This commit is contained in:
Edgar 2024-01-15 21:36:53 +01:00
parent 19c52009ab
commit bfad93ac5a
No known key found for this signature in database
GPG key ID: 70ADAE8F35904387
8 changed files with 389 additions and 109 deletions

View file

@ -41,6 +41,20 @@ pub struct PathExpr {
pub span: Span,
}
impl PathExpr {
pub fn get_full_path(&self) -> String {
let mut result = self.first.name.clone();
for path in &self.extra {
result.push('.');
match path {
PathSegment::Field(name) => result.push_str(&name.name),
PathSegment::Index { .. } => result.push_str("[]"),
}
}
result
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum PathSegment {
Field(Ident),

View file

@ -2,20 +2,25 @@ use std::{collections::HashMap, error::Error};
use bumpalo::Bump;
use edlang_ast::{
ArithOp, AssignStmt, BinaryOp, Constant, Expression, Function, LetStmt, Module,
ModuleStatement, ReturnStmt, Statement, Struct, ValueExpr,
ArithOp, AssignStmt, BinaryOp, CmpOp, Constant, Expression, FnCallExpr, Function, IfStmt,
LetStmt, LogicOp, Module, ModuleStatement, ReturnStmt, Statement, Struct, ValueExpr,
};
use edlang_session::Session;
use melior::{
dialect::{arith, func, memref},
dialect::{
arith::{self, CmpiPredicate},
cf, func, memref,
},
ir::{
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType, MemRefType},
Attribute, Block, BlockRef, Location, Module as MeliorModule, Region, Type, Value,
ValueLike,
Attribute, Block, BlockRef, Location, Module as MeliorModule, Region, Type, TypeLike,
Value, ValueLike,
},
Context as MeliorContext,
};
use crate::util::call_site;
#[derive(Debug, Clone)]
pub struct LocalVar<'ctx, 'parent: 'ctx> {
pub ast_type: edlang_ast::Type,
@ -48,7 +53,19 @@ struct ScopeContext<'ctx, 'parent: 'ctx> {
pub functions: HashMap<String, &'parent Function>,
pub structs: HashMap<String, &'parent Struct>,
pub constants: HashMap<String, &'parent Constant>,
pub ret_type: Option<&'parent edlang_ast::Type>,
pub function: Option<&'parent edlang_ast::Function>,
}
impl<'ctx, 'parent: 'ctx> ScopeContext<'ctx, 'parent> {
fn is_type_signed(&self, type_info: &edlang_ast::Type) -> bool {
let signed = ["i8", "i16", "i32", "i64", "i128"];
signed.contains(&type_info.name.name.as_str())
}
fn is_float(&self, type_info: &edlang_ast::Type) -> bool {
let signed = ["f32", "f64"];
signed.contains(&type_info.name.name.as_str())
}
}
struct BlockHelper<'ctx, 'this: 'ctx> {
@ -155,6 +172,7 @@ fn compile_function_def<'ctx, 'parent>(
let region = Region::new();
let location = get_location(context, session, info.name.span.lo);
let location = Location::name(context, &info.name.name, location);
let mut args = Vec::with_capacity(info.params.len());
let mut fn_args_types = Vec::with_capacity(info.params.len());
@ -162,6 +180,7 @@ fn compile_function_def<'ctx, 'parent>(
for param in &info.params {
let param_type = scope_ctx.resolve_type(context, &param.arg_type)?;
let loc = get_location(context, session, param.name.span.lo);
let loc = Location::name(context, &param.name.name, loc);
args.push((param_type, loc));
fn_args_types.push(param_type);
}
@ -183,7 +202,7 @@ fn compile_function_def<'ctx, 'parent>(
};
let fn_block = helper.append_block(Block::new(&args));
let mut scope_ctx = scope_ctx.clone();
scope_ctx.ret_type = info.return_type.as_ref();
scope_ctx.function = Some(info);
// Push arguments into locals
for (i, param) in info.params.iter().enumerate() {
@ -205,7 +224,11 @@ fn compile_function_def<'ctx, 'parent>(
if final_block.terminator().is_none() {
final_block.append_operation(func::r#return(
&[],
get_location(context, session, info.span.hi),
Location::name(
context,
"return",
get_location(context, session, info.span.hi),
),
));
}
}
@ -231,7 +254,7 @@ fn compile_block<'ctx, 'parent: 'ctx>(
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
helper: &BlockHelper<'ctx, 'parent>,
mut block: &'parent BlockRef<'ctx, 'parent>,
info: &edlang_ast::Block,
info: &'parent edlang_ast::Block,
) -> Result<&'parent BlockRef<'ctx, 'parent>, Box<dyn std::error::Error>> {
tracing::debug!("compiling block");
for stmt in &info.body {
@ -244,11 +267,15 @@ fn compile_block<'ctx, 'parent: 'ctx>(
}
Statement::For(_) => todo!(),
Statement::While(_) => todo!(),
Statement::If(_) => todo!(),
Statement::If(info) => {
block = compile_if_stmt(session, context, scope_ctx, helper, block, info)?;
}
Statement::Return(info) => {
compile_return(session, context, scope_ctx, helper, block, info)?;
}
Statement::FnCall(_) => todo!(),
Statement::FnCall(info) => {
compile_fn_call(session, context, scope_ctx, helper, block, info)?;
}
}
}
@ -261,7 +288,7 @@ fn compile_let<'ctx, 'parent: 'ctx>(
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
helper: &BlockHelper<'ctx, 'parent>,
block: &'parent BlockRef<'ctx, 'parent>,
info: &LetStmt,
info: &'parent LetStmt,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::debug!("compiling let");
let value = compile_expression(
@ -271,7 +298,7 @@ fn compile_let<'ctx, 'parent: 'ctx>(
helper,
block,
&info.value,
Some(scope_ctx.resolve_type(context, &info.r#type)?),
Some(&info.r#type),
)?;
let location = get_location(context, session, info.name.span.lo);
@ -332,7 +359,7 @@ fn compile_assign<'ctx, 'parent: 'ctx>(
helper,
block,
&info.value,
Some(scope_ctx.resolve_type(context, &local.ast_type)?),
Some(&local.ast_type),
)?;
let k0 = block
@ -357,6 +384,10 @@ fn compile_return<'ctx, 'parent: 'ctx>(
) -> Result<(), Box<dyn std::error::Error>> {
tracing::debug!("compiling return");
let location = get_location(context, session, info.span.lo);
let location = Location::name(context, "return", location);
let ret_type = scope_ctx.function.and_then(|x| x.return_type.clone());
if let Some(value) = &info.value {
let value = compile_expression(
session,
@ -365,9 +396,7 @@ fn compile_return<'ctx, 'parent: 'ctx>(
helper,
block,
value,
scope_ctx
.ret_type
.map(|x| scope_ctx.resolve_type(context, x).unwrap()),
ret_type.as_ref(),
)?;
block.append_operation(func::r#return(&[value], location));
} else {
@ -384,7 +413,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
helper: &BlockHelper<'ctx, 'parent>,
block: &'parent BlockRef<'ctx, 'parent>,
info: &Expression,
type_hint: Option<Type<'ctx>>,
type_hint: Option<&'parent edlang_ast::Type>,
) -> Result<Value<'ctx, 'parent>, Box<dyn std::error::Error>> {
tracing::debug!("compiling expression");
Ok(match info {
@ -408,10 +437,10 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
.result(0)?
.into(),
ValueExpr::Int { value, span } => {
let type_it = match type_hint {
let type_it = match type_hint.map(|x| scope_ctx.resolve_type(context, x)) {
Some(info) => info,
None => IntegerType::new(context, 32).into(),
};
None => Ok(IntegerType::new(context, 32).into()),
}?;
block
.append_operation(arith::constant(
context,
@ -422,10 +451,10 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
.into()
}
ValueExpr::Float { value, span } => {
let type_it = match type_hint {
let type_it = match type_hint.map(|x| scope_ctx.resolve_type(context, x)) {
Some(info) => info,
None => Type::float32(context),
};
None => Ok(Type::float32(context)),
}?;
block
.append_operation(arith::constant(
context,
@ -443,6 +472,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
.expect("local not found");
let location = get_location(context, session, path.first.span.lo);
let location = Location::name(context, &path.first.name, location);
if local.is_alloca {
let k0 = block
@ -464,49 +494,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
}
},
Expression::FnCall(info) => {
let mut args = Vec::with_capacity(info.params.len());
let location = get_location(context, session, info.name.span.lo);
let target_fn = scope_ctx
.functions
.get(&info.name.name)
.expect("function not found");
assert_eq!(
info.params.len(),
target_fn.params.len(),
"parameter length doesnt match"
);
for (arg, arg_info) in info.params.iter().zip(&target_fn.params) {
let value = compile_expression(
session,
context,
scope_ctx,
helper,
block,
arg,
Some(scope_ctx.resolve_type(context, &arg_info.arg_type)?),
)?;
args.push(value);
}
let return_type = if let Some(return_type) = &target_fn.return_type {
vec![scope_ctx.resolve_type(context, return_type)?]
} else {
vec![]
};
block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &info.name.name),
&args,
&return_type,
location,
))
.result(0)?
.into()
compile_fn_call(session, context, scope_ctx, helper, block, info)?
}
Expression::Unary(_, _) => todo!(),
Expression::Binary(lhs, op, rhs) => {
@ -516,38 +504,123 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
compile_expression(session, context, scope_ctx, helper, block, rhs, type_hint)?;
match op {
BinaryOp::Arith(op, span) => {
match op {
// todo check if its a float or unsigned
ArithOp::Add => block.append_operation(arith::addi(
lhs,
rhs,
get_location(context, session, span.lo),
)),
ArithOp::Sub => block.append_operation(arith::subi(
lhs,
rhs,
get_location(context, session, span.lo),
)),
ArithOp::Mul => block.append_operation(arith::muli(
lhs,
rhs,
get_location(context, session, span.lo),
)),
ArithOp::Div => block.append_operation(arith::divsi(
lhs,
rhs,
get_location(context, session, span.lo),
)),
ArithOp::Mod => block.append_operation(arith::remsi(
lhs,
rhs,
get_location(context, session, span.lo),
)),
}
BinaryOp::Arith(arith_op, span) => {
let location = get_location(context, session, span.lo);
let ast_type_hint = type_hint.expect("type info missing");
block.append_operation(if scope_ctx.is_float(ast_type_hint) {
match arith_op {
ArithOp::Add => arith::addf(lhs, rhs, location),
ArithOp::Sub => arith::subf(lhs, rhs, location),
ArithOp::Mul => arith::mulf(lhs, rhs, location),
ArithOp::Div => arith::divf(lhs, rhs, location),
ArithOp::Mod => arith::remf(lhs, rhs, location),
}
} else {
match arith_op {
ArithOp::Add => arith::addi(lhs, rhs, location),
ArithOp::Sub => arith::subi(lhs, rhs, location),
ArithOp::Mul => arith::muli(lhs, rhs, location),
ArithOp::Div => {
if scope_ctx.is_type_signed(ast_type_hint) {
arith::divsi(lhs, rhs, location)
} else {
arith::divui(lhs, rhs, location)
}
}
ArithOp::Mod => {
if scope_ctx.is_type_signed(ast_type_hint) {
arith::remsi(lhs, rhs, location)
} else {
arith::remui(lhs, rhs, location)
}
}
}
})
}
BinaryOp::Logic(logic_op, span) => {
let location = get_location(context, session, span.lo);
block.append_operation(match logic_op {
LogicOp::And => {
dbg!(lhs.r#type());
dbg!(rhs.r#type());
let const_true = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(1, IntegerType::new(context, 1).into())
.into(),
location,
))
.result(0)?
.into();
let lhs_bool = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
lhs,
const_true,
location,
))
.result(0)?
.into();
let rhs_bool = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
rhs,
const_true,
location,
))
.result(0)?
.into();
arith::andi(lhs_bool, rhs_bool, location)
}
LogicOp::Or => {
let const_true = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(1, IntegerType::new(context, 1).into())
.into(),
location,
))
.result(0)?
.into();
let lhs_bool = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
lhs,
const_true,
location,
))
.result(0)?
.into();
let rhs_bool = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
rhs,
const_true,
location,
))
.result(0)?
.into();
arith::ori(lhs_bool, rhs_bool, location)
}
})
}
BinaryOp::Compare(cmp_op, span) => {
let location = get_location(context, session, span.lo);
block.append_operation(match cmp_op {
CmpOp::Eq => arith::cmpi(context, CmpiPredicate::Eq, lhs, rhs, location),
CmpOp::NotEq => arith::cmpi(context, CmpiPredicate::Ne, lhs, rhs, location),
CmpOp::Lt => arith::cmpi(context, CmpiPredicate::Slt, lhs, rhs, location),
CmpOp::LtEq => arith::cmpi(context, CmpiPredicate::Sle, lhs, rhs, location),
CmpOp::Gt => arith::cmpi(context, CmpiPredicate::Sgt, lhs, rhs, location),
CmpOp::GtEq => arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location),
})
}
BinaryOp::Logic(_, _) => todo!(),
BinaryOp::Compare(_, _) => todo!(),
BinaryOp::Bitwise(_, _) => todo!(),
}
.result(0)?
@ -555,3 +628,149 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
}
})
}
fn compile_fn_call<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
scope_ctx: &ScopeContext<'ctx, 'parent>,
_helper: &BlockHelper<'ctx, 'parent>,
block: &'parent BlockRef<'ctx, 'parent>,
info: &FnCallExpr,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
let mut args = Vec::with_capacity(info.params.len());
let location = get_location(context, session, info.name.span.lo);
let location_callee = Location::name(context, &info.name.name, location);
let location_caller = Location::name(
context,
&info.name.name,
get_location(context, session, scope_ctx.function.unwrap().span.lo),
);
let location = call_site(location_callee, location_caller);
let target_fn = scope_ctx
.functions
.get(&info.name.name)
.expect("function not found");
assert_eq!(
info.params.len(),
target_fn.params.len(),
"parameter length doesnt match"
);
for (arg, arg_info) in info.params.iter().zip(&target_fn.params) {
let value = compile_expression(
session,
context,
scope_ctx,
_helper,
block,
arg,
Some(&arg_info.arg_type),
)?;
args.push(value);
}
let return_type = if let Some(ret_type) = &target_fn.return_type {
vec![scope_ctx.resolve_type(context, ret_type)?]
} else {
vec![]
};
Ok(block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &info.name.name),
&args,
&return_type,
location,
))
.result(0)?
.into())
}
fn compile_if_stmt<'c, 'this: 'c>(
session: &Session,
context: &'c MeliorContext,
scope_ctx: &mut ScopeContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
block: &'this BlockRef<'c, 'this>,
info: &'this IfStmt,
) -> Result<&'this BlockRef<'c, 'this>, Box<dyn Error>> {
let condition = compile_expression(
session,
context,
scope_ctx,
helper,
block,
&info.condition,
None,
)?;
let then_successor = helper.append_block(Block::new(&[]));
let else_successor = helper.append_block(Block::new(&[]));
let location = get_location(context, session, info.span.lo);
block.append_operation(cf::cond_br(
context,
condition,
then_successor,
else_successor,
&[],
&[],
Location::name(context, "if", location),
));
let mut then_successor = then_successor;
let mut else_successor = else_successor;
{
let mut then_scope_ctx = scope_ctx.clone();
then_successor = compile_block(
session,
context,
&mut then_scope_ctx,
helper,
then_successor,
&info.then_block,
)?;
}
if let Some(else_block) = info.else_block.as_ref() {
let mut else_scope_ctx = scope_ctx.clone();
else_successor = compile_block(
session,
context,
&mut else_scope_ctx,
helper,
else_successor,
else_block,
)?;
}
// both return
if then_successor.terminator().is_some() && else_successor.terminator().is_some() {
return Ok(then_successor);
}
let final_block = helper.append_block(Block::new(&[]));
if then_successor.terminator().is_none() {
then_successor.append_operation(cf::br(
final_block,
&[],
get_location(context, session, info.span.hi),
));
}
if else_successor.terminator().is_none() {
else_successor.append_operation(cf::br(
final_block,
&[],
get_location(context, session, info.span.hi),
));
}
Ok(final_block)
}

