From bfad93ac5afe483edc719bfd8f8cb08f5bd2c83f Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 21:36:53 +0100 Subject: [PATCH] ok --- lib/edlang_ast/src/lib.rs | 14 + lib/edlang_codegen_mlir/src/codegen.rs | 415 +++++++++++++++++++------ lib/edlang_codegen_mlir/src/context.rs | 14 +- lib/edlang_codegen_mlir/src/lib.rs | 1 + lib/edlang_codegen_mlir/src/linker.rs | 38 ++- lib/edlang_codegen_mlir/src/util.rs | 6 + lib/edlang_parser/src/grammar.lalrpop | 2 +- programs/example.ed | 8 + 8 files changed, 389 insertions(+), 109 deletions(-) create mode 100644 lib/edlang_codegen_mlir/src/util.rs diff --git a/lib/edlang_ast/src/lib.rs b/lib/edlang_ast/src/lib.rs index 371d628e5..28a0ceffa 100644 --- a/lib/edlang_ast/src/lib.rs +++ b/lib/edlang_ast/src/lib.rs @@ -41,6 +41,20 @@ pub struct PathExpr { pub span: Span, } +impl PathExpr { + pub fn get_full_path(&self) -> String { + let mut result = self.first.name.clone(); + for path in &self.extra { + result.push('.'); + match path { + PathSegment::Field(name) => result.push_str(&name.name), + PathSegment::Index { .. } => result.push_str("[]"), + } + } + result + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PathSegment { Field(Ident), diff --git a/lib/edlang_codegen_mlir/src/codegen.rs b/lib/edlang_codegen_mlir/src/codegen.rs index f2d930e17..e8850ded1 100644 --- a/lib/edlang_codegen_mlir/src/codegen.rs +++ b/lib/edlang_codegen_mlir/src/codegen.rs @@ -2,20 +2,25 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use edlang_ast::{ - ArithOp, AssignStmt, BinaryOp, Constant, Expression, Function, LetStmt, Module, - ModuleStatement, ReturnStmt, Statement, Struct, ValueExpr, + ArithOp, AssignStmt, BinaryOp, CmpOp, Constant, Expression, FnCallExpr, Function, IfStmt, + LetStmt, LogicOp, Module, ModuleStatement, ReturnStmt, Statement, Struct, ValueExpr, }; use edlang_session::Session; use melior::{ - dialect::{arith, func, memref}, + dialect::{ + arith::{self, CmpiPredicate}, + cf, func, memref, + }, ir::{ attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Attribute, Block, BlockRef, Location, Module as MeliorModule, Region, Type, Value, - ValueLike, + Attribute, Block, BlockRef, Location, Module as MeliorModule, Region, Type, TypeLike, + Value, ValueLike, }, Context as MeliorContext, }; + +use crate::util::call_site; #[derive(Debug, Clone)] pub struct LocalVar<'ctx, 'parent: 'ctx> { pub ast_type: edlang_ast::Type, @@ -48,7 +53,19 @@ struct ScopeContext<'ctx, 'parent: 'ctx> { pub functions: HashMap, pub structs: HashMap, pub constants: HashMap, - pub ret_type: Option<&'parent edlang_ast::Type>, + pub function: Option<&'parent edlang_ast::Function>, +} + +impl<'ctx, 'parent: 'ctx> ScopeContext<'ctx, 'parent> { + fn is_type_signed(&self, type_info: &edlang_ast::Type) -> bool { + let signed = ["i8", "i16", "i32", "i64", "i128"]; + signed.contains(&type_info.name.name.as_str()) + } + + fn is_float(&self, type_info: &edlang_ast::Type) -> bool { + let signed = ["f32", "f64"]; + signed.contains(&type_info.name.name.as_str()) + } } struct BlockHelper<'ctx, 'this: 'ctx> { @@ -155,6 +172,7 @@ fn compile_function_def<'ctx, 'parent>( let region = Region::new(); let location = get_location(context, session, info.name.span.lo); + let location = Location::name(context, &info.name.name, location); let mut args = Vec::with_capacity(info.params.len()); let mut fn_args_types = Vec::with_capacity(info.params.len()); @@ -162,6 +180,7 @@ fn compile_function_def<'ctx, 'parent>( for param in &info.params { let param_type = scope_ctx.resolve_type(context, ¶m.arg_type)?; let loc = get_location(context, session, param.name.span.lo); + let loc = Location::name(context, ¶m.name.name, loc); args.push((param_type, loc)); fn_args_types.push(param_type); } @@ -183,7 +202,7 @@ fn compile_function_def<'ctx, 'parent>( }; let fn_block = helper.append_block(Block::new(&args)); let mut scope_ctx = scope_ctx.clone(); - scope_ctx.ret_type = info.return_type.as_ref(); + scope_ctx.function = Some(info); // Push arguments into locals for (i, param) in info.params.iter().enumerate() { @@ -205,7 +224,11 @@ fn compile_function_def<'ctx, 'parent>( if final_block.terminator().is_none() { final_block.append_operation(func::r#return( &[], - get_location(context, session, info.span.hi), + Location::name( + context, + "return", + get_location(context, session, info.span.hi), + ), )); } } @@ -231,7 +254,7 @@ fn compile_block<'ctx, 'parent: 'ctx>( scope_ctx: &mut ScopeContext<'ctx, 'parent>, helper: &BlockHelper<'ctx, 'parent>, mut block: &'parent BlockRef<'ctx, 'parent>, - info: &edlang_ast::Block, + info: &'parent edlang_ast::Block, ) -> Result<&'parent BlockRef<'ctx, 'parent>, Box> { tracing::debug!("compiling block"); for stmt in &info.body { @@ -244,11 +267,15 @@ fn compile_block<'ctx, 'parent: 'ctx>( } Statement::For(_) => todo!(), Statement::While(_) => todo!(), - Statement::If(_) => todo!(), + Statement::If(info) => { + block = compile_if_stmt(session, context, scope_ctx, helper, block, info)?; + } Statement::Return(info) => { compile_return(session, context, scope_ctx, helper, block, info)?; } - Statement::FnCall(_) => todo!(), + Statement::FnCall(info) => { + compile_fn_call(session, context, scope_ctx, helper, block, info)?; + } } } @@ -261,7 +288,7 @@ fn compile_let<'ctx, 'parent: 'ctx>( scope_ctx: &mut ScopeContext<'ctx, 'parent>, helper: &BlockHelper<'ctx, 'parent>, block: &'parent BlockRef<'ctx, 'parent>, - info: &LetStmt, + info: &'parent LetStmt, ) -> Result<(), Box> { tracing::debug!("compiling let"); let value = compile_expression( @@ -271,7 +298,7 @@ fn compile_let<'ctx, 'parent: 'ctx>( helper, block, &info.value, - Some(scope_ctx.resolve_type(context, &info.r#type)?), + Some(&info.r#type), )?; let location = get_location(context, session, info.name.span.lo); @@ -332,7 +359,7 @@ fn compile_assign<'ctx, 'parent: 'ctx>( helper, block, &info.value, - Some(scope_ctx.resolve_type(context, &local.ast_type)?), + Some(&local.ast_type), )?; let k0 = block @@ -357,6 +384,10 @@ fn compile_return<'ctx, 'parent: 'ctx>( ) -> Result<(), Box> { tracing::debug!("compiling return"); let location = get_location(context, session, info.span.lo); + let location = Location::name(context, "return", location); + + let ret_type = scope_ctx.function.and_then(|x| x.return_type.clone()); + if let Some(value) = &info.value { let value = compile_expression( session, @@ -365,9 +396,7 @@ fn compile_return<'ctx, 'parent: 'ctx>( helper, block, value, - scope_ctx - .ret_type - .map(|x| scope_ctx.resolve_type(context, x).unwrap()), + ret_type.as_ref(), )?; block.append_operation(func::r#return(&[value], location)); } else { @@ -384,7 +413,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( helper: &BlockHelper<'ctx, 'parent>, block: &'parent BlockRef<'ctx, 'parent>, info: &Expression, - type_hint: Option>, + type_hint: Option<&'parent edlang_ast::Type>, ) -> Result, Box> { tracing::debug!("compiling expression"); Ok(match info { @@ -408,10 +437,10 @@ fn compile_expression<'ctx, 'parent: 'ctx>( .result(0)? .into(), ValueExpr::Int { value, span } => { - let type_it = match type_hint { + let type_it = match type_hint.map(|x| scope_ctx.resolve_type(context, x)) { Some(info) => info, - None => IntegerType::new(context, 32).into(), - }; + None => Ok(IntegerType::new(context, 32).into()), + }?; block .append_operation(arith::constant( context, @@ -422,10 +451,10 @@ fn compile_expression<'ctx, 'parent: 'ctx>( .into() } ValueExpr::Float { value, span } => { - let type_it = match type_hint { + let type_it = match type_hint.map(|x| scope_ctx.resolve_type(context, x)) { Some(info) => info, - None => Type::float32(context), - }; + None => Ok(Type::float32(context)), + }?; block .append_operation(arith::constant( context, @@ -443,6 +472,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( .expect("local not found"); let location = get_location(context, session, path.first.span.lo); + let location = Location::name(context, &path.first.name, location); if local.is_alloca { let k0 = block @@ -464,49 +494,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( } }, Expression::FnCall(info) => { - let mut args = Vec::with_capacity(info.params.len()); - let location = get_location(context, session, info.name.span.lo); - - let target_fn = scope_ctx - .functions - .get(&info.name.name) - .expect("function not found"); - - assert_eq!( - info.params.len(), - target_fn.params.len(), - "parameter length doesnt match" - ); - - for (arg, arg_info) in info.params.iter().zip(&target_fn.params) { - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - arg, - Some(scope_ctx.resolve_type(context, &arg_info.arg_type)?), - )?; - args.push(value); - } - - let return_type = if let Some(return_type) = &target_fn.return_type { - vec![scope_ctx.resolve_type(context, return_type)?] - } else { - vec![] - }; - - block - .append_operation(func::call( - context, - FlatSymbolRefAttribute::new(context, &info.name.name), - &args, - &return_type, - location, - )) - .result(0)? - .into() + compile_fn_call(session, context, scope_ctx, helper, block, info)? } Expression::Unary(_, _) => todo!(), Expression::Binary(lhs, op, rhs) => { @@ -516,38 +504,123 @@ fn compile_expression<'ctx, 'parent: 'ctx>( compile_expression(session, context, scope_ctx, helper, block, rhs, type_hint)?; match op { - BinaryOp::Arith(op, span) => { - match op { - // todo check if its a float or unsigned - ArithOp::Add => block.append_operation(arith::addi( - lhs, - rhs, - get_location(context, session, span.lo), - )), - ArithOp::Sub => block.append_operation(arith::subi( - lhs, - rhs, - get_location(context, session, span.lo), - )), - ArithOp::Mul => block.append_operation(arith::muli( - lhs, - rhs, - get_location(context, session, span.lo), - )), - ArithOp::Div => block.append_operation(arith::divsi( - lhs, - rhs, - get_location(context, session, span.lo), - )), - ArithOp::Mod => block.append_operation(arith::remsi( - lhs, - rhs, - get_location(context, session, span.lo), - )), - } + BinaryOp::Arith(arith_op, span) => { + let location = get_location(context, session, span.lo); + let ast_type_hint = type_hint.expect("type info missing"); + + block.append_operation(if scope_ctx.is_float(ast_type_hint) { + match arith_op { + ArithOp::Add => arith::addf(lhs, rhs, location), + ArithOp::Sub => arith::subf(lhs, rhs, location), + ArithOp::Mul => arith::mulf(lhs, rhs, location), + ArithOp::Div => arith::divf(lhs, rhs, location), + ArithOp::Mod => arith::remf(lhs, rhs, location), + } + } else { + match arith_op { + ArithOp::Add => arith::addi(lhs, rhs, location), + ArithOp::Sub => arith::subi(lhs, rhs, location), + ArithOp::Mul => arith::muli(lhs, rhs, location), + ArithOp::Div => { + if scope_ctx.is_type_signed(ast_type_hint) { + arith::divsi(lhs, rhs, location) + } else { + arith::divui(lhs, rhs, location) + } + } + ArithOp::Mod => { + if scope_ctx.is_type_signed(ast_type_hint) { + arith::remsi(lhs, rhs, location) + } else { + arith::remui(lhs, rhs, location) + } + } + } + }) + } + BinaryOp::Logic(logic_op, span) => { + let location = get_location(context, session, span.lo); + + block.append_operation(match logic_op { + LogicOp::And => { + dbg!(lhs.r#type()); + dbg!(rhs.r#type()); + let const_true = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(1, IntegerType::new(context, 1).into()) + .into(), + location, + )) + .result(0)? + .into(); + let lhs_bool = block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Eq, + lhs, + const_true, + location, + )) + .result(0)? + .into(); + let rhs_bool = block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Eq, + rhs, + const_true, + location, + )) + .result(0)? + .into(); + arith::andi(lhs_bool, rhs_bool, location) + } + LogicOp::Or => { + let const_true = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(1, IntegerType::new(context, 1).into()) + .into(), + location, + )) + .result(0)? + .into(); + let lhs_bool = block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Eq, + lhs, + const_true, + location, + )) + .result(0)? + .into(); + let rhs_bool = block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Eq, + rhs, + const_true, + location, + )) + .result(0)? + .into(); + arith::ori(lhs_bool, rhs_bool, location) + } + }) + } + BinaryOp::Compare(cmp_op, span) => { + let location = get_location(context, session, span.lo); + block.append_operation(match cmp_op { + CmpOp::Eq => arith::cmpi(context, CmpiPredicate::Eq, lhs, rhs, location), + CmpOp::NotEq => arith::cmpi(context, CmpiPredicate::Ne, lhs, rhs, location), + CmpOp::Lt => arith::cmpi(context, CmpiPredicate::Slt, lhs, rhs, location), + CmpOp::LtEq => arith::cmpi(context, CmpiPredicate::Sle, lhs, rhs, location), + CmpOp::Gt => arith::cmpi(context, CmpiPredicate::Sgt, lhs, rhs, location), + CmpOp::GtEq => arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location), + }) } - BinaryOp::Logic(_, _) => todo!(), - BinaryOp::Compare(_, _) => todo!(), BinaryOp::Bitwise(_, _) => todo!(), } .result(0)? @@ -555,3 +628,149 @@ fn compile_expression<'ctx, 'parent: 'ctx>( } }) } + +fn compile_fn_call<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent BlockRef<'ctx, 'parent>, + info: &FnCallExpr, +) -> Result, Box> { + let mut args = Vec::with_capacity(info.params.len()); + let location = get_location(context, session, info.name.span.lo); + let location_callee = Location::name(context, &info.name.name, location); + let location_caller = Location::name( + context, + &info.name.name, + get_location(context, session, scope_ctx.function.unwrap().span.lo), + ); + let location = call_site(location_callee, location_caller); + + let target_fn = scope_ctx + .functions + .get(&info.name.name) + .expect("function not found"); + + assert_eq!( + info.params.len(), + target_fn.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.params.iter().zip(&target_fn.params) { + let value = compile_expression( + session, + context, + scope_ctx, + _helper, + block, + arg, + Some(&arg_info.arg_type), + )?; + args.push(value); + } + + let return_type = if let Some(ret_type) = &target_fn.return_type { + vec![scope_ctx.resolve_type(context, ret_type)?] + } else { + vec![] + }; + + Ok(block + .append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &info.name.name), + &args, + &return_type, + location, + )) + .result(0)? + .into()) +} + +fn compile_if_stmt<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &'this IfStmt, +) -> Result<&'this BlockRef<'c, 'this>, Box> { + let condition = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.condition, + None, + )?; + + let then_successor = helper.append_block(Block::new(&[])); + let else_successor = helper.append_block(Block::new(&[])); + + let location = get_location(context, session, info.span.lo); + + block.append_operation(cf::cond_br( + context, + condition, + then_successor, + else_successor, + &[], + &[], + Location::name(context, "if", location), + )); + + let mut then_successor = then_successor; + let mut else_successor = else_successor; + + { + let mut then_scope_ctx = scope_ctx.clone(); + then_successor = compile_block( + session, + context, + &mut then_scope_ctx, + helper, + then_successor, + &info.then_block, + )?; + } + + if let Some(else_block) = info.else_block.as_ref() { + let mut else_scope_ctx = scope_ctx.clone(); + else_successor = compile_block( + session, + context, + &mut else_scope_ctx, + helper, + else_successor, + else_block, + )?; + } + + // both return + if then_successor.terminator().is_some() && else_successor.terminator().is_some() { + return Ok(then_successor); + } + + let final_block = helper.append_block(Block::new(&[])); + + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br( + final_block, + &[], + get_location(context, session, info.span.hi), + )); + } + + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br( + final_block, + &[], + get_location(context, session, info.span.hi), + )); + } + + Ok(final_block) +} diff --git a/lib/edlang_codegen_mlir/src/context.rs b/lib/edlang_codegen_mlir/src/context.rs index b4c8307be..e8c89e3fc 100644 --- a/lib/edlang_codegen_mlir/src/context.rs +++ b/lib/edlang_codegen_mlir/src/context.rs @@ -4,7 +4,7 @@ use edlang_ast::Module; use edlang_session::Session; use melior::{ dialect::DialectRegistry, - ir::{Location, Module as MeliorModule}, + ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule}, pass::{self, PassManager}, utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}, Context as MeliorContext, @@ -42,16 +42,20 @@ impl Context { assert!(melior_module.as_operation().verify()); tracing::debug!( - "MLIR Code before passes:\n{:#?}", - melior_module.as_operation() + "MLIR Code before passes:\n{}", + melior_module.as_operation().to_string_with_flags( + OperationPrintingFlags::new().enable_debug_info(true, true) + )? ); // TODO: Add proper error handling. self.run_pass_manager(&mut melior_module)?; tracing::debug!( - "MLIR Code after passes:\n{:#?}", - melior_module.as_operation() + "MLIR Code after passes:\n{}", + melior_module.as_operation().to_string_with_flags( + OperationPrintingFlags::new().enable_debug_info(true, true) + )? ); Ok(melior_module) diff --git a/lib/edlang_codegen_mlir/src/lib.rs b/lib/edlang_codegen_mlir/src/lib.rs index ab8867428..4470915be 100644 --- a/lib/edlang_codegen_mlir/src/lib.rs +++ b/lib/edlang_codegen_mlir/src/lib.rs @@ -32,6 +32,7 @@ pub mod codegen; mod context; mod ffi; pub mod linker; +mod util; pub fn compile(session: &Session, program: &Module) -> Result> { let context = Context::new(); diff --git a/lib/edlang_codegen_mlir/src/linker.rs b/lib/edlang_codegen_mlir/src/linker.rs index 6fb2f784c..b44df1718 100644 --- a/lib/edlang_codegen_mlir/src/linker.rs +++ b/lib/edlang_codegen_mlir/src/linker.rs @@ -2,8 +2,6 @@ use std::path::Path; use tracing::instrument; -// TODO: Implement a proper linker driver, passing only the arguments needed dynamically based on the requirements. - #[instrument(level = "debug")] pub fn link_shared_lib(input_path: &Path, output_filename: &Path) -> Result<(), std::io::Error> { let args: &[&str] = { @@ -24,15 +22,41 @@ pub fn link_shared_lib(input_path: &Path, output_filename: &Path) -> Result<(), } #[cfg(target_os = "linux")] { + let (scrt1, crti, crtn) = { + if file_exists("/usr/lib64/Scrt1.o") { + ( + "/usr/lib64/Scrt1.o", + "/usr/lib64/crti.o", + "/usr/lib64/crtn.o", + ) + } else { + ( + "/lib/x86_64-linux-gnu/Scrt1.o", + "/lib/x86_64-linux-gnu/crti.o", + "/lib/x86_64-linux-gnu/crtn.o", + ) + } + }; + &[ + "-pie", "--hash-style=gnu", "--eh-frame-hdr", - "-shared", + "--dynamic-linker", + "/lib64/ld-linux-x86-64.so.2", + "-m", + "elf_x86_64", + scrt1, + crti, "-o", &output_filename.display().to_string(), - "-L/lib/../lib64", - "-L/usr/lib/../lib64", + "-L/lib64", + "-L/usr/lib64", + "-L/lib/x86_64-linux-gnu", + "-zrelro", + "--no-as-needed", "-lc", + crtn, &input_path.display().to_string(), ] } @@ -96,3 +120,7 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: proc.wait_with_output()?; Ok(()) } + +fn file_exists(path: &str) -> bool { + Path::new(path).exists() +} diff --git a/lib/edlang_codegen_mlir/src/util.rs b/lib/edlang_codegen_mlir/src/util.rs new file mode 100644 index 000000000..3f449382e --- /dev/null +++ b/lib/edlang_codegen_mlir/src/util.rs @@ -0,0 +1,6 @@ +use melior::{ir::Location, Context}; +use mlir_sys::mlirLocationCallSiteGet; + +pub fn call_site<'c>(callee: Location<'c>, caller: Location<'c>) -> Location<'c> { + unsafe { Location::from_raw(mlirLocationCallSiteGet(callee.to_raw(), caller.to_raw())) } +} diff --git a/lib/edlang_parser/src/grammar.lalrpop b/lib/edlang_parser/src/grammar.lalrpop index f2e8836d6..919fe39e3 100644 --- a/lib/edlang_parser/src/grammar.lalrpop +++ b/lib/edlang_parser/src/grammar.lalrpop @@ -244,7 +244,7 @@ pub(crate) ForStmt: ast::ForStmt = { } pub(crate) IfStmt: ast::IfStmt = { - "if" => { + "if" )?> => { ast::IfStmt { condition, then_block, diff --git a/programs/example.ed b/programs/example.ed index e2948ac0a..23a532278 100644 --- a/programs/example.ed +++ b/programs/example.ed @@ -3,6 +3,14 @@ mod Main { return a + b; } + fn check(a: i32) -> i32 { + if a == 2 { + return a; + } else { + return 0; + } + } + fn main() -> i32 { let x: i32 = 2 + 3; let y: i32 = add(x, 4);