diff --git a/.gitignore b/.gitignore index ea8c4bf7f..035a8d8ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/target_ed diff --git a/Cargo.lock b/Cargo.lock index 722e4ab85..207fd705f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -539,6 +539,7 @@ dependencies = [ name = "edlang_codegen_mlir" version = "0.1.0" dependencies = [ + "bumpalo", "cc", "edlang_ast", "edlang_parser", @@ -553,6 +554,7 @@ dependencies = [ name = "edlang_driver" version = "0.1.0" dependencies = [ + "ariadne", "clap", "color-eyre", "edlang_ast", @@ -570,6 +572,7 @@ version = "0.1.0" dependencies = [ "ariadne", "edlang_ast", + "itertools 0.12.0", "lalrpop", "lalrpop-util", "logos", @@ -754,6 +757,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -771,7 +783,7 @@ dependencies = [ "diff", "ena", "is-terminal", - "itertools", + "itertools 0.10.5", "lalrpop-util", "petgraph", "pico-args", diff --git a/lib/edlang_ast/src/lib.rs b/lib/edlang_ast/src/lib.rs index 0d8ecd18c..371d628e5 100644 --- a/lib/edlang_ast/src/lib.rs +++ b/lib/edlang_ast/src/lib.rs @@ -12,6 +12,7 @@ impl Span { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Module { + pub name: Ident, pub imports: Vec, pub contents: Vec, pub span: Span, @@ -22,6 +23,7 @@ pub enum ModuleStatement { Function(Function), Constant(Constant), Struct(Struct), + Module(Module), } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -56,6 +58,7 @@ pub struct Ident { pub struct Type { pub name: Ident, pub generics: Vec, + pub span: Span, } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -73,15 +76,61 @@ pub struct Block { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Statement { - Let, - Assign, - For, - While, - If, - Return, + Let(LetStmt), + Assign(AssignStmt), + For(ForStmt), + While(WhileStmt), + If(IfStmt), + Return(ReturnStmt), FnCall(FnCallExpr), } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct LetStmt { + pub name: Ident, + pub is_mut: bool, + pub r#type: Type, + pub value: Expression, + pub span: Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct AssignStmt { + pub name: PathExpr, + pub value: Expression, + pub span: Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct IfStmt { + pub condition: Expression, + pub then_block: Block, + pub else_block: Option, + pub span: Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ForStmt { + pub name: Ident, + pub from: Expression, + pub to: Option, + pub block: Block, + pub span: Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct WhileStmt { + pub condition: Expression, + pub block: Block, + pub span: Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ReturnStmt { + pub value: Option, + pub span: Span, +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Function { pub name: Ident, @@ -102,7 +151,7 @@ pub struct Constant { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Field { +pub struct StructField { pub name: Ident, pub r#type: Type, pub span: Span, @@ -112,7 +161,7 @@ pub struct Field { pub struct Struct { pub name: Ident, pub generics: Vec, - pub fields: Vec, + pub fields: Vec, pub span: Span, } @@ -137,7 +186,6 @@ pub enum ValueExpr { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct FnCallExpr { pub name: Ident, - pub generic_params: Vec, pub params: Vec, pub span: Span, } diff --git a/lib/edlang_codegen_mlir/Cargo.toml b/lib/edlang_codegen_mlir/Cargo.toml index 758c9e5a2..719e1898f 100644 --- a/lib/edlang_codegen_mlir/Cargo.toml +++ b/lib/edlang_codegen_mlir/Cargo.toml @@ -11,6 +11,7 @@ categories = ["compilers"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bumpalo = { version = "3.14.0", features = ["std"] } edlang_ast = { version = "0.1.0", path = "../edlang_ast" } edlang_parser = { version = "0.1.0", path = "../edlang_parser" } edlang_session = { version = "0.1.0", path = "../edlang_session" } diff --git a/lib/edlang_codegen_mlir/src/codegen.rs b/lib/edlang_codegen_mlir/src/codegen.rs index 8b1378917..41444bc3f 100644 --- a/lib/edlang_codegen_mlir/src/codegen.rs +++ b/lib/edlang_codegen_mlir/src/codegen.rs @@ -1 +1,559 @@ +use std::{cell::Cell, collections::HashMap, error::Error}; +use bumpalo::Bump; +use edlang_ast::{ + ArithOp, AssignStmt, BinaryOp, Constant, Expression, Function, LetStmt, Module, + ModuleStatement, ReturnStmt, Span, Statement, Struct, ValueExpr, +}; +use edlang_session::Session; +use melior::{ + dialect::{arith, cf, func, memref}, + ir::{ + attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, + r#type::{FunctionType, IntegerType, MemRefType}, + Attribute, Block, BlockRef, Location, Module as MeliorModule, Region, Type, Value, + ValueLike, + }, + Context as MeliorContext, +}; + +use crate::context::Context; +#[derive(Debug, Clone)] +pub struct LocalVar<'ctx, 'parent: 'ctx> { + pub ast_type: edlang_ast::Type, + // If it's none its on a register, otherwise allocated on the stack. + pub is_alloca: bool, + pub value: Value<'ctx, 'parent>, +} + +impl<'ctx, 'parent> LocalVar<'ctx, 'parent> { + pub fn param(value: Value<'ctx, 'parent>, ast_type: edlang_ast::Type) -> Self { + Self { + value, + ast_type, + is_alloca: false, + } + } + + pub fn alloca(value: Value<'ctx, 'parent>, ast_type: edlang_ast::Type) -> Self { + Self { + value, + ast_type, + is_alloca: true, + } + } +} + +#[derive(Debug, Clone, Default)] +struct ScopeContext<'ctx, 'parent: 'ctx> { + pub locals: HashMap>, + pub functions: HashMap, + pub structs: HashMap, + pub constants: HashMap, + pub ret_type: Option<&'parent edlang_ast::Type>, +} + +struct BlockHelper<'ctx, 'this: 'ctx> { + region: &'this Region<'ctx>, + blocks_arena: &'this Bump, +} + +impl<'ctx, 'this: 'ctx> BlockHelper<'ctx, 'this> { + pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> { + let block = self.region.append_block(block); + + let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block); + block_ref + } +} + +impl<'ctx, 'parent: 'ctx> ScopeContext<'ctx, 'parent> { + fn resolve_type_name( + &self, + context: &'ctx MeliorContext, + name: &str, + ) -> Result, Box> { + Ok(match name { + "u128" | "i128" => IntegerType::new(context, 128).into(), + "u64" | "i64" => IntegerType::new(context, 64).into(), + "u32" | "i32" => IntegerType::new(context, 32).into(), + "u16" | "i16" => IntegerType::new(context, 16).into(), + "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), + "bool" => IntegerType::new(context, 1).into(), + _ => todo!("custom type lookup"), + }) + } + + fn resolve_type( + &self, + context: &'ctx MeliorContext, + r#type: &edlang_ast::Type, + ) -> Result, Box> { + self.resolve_type_name(context, &r#type.name.name) + } +} + +pub fn compile_module( + session: &Session, + context: &MeliorContext, + mlir_module: &MeliorModule, + module: &Module, +) -> Result<(), Box> { + let mut scope_ctx: ScopeContext = Default::default(); + let block = mlir_module.body(); + + // Save types + for statement in &module.contents { + match statement { + ModuleStatement::Function(info) => { + scope_ctx.functions.insert(info.name.name.clone(), info); + } + ModuleStatement::Constant(info) => { + scope_ctx.constants.insert(info.name.name.clone(), info); + } + ModuleStatement::Struct(info) => { + scope_ctx.structs.insert(info.name.name.clone(), info); + } + ModuleStatement::Module(_) => todo!(), + } + } + + for statement in &module.contents { + match statement { + ModuleStatement::Function(info) => { + compile_function_def(session, context, &scope_ctx, &block, info)?; + } + ModuleStatement::Constant(_) => todo!(), + ModuleStatement::Struct(_) => todo!(), + ModuleStatement::Module(_) => todo!(), + } + } + + tracing::debug!("compiled module"); + + Ok(()) +} + +fn get_location<'c>(context: &'c MeliorContext, session: &Session, offset: usize) -> Location<'c> { + let (_, line, col) = session.source.get_offset_line(offset).unwrap(); + Location::new( + context, + &session.file_path.display().to_string(), + line + 1, + col + 1, + ) +} + +fn compile_function_def<'ctx, 'parent>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &ScopeContext<'ctx, 'parent>, + block: &'parent Block<'ctx>, + info: &Function, +) -> Result<(), Box> { + tracing::debug!("compiling function: {}", info.name.name); + let region = Region::new(); + + let location = get_location(context, session, info.name.span.lo); + + let mut args = Vec::with_capacity(info.params.len()); + let mut fn_args_types = Vec::with_capacity(info.params.len()); + + for param in &info.params { + let param_type = scope_ctx.resolve_type(context, ¶m.arg_type)?; + let loc = get_location(context, session, param.name.span.lo); + args.push((param_type, loc)); + fn_args_types.push(param_type); + } + + let return_type = if let Some(return_type) = &info.return_type { + vec![scope_ctx.resolve_type(context, return_type)?] + } else { + vec![] + }; + + let func_type = + TypeAttribute::new(FunctionType::new(context, &fn_args_types, &return_type).into()); + + let blocks_arena = Bump::new(); + { + let helper = BlockHelper { + region: ®ion, + blocks_arena: &blocks_arena, + }; + let fn_block = helper.append_block(Block::new(&args)); + let mut scope_ctx = scope_ctx.clone(); + scope_ctx.ret_type = info.return_type.as_ref(); + + // Push arguments into locals + for (i, param) in info.params.iter().enumerate() { + scope_ctx.locals.insert( + param.name.name.clone(), + LocalVar::param(fn_block.argument(i)?.into(), param.arg_type.clone()), + ); + } + + let final_block = compile_block( + session, + context, + &mut scope_ctx, + &helper, + fn_block, + &info.body, + )?; + + if final_block.terminator().is_none() { + final_block.append_operation(func::r#return( + &[], + get_location(context, session, info.span.hi), + )); + } + } + + let op = func::func( + context, + StringAttribute::new(context, &info.name.name), + func_type, + region, + &[], + location, + ); + assert!(op.verify()); + + block.append_operation(op); + + Ok(()) +} + +fn compile_block<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + mut block: &'parent BlockRef<'ctx, 'parent>, + info: &edlang_ast::Block, +) -> Result<&'parent BlockRef<'ctx, 'parent>, Box> { + tracing::debug!("compiling block"); + for stmt in &info.body { + match stmt { + Statement::Let(info) => { + compile_let(session, context, scope_ctx, helper, block, info)?; + } + Statement::Assign(info) => { + compile_assign(session, context, scope_ctx, helper, block, info)?; + } + Statement::For(_) => todo!(), + Statement::While(_) => todo!(), + Statement::If(_) => todo!(), + Statement::Return(info) => { + compile_return(session, context, scope_ctx, helper, block, info)?; + } + Statement::FnCall(_) => todo!(), + } + } + + Ok(block) +} + +fn compile_let<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent BlockRef<'ctx, 'parent>, + info: &LetStmt, +) -> Result<(), Box> { + tracing::debug!("compiling let"); + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(scope_ctx.resolve_type(context, &info.r#type)?), + )?; + let location = get_location(context, session, info.name.span.lo); + + let memref_type = MemRefType::new(value.r#type(), &[1], None, None); + + let alloca: Value = block + .append_operation(memref::alloca( + context, + memref_type, + &[], + &[], + None, + location, + )) + .result(0)? + .into(); + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + block.append_operation(memref::store(value, alloca, &[k0], location)); + + scope_ctx.locals.insert( + info.name.name.clone(), + LocalVar::alloca(alloca, info.r#type.clone()), + ); + + Ok(()) +} + +fn compile_assign<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent BlockRef<'ctx, 'parent>, + info: &AssignStmt, +) -> Result<(), Box> { + tracing::debug!("compiling assign"); + let local = scope_ctx + .locals + .get(&info.name.first.name) + .expect("local should exist") + .clone(); + + assert!(local.is_alloca, "can only mutate local stack variables"); + + let location = get_location(context, session, info.name.first.span.lo); + + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(scope_ctx.resolve_type(context, &local.ast_type)?), + )?; + + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + block.append_operation(memref::store(value, local.value, &[k0], location)); + Ok(()) +} + +fn compile_return<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent BlockRef<'ctx, 'parent>, + info: &ReturnStmt, +) -> Result<(), Box> { + tracing::debug!("compiling return"); + let location = get_location(context, session, info.span.lo); + if let Some(value) = &info.value { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + value, + scope_ctx + .ret_type + .map(|x| scope_ctx.resolve_type(context, x).unwrap()), + )?; + block.append_operation(func::r#return(&[value], location)); + } else { + block.append_operation(func::r#return(&[], location)); + } + + Ok(()) +} + +fn compile_expression<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent BlockRef<'ctx, 'parent>, + info: &Expression, + type_hint: Option>, +) -> Result, Box> { + tracing::debug!("compiling expression"); + Ok(match info { + Expression::Value(info) => match info { + ValueExpr::Bool { value, span } => block + .append_operation(arith::constant( + context, + IntegerAttribute::new((*value) as i64, IntegerType::new(context, 1).into()) + .into(), + get_location(context, session, span.lo), + )) + .result(0)? + .into(), + ValueExpr::Char { value, span } => block + .append_operation(arith::constant( + context, + IntegerAttribute::new((*value) as i64, IntegerType::new(context, 32).into()) + .into(), + get_location(context, session, span.lo), + )) + .result(0)? + .into(), + ValueExpr::Int { value, span } => { + let type_it = match type_hint { + Some(info) => info, + None => IntegerType::new(context, 32).into(), + }; + block + .append_operation(arith::constant( + context, + IntegerAttribute::new((*value) as i64, type_it).into(), + get_location(context, session, span.lo), + )) + .result(0)? + .into() + } + ValueExpr::Float { value, span } => { + let type_it = match type_hint { + Some(info) => info, + None => Type::float32(context), + }; + block + .append_operation(arith::constant( + context, + Attribute::parse(context, &format!("{value} : {type_it}")).unwrap(), + get_location(context, session, span.lo), + )) + .result(0)? + .into() + } + ValueExpr::Str { value: _, span: _ } => todo!(), + ValueExpr::Path(path) => { + let local = scope_ctx + .locals + .get(&path.first.name) + .expect("local not found"); + + let location = get_location(context, session, path.first.span.lo); + + if local.is_alloca { + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + + block + .append_operation(memref::load(local.value, &[k0], location)) + .result(0)? + .into() + } else { + local.value + } + } + }, + Expression::FnCall(info) => { + let mut args = Vec::with_capacity(info.params.len()); + let location = get_location(context, session, info.name.span.lo); + + let target_fn = scope_ctx + .functions + .get(&info.name.name) + .expect("function not found"); + + assert_eq!( + info.params.len(), + target_fn.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.params.iter().zip(&target_fn.params) { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + arg, + Some(scope_ctx.resolve_type(context, &arg_info.arg_type)?), + )?; + args.push(value); + } + + let return_type = if let Some(return_type) = &target_fn.return_type { + vec![scope_ctx.resolve_type(context, return_type)?] + } else { + vec![] + }; + + block + .append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &info.name.name), + &args, + &return_type, + location, + )) + .result(0)? + .into() + } + Expression::Unary(_, _) => todo!(), + Expression::Binary(lhs, op, rhs) => { + let lhs = + compile_expression(session, context, scope_ctx, helper, block, lhs, type_hint)?; + let rhs = + compile_expression(session, context, scope_ctx, helper, block, rhs, type_hint)?; + + match op { + BinaryOp::Arith(op, span) => { + match op { + // todo check if its a float or unsigned + ArithOp::Add => block.append_operation(arith::addi( + lhs, + rhs, + get_location(context, session, span.lo), + )), + ArithOp::Sub => block.append_operation(arith::subi( + lhs, + rhs, + get_location(context, session, span.lo), + )), + ArithOp::Mul => block.append_operation(arith::muli( + lhs, + rhs, + get_location(context, session, span.lo), + )), + ArithOp::Div => block.append_operation(arith::divsi( + lhs, + rhs, + get_location(context, session, span.lo), + )), + ArithOp::Mod => block.append_operation(arith::remsi( + lhs, + rhs, + get_location(context, session, span.lo), + )), + } + } + BinaryOp::Logic(_, _) => todo!(), + BinaryOp::Compare(_, _) => todo!(), + BinaryOp::Bitwise(_, _) => todo!(), + } + .result(0)? + .into() + } + }) +} diff --git a/lib/edlang_codegen_mlir/src/context.rs b/lib/edlang_codegen_mlir/src/context.rs new file mode 100644 index 000000000..b4c8307be --- /dev/null +++ b/lib/edlang_codegen_mlir/src/context.rs @@ -0,0 +1,87 @@ +use std::error::Error; + +use edlang_ast::Module; +use edlang_session::Session; +use melior::{ + dialect::DialectRegistry, + ir::{Location, Module as MeliorModule}, + pass::{self, PassManager}, + utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}, + Context as MeliorContext, +}; + +#[derive(Debug, Eq, PartialEq)] +pub struct Context { + melior_context: MeliorContext, +} + +impl Default for Context { + fn default() -> Self { + Self::new() + } +} + +impl Context { + pub fn new() -> Self { + let melior_context = initialize_mlir(); + Self { melior_context } + } + + pub fn compile( + &self, + session: &Session, + module: &Module, + ) -> Result> { + let file_path = session.file_path.display().to_string(); + let location = Location::new(&self.melior_context, &file_path, 0, 0); + + let mut melior_module = MeliorModule::new(location); + + super::codegen::compile_module(session, &self.melior_context, &melior_module, module)?; + + assert!(melior_module.as_operation().verify()); + + tracing::debug!( + "MLIR Code before passes:\n{:#?}", + melior_module.as_operation() + ); + + // TODO: Add proper error handling. + self.run_pass_manager(&mut melior_module)?; + + tracing::debug!( + "MLIR Code after passes:\n{:#?}", + melior_module.as_operation() + ); + + Ok(melior_module) + } + + fn run_pass_manager(&self, module: &mut MeliorModule) -> Result<(), melior::Error> { + let pass_manager = PassManager::new(&self.melior_context); + pass_manager.enable_verifier(true); + pass_manager.add_pass(pass::transform::create_canonicalizer()); + pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); + pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); + pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); + pass_manager.add_pass(pass::conversion::create_func_to_llvm()); + pass_manager.add_pass(pass::conversion::create_index_to_llvm()); + pass_manager.add_pass(pass::conversion::create_finalize_mem_ref_to_llvm()); + pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); + pass_manager.run(module) + } +} + +/// Initialize an MLIR context. +pub fn initialize_mlir() -> MeliorContext { + let context = MeliorContext::new(); + context.append_dialect_registry(&{ + let registry = DialectRegistry::new(); + register_all_dialects(®istry); + registry + }); + context.load_all_available_dialects(); + register_all_passes(); + register_all_llvm_translations(&context); + context +} diff --git a/lib/edlang_codegen_mlir/src/lib.rs b/lib/edlang_codegen_mlir/src/lib.rs index db21ca1f4..ab8867428 100644 --- a/lib/edlang_codegen_mlir/src/lib.rs +++ b/lib/edlang_codegen_mlir/src/lib.rs @@ -8,6 +8,8 @@ use std::{ sync::OnceLock, }; +use context::Context; +use edlang_ast::Module; use edlang_session::{OptLevel, Session}; use llvm_sys::{ core::{LLVMContextCreate, LLVMContextDispose, LLVMDisposeMessage, LLVMDisposeModule}, @@ -22,14 +24,24 @@ use llvm_sys::{ LLVMTargetRef, }, }; -use melior::ir::Module; +use melior::ir::Module as MeliorModule; use crate::ffi::mlirTranslateModuleToLLVMIR; pub mod codegen; +mod context; mod ffi; pub mod linker; +pub fn compile(session: &Session, program: &Module) -> Result> { + let context = Context::new(); + let mlir_module = context.compile(session, program)?; + + let object_path = compile_to_object(session, &mlir_module)?; + + Ok(object_path) +} + /// Converts a module to an object. /// The object will be written to the specified target path. /// TODO: error handling @@ -37,7 +49,7 @@ pub mod linker; /// Returns the path to the object. pub fn compile_to_object( session: &Session, - module: &Module, + module: &MeliorModule, ) -> Result> { tracing::debug!("Compiling to object file"); if !session.target_dir.exists() { diff --git a/lib/edlang_driver/Cargo.toml b/lib/edlang_driver/Cargo.toml index 7f407feda..1450aaea1 100644 --- a/lib/edlang_driver/Cargo.toml +++ b/lib/edlang_driver/Cargo.toml @@ -11,6 +11,7 @@ categories = ["compilers"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ariadne = { version = "0.4.0", features = ["auto-color"] } clap = { version = "4.4.16", features = ["derive"] } color-eyre = "0.6.2" edlang_ast = { version = "0.1.0", path = "../edlang_ast" } diff --git a/lib/edlang_driver/src/lib.rs b/lib/edlang_driver/src/lib.rs index 226e442db..8555f3ba6 100644 --- a/lib/edlang_driver/src/lib.rs +++ b/lib/edlang_driver/src/lib.rs @@ -1,6 +1,9 @@ use std::{error::Error, path::PathBuf, time::Instant}; +use ariadne::Source; use clap::Parser; +use edlang_codegen_mlir::linker::{link_binary, link_shared_lib}; +use edlang_session::{DebugInfo, OptLevel, Session}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -24,21 +27,17 @@ pub fn main() -> Result<(), Box> { let args = CompilerArgs::parse(); - /* - let db = crate::db::Database::default(); - let source = ProgramSource::new(&db, std::fs::read_to_string(args.input.clone())?); - tracing::debug!("source code:\n{}", source.input(&db)); - let program = match concrete_parser::parse_ast(&db, source) { - Some(x) => x, - None => { - Diagnostics::dump( - &db, - source, - &concrete_parser::parse_ast::accumulated::( - &db, source, - ), - ); - panic!(); + let path = args.input.display().to_string(); + let source = std::fs::read_to_string(&args.input)?; + + let module = edlang_parser::parse_ast(&source); + + let module = match module { + Ok(module) => module, + Err(error) => { + let report = edlang_parser::error_to_report(&path, &error)?; + edlang_parser::print_report(&path, &source, report)?; + std::process::exit(1) } }; @@ -64,21 +63,20 @@ pub fn main() -> Result<(), Box> { } else { OptLevel::None }, - source: source.input(&db).to_string(), + source: Source::from(source), library: args.library, target_dir, output_file, }; tracing::debug!("Compiling with session: {:#?}", session); - // let object_path = concrete_codegen_mlir::compile(&session, &program)?; + let object_path = edlang_codegen_mlir::compile(&session, &module)?; if session.library { link_shared_lib(&object_path, &session.output_file.with_extension("so"))?; } else { link_binary(&object_path, &session.output_file.with_extension(""))?; } - */ let elapsed = start_time.elapsed(); tracing::debug!("Done in {:?}", elapsed); diff --git a/lib/edlang_parser/Cargo.toml b/lib/edlang_parser/Cargo.toml index 13df4025b..74748050c 100644 --- a/lib/edlang_parser/Cargo.toml +++ b/lib/edlang_parser/Cargo.toml @@ -13,6 +13,7 @@ categories = ["compilers"] [dependencies] ariadne = { version = "0.4.0", features = ["auto-color"] } edlang_ast = { version = "0.1.0", path = "../edlang_ast" } +itertools = "0.12.0" lalrpop-util = { version = "0.20.0", features = ["lexer"] } logos = "0.13.0" tracing = { workspace = true } diff --git a/lib/edlang_parser/src/grammar.lalrpop b/lib/edlang_parser/src/grammar.lalrpop index 061d31e26..f2e8836d6 100644 --- a/lib/edlang_parser/src/grammar.lalrpop +++ b/lib/edlang_parser/src/grammar.lalrpop @@ -403,4 +403,5 @@ pub(crate) ModuleStatement: ast::ModuleStatement = { => ast::ModuleStatement::Function(<>), => ast::ModuleStatement::Constant(<>), => ast::ModuleStatement::Struct(<>), + => ast::ModuleStatement::Module(<>), } diff --git a/lib/edlang_parser/src/lib.rs b/lib/edlang_parser/src/lib.rs index 4aa0acf39..c357865f2 100644 --- a/lib/edlang_parser/src/lib.rs +++ b/lib/edlang_parser/src/lib.rs @@ -1,9 +1,11 @@ -use std::{ops::Range, path::Path}; +use std::ops::Range; -use ariadne::{Color, ColorGenerator, Fmt, Label, Report, ReportKind, Source}; -use error::Error; +use crate::error::Error; +use ariadne::{ColorGenerator, Label, Report, ReportKind, Source}; +use itertools::Itertools; use lalrpop_util::ParseError; use lexer::{Lexer, LexicalError}; +use tokens::Token; pub mod error; pub mod lexer; @@ -18,57 +20,76 @@ pub mod grammar { lalrpop_mod!(pub grammar); } -pub fn parse_ast(source: &str) { +pub fn parse_ast( + source: &str, +) -> Result> { let lexer = Lexer::new(source); - let parser = grammar::IdentParser::new(); + let parser = grammar::ModuleParser::new(); + parser.parse(lexer) } -pub fn print_error(path: &str, source: &str, error: &Error) -> Result<(), std::io::Error> { +pub fn print_report<'a>( + path: &'a str, + source: &'a str, + report: Report<'static, (&'a str, Range)>, +) -> Result<(), std::io::Error> { let source = Source::from(source); - match error { + report.eprint((path, source)) +} + +pub fn error_to_report<'a>( + path: &'a str, + error: &Error, +) -> Result)>, std::io::Error> { + let mut colors = ColorGenerator::new(); + let report = match error { ParseError::InvalidToken { location } => { let loc = *location; Report::build(ReportKind::Error, path, loc) - .with_code(1) - .with_message("Invalid token") - .with_label(Label::new((path, loc..(loc + 1))).with_message("invalid token")) + .with_code("P1") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_color(colors.next()) + .with_message("invalid token"), + ) .finish() - .eprint((path, source))?; } ParseError::UnrecognizedEof { location, expected } => { let loc = *location; Report::build(ReportKind::Error, path, loc) - .with_code(2) - .with_message("Unrecognized end of file") - .with_label(Label::new((path, loc..(loc + 1))).with_message(format!( - "unrecognized eof, expected one of the following: {:?}", - expected - ))) + .with_code("P2") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_message(format!( + "unrecognized eof, expected one of the following: {}", + expected.iter().join(", ") + )) + .with_color(colors.next()), + ) .finish() - .eprint((path, source))?; } ParseError::UnrecognizedToken { token, expected } => { Report::build(ReportKind::Error, path, token.0) .with_code(3) - .with_message("Unrecognized token") - .with_label(Label::new((path, token.0..token.2)).with_message(format!( - "unrecognized token {:?}, expected one of the following: {:?}", - token.1, expected - ))) - .finish() - .eprint((path, source))?; - } - ParseError::ExtraToken { token } => { - Report::build(ReportKind::Error, path, token.0) - .with_code(4) - .with_message("Extra token") .with_label( Label::new((path, token.0..token.2)) - .with_message(format!("unexpected extra token {:?}", token.1)), + .with_message(format!( + "unrecognized token {:?}, expected one of the following: {}", + token.1, + expected.iter().join(", ") + )) + .with_color(colors.next()), ) .finish() - .eprint((path, source))?; } + ParseError::ExtraToken { token } => Report::build(ReportKind::Error, path, token.0) + .with_code("P3") + .with_message("Extra token") + .with_label( + Label::new((path, token.0..token.2)) + .with_message(format!("unexpected extra token {:?}", token.1)), + ) + .finish(), ParseError::User { error } => match error { LexicalError::InvalidToken(err, range) => match err { tokens::LexingError::NumberParseError => { @@ -77,26 +98,25 @@ pub fn print_error(path: &str, source: &str, error: &Error) -> Result<(), std::i .with_message("Error parsing literal number") .with_label( Label::new((path, range.start..range.end)) - .with_message("error parsing literal number"), + .with_message("error parsing literal number") + .with_color(colors.next()), ) .finish() - .eprint((path, source))?; - } - tokens::LexingError::Other => { - Report::build(ReportKind::Error, path, range.start) - .with_code(4) - .with_message("Other error") - .with_label( - Label::new((path, range.start..range.end)).with_message("other error"), - ) - .finish() - .eprint((path, source))?; } + tokens::LexingError::Other => Report::build(ReportKind::Error, path, range.start) + .with_code(4) + .with_message("Other error") + .with_label( + Label::new((path, range.start..range.end)) + .with_message("other error") + .with_color(colors.next()), + ) + .finish(), }, }, - } + }; - Ok(()) + Ok(report) } #[cfg(test)] diff --git a/programs/example.ed b/programs/example.ed index d863d7052..e2948ac0a 100644 --- a/programs/example.ed +++ b/programs/example.ed @@ -1,9 +1,11 @@ -fn add(a: i32, b: i32) -> i32 { - return a + b; -} +mod Main { + fn add(a: i32, b: i32) -> i32 { + return a + b; + } -fn main() -> i32 { - let x = 2 + 3; - let y = add(x, 4); - return y; + fn main() -> i32 { + let x: i32 = 2 + 3; + let y: i32 = add(x, 4); + return y; + } } diff --git a/programs/ifelse.ed b/programs/ifelse.ed deleted file mode 100644 index 6703b2c37..000000000 --- a/programs/ifelse.ed +++ /dev/null @@ -1,15 +0,0 @@ -fn works(x: i64) -> i64 { - let z = 0i64; - if 2i64 == x { - z = x * 2i64; - } else { - z = x * 3i64; - } - return z; -} - -fn main() -> i64 { - let y = 2i64; - let z = y; - return works(z); -} diff --git a/programs/simple.ed b/programs/simple.ed deleted file mode 100644 index c21572016..000000000 --- a/programs/simple.ed +++ /dev/null @@ -1,24 +0,0 @@ -struct Hello { - x: i32, - y: i32, -} - -fn test(x: Hello) { - return; -} - -fn works(x: i64) -> i64 { - let z = 0i64; - if 2i64 == x { - z = x * 2i64; - } else { - z = x * 3i64; - } - return z; -} - -fn main() -> i64 { - let y = 2i64; - let z = y; - return works(z); -} diff --git a/programs/std.ed b/programs/std.ed deleted file mode 100644 index 5cbfdd516..000000000 --- a/programs/std.ed +++ /dev/null @@ -1,4 +0,0 @@ - -struct String { - -} diff --git a/programs/struct.ed b/programs/struct.ed deleted file mode 100644 index b73f52af9..000000000 --- a/programs/struct.ed +++ /dev/null @@ -1,10 +0,0 @@ - - -struct Hello { - y: i16, - x: i32, -} - -fn test(x: Hello) { - return; -}