View file

@ -4,7 +4,7 @@ use edlang_ast::Module;
use edlang_session::Session;
use melior::{
dialect::DialectRegistry,
ir::{Location, Module as MeliorModule},
ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule},
pass::{self, PassManager},
utility::{register_all_dialects, register_all_llvm_translations, register_all_passes},
Context as MeliorContext,
@ -42,16 +42,20 @@ impl Context {
assert!(melior_module.as_operation().verify());
tracing::debug!(
"MLIR Code before passes:\n{:#?}",
melior_module.as_operation()
"MLIR Code before passes:\n{}",
melior_module.as_operation().to_string_with_flags(
OperationPrintingFlags::new().enable_debug_info(true, true)
)?
);
// TODO: Add proper error handling.
self.run_pass_manager(&mut melior_module)?;
tracing::debug!(
"MLIR Code after passes:\n{:#?}",
melior_module.as_operation()
"MLIR Code after passes:\n{}",
melior_module.as_operation().to_string_with_flags(
OperationPrintingFlags::new().enable_debug_info(true, true)
)?
);
Ok(melior_module)

View file

@ -32,6 +32,7 @@ pub mod codegen;
mod context;
mod ffi;
pub mod linker;
mod util;
pub fn compile(session: &Session, program: &Module) -> Result<PathBuf, Box<dyn std::error::Error>> {
let context = Context::new();

View file

@ -2,8 +2,6 @@ use std::path::Path;
use tracing::instrument;
// TODO: Implement a proper linker driver, passing only the arguments needed dynamically based on the requirements.
#[instrument(level = "debug")]
pub fn link_shared_lib(input_path: &Path, output_filename: &Path) -> Result<(), std::io::Error> {
let args: &[&str] = {
@ -24,15 +22,41 @@ pub fn link_shared_lib(input_path: &Path, output_filename: &Path) -> Result<(),
}
#[cfg(target_os = "linux")]
{
let (scrt1, crti, crtn) = {
if file_exists("/usr/lib64/Scrt1.o") {
(
"/usr/lib64/Scrt1.o",
"/usr/lib64/crti.o",
"/usr/lib64/crtn.o",
)
} else {
(
"/lib/x86_64-linux-gnu/Scrt1.o",
"/lib/x86_64-linux-gnu/crti.o",
"/lib/x86_64-linux-gnu/crtn.o",
)
}
};
&[
"-pie",
"--hash-style=gnu",
"--eh-frame-hdr",
"-shared",
"--dynamic-linker",
"/lib64/ld-linux-x86-64.so.2",
"-m",
"elf_x86_64",
scrt1,
crti,
"-o",
&output_filename.display().to_string(),
"-L/lib/../lib64",
"-L/usr/lib/../lib64",
"-L/lib64",
"-L/usr/lib64",
"-L/lib/x86_64-linux-gnu",
"-zrelro",
"--no-as-needed",
"-lc",
crtn,
&input_path.display().to_string(),
]
}
@ -96,3 +120,7 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std:
proc.wait_with_output()?;
Ok(())
}
fn file_exists(path: &str) -> bool {
Path::new(path).exists()
}

View file

@ -0,0 +1,6 @@
use melior::{ir::Location, Context};
use mlir_sys::mlirLocationCallSiteGet;
pub fn call_site<'c>(callee: Location<'c>, caller: Location<'c>) -> Location<'c> {
unsafe { Location::from_raw(mlirLocationCallSiteGet(callee.to_raw(), caller.to_raw())) }
}

View file

@ -244,7 +244,7 @@ pub(crate) ForStmt: ast::ForStmt = {
}
pub(crate) IfStmt: ast::IfStmt = {
<lo:@L> "if" <condition:Expression> <then_block:Block> <else_block:Block?> <hi:@R> => {
<lo:@L> "if" <condition:Expression> <then_block:Block> <else_block:("else" <Block>)?> <hi:@R> => {
ast::IfStmt {
condition,
then_block,

View file

@ -3,6 +3,14 @@ mod Main {
return a + b;
}
fn check(a: i32) -> i32 {
if a == 2 {
return a;
} else {
return 0;
}
}
fn main() -> i32 {
let x: i32 = 2 + 3;
let y: i32 = add(x, 4);