diff --git a/Cargo.lock b/Cargo.lock index 324b311c7..df271deb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -370,6 +370,7 @@ version = "0.1.0" dependencies = [ "edlang_ast", "edlang_ir", + "tracing", ] [[package]] diff --git a/lib/edlang_ast/src/lib.rs b/lib/edlang_ast/src/lib.rs index 6d8efbb8e..ca1b512c6 100644 --- a/lib/edlang_ast/src/lib.rs +++ b/lib/edlang_ast/src/lib.rs @@ -18,7 +18,7 @@ pub enum ModuleStatement { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Import { - pub path: PathExpr, + pub module: Vec, /// If symbols is empty then the last path ident is the symbol. pub symbols: Vec, pub span: Span, diff --git a/lib/edlang_codegen_mlir/src/codegen.rs b/lib/edlang_codegen_mlir/src/codegen.rs index e0f0d7ee8..a9276bb75 100644 --- a/lib/edlang_codegen_mlir/src/codegen.rs +++ b/lib/edlang_codegen_mlir/src/codegen.rs @@ -4,24 +4,22 @@ use edlang_ir as ir; use edlang_ir::DefId; use edlang_session::Session; use inkwell::{ - attributes::Attribute, builder::{Builder, BuilderError}, context::Context, debug_info::{DICompileUnit, DebugInfoBuilder}, module::Module, - targets::{InitializationConfig, Target, TargetData, TargetMachine, TargetTriple}, + targets::{InitializationConfig, Target, TargetData, TargetMachine}, types::{AnyType, BasicMetadataTypeEnum, BasicType}, - values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue}, + values::{BasicValue, BasicValueEnum, PointerValue}, }; -use ir::{ConstData, Operand, TypeInfo, ValueTree}; +use ir::{ModuleBody, ProgramBody, TypeInfo, ValueTree}; use tracing::info; #[derive(Debug, Clone, Copy)] struct CompileCtx<'a> { context: &'a Context, session: &'a Session, - modules: &'a HashMap, - symbols: &'a HashMap, + program: &'a ProgramBody, } struct ModuleCompileCtx<'ctx, 'm> { @@ -31,21 +29,23 @@ struct ModuleCompileCtx<'ctx, 'm> { di_builder: DebugInfoBuilder<'ctx>, di_unit: DICompileUnit<'ctx>, target_data: TargetData, + module_id: DefId, } -pub fn compile( - session: &Session, - modules: &HashMap, - symbols: &HashMap, -) -> Result> { +impl<'ctx, 'm> ModuleCompileCtx<'ctx, 'm> { + pub fn get_module_body(&self) -> &ModuleBody { + self.ctx.program.modules.get(&self.module_id).unwrap() + } +} + +pub fn compile(session: &Session, program: &ProgramBody) -> Result> { let context = Context::create(); let builder = context.create_builder(); let ctx = CompileCtx { context: &context, session, - modules, - symbols, + program, }; let mut llvm_modules = Vec::new(); @@ -68,9 +68,9 @@ pub fn compile( let filename = session.file_path.file_name().unwrap().to_string_lossy(); let dir = session.file_path.parent().unwrap().to_string_lossy(); - for (id, module) in modules.iter() { - let name = ctx.symbols.get(id).unwrap(); - let llvm_module = context.create_module(name); + for module_id in program.top_level_modules.iter() { + let module = ctx.program.modules.get(module_id).unwrap(); + let llvm_module = context.create_module(&module.name); llvm_module.set_source_file_name(&filename); llvm_module.set_triple(&triple); let (di_builder, di_unit) = llvm_module.create_debug_info_builder( @@ -84,7 +84,7 @@ pub fn compile( 1, "", // split name inkwell::debug_info::DWARFEmissionKind::Full, - module.module_id.module_id.try_into().unwrap(), // compile unit id? + module.module_id.program_id.try_into().unwrap(), // compile unit id? false, false, "", @@ -98,9 +98,10 @@ pub fn compile( di_unit, builder: &builder, target_data: machine.get_target_data(), + module_id: *module_id, }; - compile_module(&module_ctx, module); + compile_module(&module_ctx, *module_id); module_ctx.module.verify()?; @@ -125,37 +126,36 @@ pub fn compile( Ok(session.output_file.with_extension("o")) } -fn compile_module(ctx: &ModuleCompileCtx, module: &ir::ModuleBody) { +fn compile_module(ctx: &ModuleCompileCtx, module_id: DefId) { + let module = ctx.ctx.program.modules.get(&module_id).unwrap(); info!("compiling module"); - for (_fn_id, func) in module.functions.iter() { - compile_fn_signature(ctx, func); + for id in module.functions.iter() { + compile_fn_signature(ctx, *id); } - for (_fn_id, func) in module.functions.iter() { - compile_fn(ctx, func).unwrap(); + for id in module.functions.iter() { + compile_fn(ctx, *id).unwrap(); } } -fn compile_fn_signature(ctx: &ModuleCompileCtx, body: &ir::Body) { - let name = ctx.ctx.symbols.get(&body.def_id).unwrap(); - info!("compiling fn sig: {}", name); +fn compile_fn_signature(ctx: &ModuleCompileCtx, fn_id: DefId) { + let (arg_types, ret_type) = ctx.ctx.program.function_signatures.get(&fn_id).unwrap(); + let body = ctx.ctx.program.functions.get(&fn_id).unwrap(); + info!("compiling fn sig: {}", body.name); - let (args, ret_type) = { (body.get_args(), body.ret_type.clone()) }; - - let args: Vec = args + let args: Vec = arg_types .iter() - .map(|x| compile_basic_type(ctx, &x.ty).into()) + .map(|x| compile_basic_type(ctx, x).into()) .collect(); - // let ret_type = compile_basic_type(ctx, &ret_type); let fn_type = if let ir::TypeKind::Unit = ret_type.kind { ctx.ctx.context.void_type().fn_type(&args, false) } else { - compile_basic_type(ctx, &ret_type).fn_type(&args, false) + compile_basic_type(ctx, ret_type).fn_type(&args, false) }; let fn_value = ctx.module.add_function( - name, + &body.name, fn_type, Some(if body.is_extern { inkwell::module::Linkage::AvailableExternally @@ -173,12 +173,11 @@ fn compile_fn_signature(ctx: &ModuleCompileCtx, body: &ir::Body) { ); } -fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderError> { - let name = ctx.ctx.symbols.get(&body.def_id).unwrap(); - info!("compiling fn body: {}", name); - // let (args, ret_type) = { (body.get_args(), body.ret_type.clone().unwrap()) }; +fn compile_fn(ctx: &ModuleCompileCtx, fn_id: DefId) -> Result<(), BuilderError> { + let body = ctx.ctx.program.functions.get(&fn_id).unwrap(); + info!("compiling fn body: {}", body.name); - let fn_value = ctx.module.get_function(name).unwrap(); + let fn_value = ctx.module.get_function(&body.name).unwrap(); let block = ctx.ctx.context.append_basic_block(fn_value, "entry"); ctx.builder.position_at_end(block); @@ -236,21 +235,11 @@ fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderErro for stmt in &block.statements { info!("compiling stmt"); match &stmt.kind { - ir::StatementKind::Assign(place, rvalue) => match rvalue { - ir::RValue::Use(op) => { - let value = compile_load_operand(ctx, body, &locals, op)?.0; - ctx.builder - .build_store(*locals.get(&place.local).unwrap(), value)?; - } - ir::RValue::Ref(_, _) => todo!(), - ir::RValue::BinOp(op, lhs, rhs) => { - let value = compile_bin_op(ctx, body, &locals, *op, lhs, rhs)?; - ctx.builder - .build_store(*locals.get(&place.local).unwrap(), value)?; - } - ir::RValue::LogicOp(_, _, _) => todo!(), - ir::RValue::UnOp(_, _) => todo!(), - }, + ir::StatementKind::Assign(place, rvalue) => { + let (value, _value_ty) = compile_rvalue(ctx, fn_id, &locals, rvalue)?; + ctx.builder + .build_store(*locals.get(&place.local).unwrap(), value)?; + } ir::StatementKind::StorageLive(_) => { // https://llvm.org/docs/LangRef.html#int-lifestart } @@ -278,44 +267,28 @@ fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderErro ir::Terminator::Call { func, args, - dest, + destination: dest, target, } => { - if let Operand::Constant(c) = func { - if let ir::TypeKind::FnDef(fn_id, generics) = &c.type_info.kind { - let fn_symbol = ctx.ctx.symbols.get(fn_id).unwrap(); - let fn_value = ctx.module.get_function(fn_symbol).unwrap(); - let args: Vec<_> = args - .iter() - .map(|x| { - compile_load_operand(ctx, body, &locals, x) - .unwrap() - .0 - .into() - }) - .collect(); - let result = ctx.builder.build_call(fn_value, &args, "")?; + let target_fn_body = ctx.ctx.program.functions.get(func).unwrap(); + let fn_value = ctx.module.get_function(&target_fn_body.name).unwrap(); + let args: Vec<_> = args + .iter() + .map(|x| compile_rvalue(ctx, fn_id, &locals, x).unwrap().0.into()) + .collect(); + let result = ctx.builder.build_call(fn_value, &args, "")?; - if let Some(dest) = dest { - let is_void = - matches!(body.locals[dest.local].ty.kind, ir::TypeKind::Unit); + let is_void = matches!(body.locals[dest.local].ty.kind, ir::TypeKind::Unit); - if !is_void { - ctx.builder.build_store( - *locals.get(&dest.local).unwrap(), - result.try_as_basic_value().expect_left("value was right"), - )?; - } - } + if !is_void { + ctx.builder.build_store( + *locals.get(&dest.local).unwrap(), + result.try_as_basic_value().expect_left("value was right"), + )?; + } - if let Some(target) = target { - ctx.builder.build_unconditional_branch(blocks[*target])?; - } - } else { - todo!() - } - } else { - todo!() + if let Some(target) = target { + ctx.builder.build_unconditional_branch(blocks[*target])?; } } ir::Terminator::Unreachable => { @@ -329,21 +302,21 @@ fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderErro fn compile_bin_op<'ctx>( ctx: &ModuleCompileCtx<'ctx, '_>, - body: &ir::Body, + fn_id: DefId, locals: &HashMap>, op: ir::BinOp, lhs: &ir::Operand, rhs: &ir::Operand, -) -> Result, BuilderError> { - let (lhs_value, lhs_ty) = compile_load_operand(ctx, body, locals, lhs)?; - let (rhs_value, _rhs_ty) = compile_load_operand(ctx, body, locals, rhs)?; +) -> Result<(BasicValueEnum<'ctx>, TypeInfo), BuilderError> { + let (lhs_value, lhs_ty) = compile_load_operand(ctx, fn_id, locals, lhs)?; + let (rhs_value, _rhs_ty) = compile_load_operand(ctx, fn_id, locals, rhs)?; let is_float = matches!(lhs_ty.kind, ir::TypeKind::Float(_)); let is_signed = matches!(lhs_ty.kind, ir::TypeKind::Int(_)); Ok(match op { ir::BinOp::Add => { - if is_float { + let value = if is_float { ctx.builder .build_float_add( lhs_value.into_float_value(), @@ -355,10 +328,11 @@ fn compile_bin_op<'ctx>( ctx.builder .build_int_add(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? .as_basic_value_enum() - } + }; + (value, lhs_ty) } ir::BinOp::Sub => { - if is_float { + let value = if is_float { ctx.builder .build_float_sub( lhs_value.into_float_value(), @@ -370,10 +344,11 @@ fn compile_bin_op<'ctx>( ctx.builder .build_int_sub(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? .as_basic_value_enum() - } + }; + (value, lhs_ty) } ir::BinOp::Mul => { - if is_float { + let value = if is_float { ctx.builder .build_float_mul( lhs_value.into_float_value(), @@ -385,10 +360,11 @@ fn compile_bin_op<'ctx>( ctx.builder .build_int_add(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? .as_basic_value_enum() - } + }; + (value, lhs_ty) } ir::BinOp::Div => { - if is_float { + let value = if is_float { ctx.builder .build_float_div( lhs_value.into_float_value(), @@ -412,10 +388,11 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + (value, lhs_ty) } ir::BinOp::Rem => { - if is_float { + let value = if is_float { ctx.builder .build_float_rem( lhs_value.into_float_value(), @@ -439,35 +416,46 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + (value, lhs_ty) } - ir::BinOp::BitXor => ctx - .builder - .build_xor(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? - .as_basic_value_enum(), - ir::BinOp::BitAnd => ctx - .builder - .build_and(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? - .as_basic_value_enum(), - ir::BinOp::BitOr => ctx - .builder - .build_or(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? - .as_basic_value_enum(), - ir::BinOp::Shl => ctx - .builder - .build_left_shift(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? - .as_basic_value_enum(), - ir::BinOp::Shr => ctx - .builder - .build_right_shift( - lhs_value.into_int_value(), - rhs_value.into_int_value(), - is_signed, - "", - )? - .as_basic_value_enum(), + ir::BinOp::BitXor => ( + ctx.builder + .build_xor(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum(), + lhs_ty, + ), + ir::BinOp::BitAnd => ( + ctx.builder + .build_and(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum(), + lhs_ty, + ), + ir::BinOp::BitOr => ( + ctx.builder + .build_or(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum(), + lhs_ty, + ), + ir::BinOp::Shl => ( + ctx.builder + .build_left_shift(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum(), + lhs_ty, + ), + ir::BinOp::Shr => ( + ctx.builder + .build_right_shift( + lhs_value.into_int_value(), + rhs_value.into_int_value(), + is_signed, + "", + )? + .as_basic_value_enum(), + lhs_ty, + ), ir::BinOp::Eq => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::OEQ, @@ -485,10 +473,17 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Lt => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::OLT, @@ -510,10 +505,17 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Le => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::OLE, @@ -535,10 +537,17 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Ne => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::ONE, @@ -556,10 +565,17 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Ge => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::OGE, @@ -581,10 +597,17 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Gt => { - if is_float { + let value = if is_float { ctx.builder .build_float_compare( inkwell::FloatPredicate::OGT, @@ -606,19 +629,42 @@ fn compile_bin_op<'ctx>( "", )? .as_basic_value_enum() - } + }; + ( + value, + TypeInfo { + span: None, + kind: ir::TypeKind::Bool, + }, + ) } ir::BinOp::Offset => todo!(), }) } +fn compile_rvalue<'ctx>( + ctx: &ModuleCompileCtx<'ctx, '_>, + fn_id: DefId, + locals: &HashMap>, + rvalue: &ir::RValue, +) -> Result<(BasicValueEnum<'ctx>, TypeInfo), BuilderError> { + Ok(match rvalue { + ir::RValue::Use(op) => compile_load_operand(ctx, fn_id, locals, op)?, + ir::RValue::Ref(_, _) => todo!(), + ir::RValue::BinOp(op, lhs, rhs) => compile_bin_op(ctx, fn_id, locals, *op, lhs, rhs)?, + ir::RValue::LogicOp(_, _, _) => todo!(), + ir::RValue::UnOp(_, _) => todo!(), + }) +} + fn compile_load_operand<'ctx>( ctx: &ModuleCompileCtx<'ctx, '_>, - body: &ir::Body, + fn_id: DefId, locals: &HashMap>, op: &ir::Operand, ) -> Result<(BasicValueEnum<'ctx>, TypeInfo), BuilderError> { // todo: implement projection + let body = ctx.ctx.program.functions.get(&fn_id).unwrap(); Ok(match op { ir::Operand::Copy(place) => { let pointee_ty = compile_basic_type(ctx, &body.locals[place.local].ty); @@ -715,23 +761,13 @@ fn compile_type<'a>( match &ty.kind { ir::TypeKind::Unit => context.void_type().as_any_type_enum(), ir::TypeKind::FnDef(def_id, _generic_args) => { - let (args, ret_type) = { - let fn_body = ctx - .ctx - .modules - .get(&def_id.get_module_defid()) - .unwrap() - .functions - .get(def_id) - .unwrap(); - (fn_body.get_args(), fn_body.ret_type.clone()) - }; + let (args, ret_type) = { ctx.ctx.program.function_signatures.get(def_id).unwrap() }; let args: Vec = args .iter() - .map(|x| compile_basic_type(ctx, &x.ty).into()) + .map(|x| compile_basic_type(ctx, x).into()) .collect(); - let ret_type = compile_basic_type(ctx, &ret_type); + let ret_type = compile_basic_type(ctx, ret_type); ret_type.fn_type(&args, false).as_any_type_enum() } diff --git a/lib/edlang_codegen_mlir/src/lib.rs b/lib/edlang_codegen_mlir/src/lib.rs index 244fba0ee..29e83809c 100644 --- a/lib/edlang_codegen_mlir/src/lib.rs +++ b/lib/edlang_codegen_mlir/src/lib.rs @@ -1,18 +1,10 @@ #![allow(clippy::too_many_arguments)] use edlang_ir as ir; -use std::{ - collections::HashMap, - ffi::{CStr, CString}, - mem::MaybeUninit, - path::PathBuf, - ptr::{addr_of_mut, null_mut}, - sync::OnceLock, -}; +use std::path::PathBuf; -use edlang_session::{OptLevel, Session}; -use inkwell::context::Context; -use ir::DefId; +use edlang_session::Session; +use ir::ProgramBody; /* use llvm_sys::{ core::{LLVMContextCreate, LLVMContextDispose, LLVMDisposeMessage, LLVMDisposeModule}, @@ -34,10 +26,9 @@ pub mod linker; pub fn compile( session: &Session, - modules: &HashMap, - symbols: HashMap, + program: &ProgramBody, ) -> Result> { - codegen::compile(session, modules, &symbols) + codegen::compile(session, program) } // Converts a module to an object. diff --git a/lib/edlang_driver/src/lib.rs b/lib/edlang_driver/src/lib.rs index 999175890..b5e229b63 100644 --- a/lib/edlang_driver/src/lib.rs +++ b/lib/edlang_driver/src/lib.rs @@ -88,14 +88,14 @@ pub fn main() -> Result<(), Box> { return Ok(()); } - let (symbols, module_irs) = lower_modules(&[module.clone()]); + let program_ir = lower_modules(&[module.clone()]); if args.ir { - println!("{:#?}", module_irs); + println!("{:#?}", program_ir); return Ok(()); } - let object_path = edlang_codegen_mlir::compile(&session, &module_irs, symbols)?; + let object_path = edlang_codegen_mlir::compile(&session, &program_ir)?; if session.library { link_shared_lib(&object_path, &session.output_file.with_extension("so"))?; diff --git a/lib/edlang_ir/src/lib.rs b/lib/edlang_ir/src/lib.rs index 31d86acb1..c8e9923db 100644 --- a/lib/edlang_ir/src/lib.rs +++ b/lib/edlang_ir/src/lib.rs @@ -1,42 +1,69 @@ // Based on a cfg -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashSet}; use edlang_span::Span; use smallvec::SmallVec; pub mod scalar_int; +#[derive(Debug, Clone, Default)] +pub struct SymbolTable { + pub symbols: BTreeMap, + pub modules: BTreeMap, + pub functions: BTreeMap, + pub constants: BTreeMap, + pub structs: BTreeMap, + pub types: BTreeMap, +} + +#[derive(Debug, Clone, Default)] +pub struct ProgramBody { + pub top_level_module_names: BTreeMap, + /// The top level modules. + pub top_level_modules: Vec, + /// All the modules in a flat map. + pub modules: BTreeMap, + /// This stores all the functions from all modules + pub functions: BTreeMap, + /// The function signatures. + pub function_signatures: BTreeMap, TypeInfo)>, +} + #[derive(Debug, Clone)] pub struct ModuleBody { pub module_id: DefId, - pub functions: BTreeMap, - pub modules: BTreeMap, + pub parent_ids: Vec, + pub name: String, + pub symbols: SymbolTable, + /// Functions defined in this module. + pub functions: HashSet, + /// Structs defined in this module. + pub structs: HashSet, + /// Types defined in this module. + pub types: HashSet, + /// Constants defined in this module. + pub constants: HashSet, + /// Submodules defined in this module. + pub modules: HashSet, + /// Imported items. symbol -> id + pub imports: BTreeMap, pub span: Span, } /// Definition id. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct DefId { - pub module_id: usize, + pub program_id: usize, pub id: usize, } -impl DefId { - pub fn get_module_defid(&self) -> Self { - Self { - module_id: self.module_id, - id: 0, - } - } -} - #[derive(Debug, Clone)] pub struct Body { pub def_id: DefId, pub is_pub: bool, pub is_extern: bool, - pub ret_type: TypeInfo, + pub name: String, pub locals: SmallVec<[Local; 4]>, pub blocks: SmallVec<[BasicBlock; 8]>, pub fn_span: Span, @@ -68,7 +95,6 @@ pub struct DebugInfo { #[derive(Debug, Clone)] pub struct BasicBlock { - pub id: usize, pub statements: SmallVec<[Statement; 8]>, pub terminator: Terminator, } @@ -77,10 +103,39 @@ pub struct BasicBlock { pub struct Local { pub mutable: bool, pub span: Option, + pub debug_name: Option, pub ty: TypeInfo, pub kind: LocalKind, } +impl Local { + pub fn new( + span: Option, + kind: LocalKind, + ty: TypeInfo, + debug_name: Option, + mutable: bool, + ) -> Self { + Self { + span, + kind, + ty, + debug_name, + mutable, + } + } + + pub const fn temp(ty: TypeInfo) -> Self { + Self { + span: None, + ty, + kind: LocalKind::Temp, + debug_name: None, + mutable: false, + } + } +} + #[derive(Debug, Clone, Copy)] pub enum LocalKind { Temp, @@ -107,10 +162,14 @@ pub enum Terminator { Return, Switch, Call { - func: Operand, - args: Vec, - dest: Option, - target: Option, // block + /// The function to call. + func: DefId, + /// The arguments. + args: Vec, + /// The place in memory to store the return value of the function call. + destination: Place, + /// What basic block to jump to after the function call, if the function is non-diverging (i.e it returns control back). + target: Option, }, Unreachable, } diff --git a/lib/edlang_lowering/Cargo.toml b/lib/edlang_lowering/Cargo.toml index c03eecb8f..c5799f90a 100644 --- a/lib/edlang_lowering/Cargo.toml +++ b/lib/edlang_lowering/Cargo.toml @@ -8,3 +8,4 @@ edition = "2021" [dependencies] edlang_ast = { version = "0.1.0", path = "../edlang_ast" } edlang_ir = { version = "0.1.0", path = "../edlang_ir" } +tracing.workspace = true diff --git a/lib/edlang_lowering/src/common.rs b/lib/edlang_lowering/src/common.rs index 4b70caf83..11d1df801 100644 --- a/lib/edlang_lowering/src/common.rs +++ b/lib/edlang_lowering/src/common.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use edlang_ir::{Body, DefId, Local, Statement, TypeInfo}; +use edlang_ir::{Body, DefId, Local, ModuleBody, ProgramBody, Statement, TypeInfo, TypeKind}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)] pub struct IdGenerator { @@ -23,7 +23,7 @@ impl IdGenerator { pub fn module_defid(&self) -> DefId { DefId { - module_id: self.module_id, + program_id: self.module_id, id: 0, } } @@ -32,7 +32,7 @@ impl IdGenerator { let id = self.next_id(); DefId { - module_id: self.module_id, + program_id: self.module_id, id, } } @@ -47,18 +47,9 @@ impl IdGenerator { #[derive(Debug, Clone, Default)] pub struct BuildCtx { - pub module_name_to_id: HashMap, - pub modules: HashMap, - pub functions: HashMap, - pub gen: IdGenerator, - pub symbol_names: HashMap, -} - -#[derive(Debug, Clone, Default)] -pub struct ModuleCtx { - pub id: DefId, - pub func_name_to_id: HashMap, - pub functions: HashMap, TypeInfo)>, + pub body: ProgramBody, + pub unresolved_function_signatures: + HashMap, Option)>, pub gen: IdGenerator, } @@ -67,8 +58,8 @@ pub struct BodyBuilder { pub local_module: DefId, pub body: Body, pub statements: Vec, - pub locals: HashMap, - pub ret_local: Option, + pub name_to_local: HashMap, + pub ret_local: usize, pub ctx: BuildCtx, } @@ -78,27 +69,21 @@ impl BodyBuilder { self.body.locals.push(local); id } + + pub fn add_temp_local(&mut self, ty_kind: TypeKind) -> usize { + let id = self.body.locals.len(); + self.body.locals.push(Local::temp(TypeInfo { + span: None, + kind: ty_kind, + })); + id + } + pub fn get_local(&self, name: &str) -> Option<&Local> { - self.body.locals.get(*(self.locals.get(name)?)) + self.body.locals.get(*(self.name_to_local.get(name)?)) } - pub fn get_current_module(&self) -> &ModuleCtx { - self.ctx - .modules - .get(&self.local_module) - .expect("current module should exist") - } - - pub fn get_current_module_mut(&mut self) -> &mut ModuleCtx { - self.ctx - .modules - .get_mut(&self.local_module) - .expect("current module should exist") - } - - pub fn get_fn_by_name(&self, name: &str) -> Option<&(Vec, TypeInfo)> { - let id = self.get_current_module().func_name_to_id.get(name)?; - let f = self.get_current_module().functions.get(&id)?; - Some(f) + pub fn get_module_body(&self) -> &ModuleBody { + self.ctx.body.modules.get(&self.local_module).unwrap() } } diff --git a/lib/edlang_lowering/src/lib.rs b/lib/edlang_lowering/src/lib.rs index 7d9f73187..0f9c27aa9 100644 --- a/lib/edlang_lowering/src/lib.rs +++ b/lib/edlang_lowering/src/lib.rs @@ -1,222 +1,196 @@ use std::collections::HashMap; -use common::{BodyBuilder, BuildCtx, IdGenerator, ModuleCtx}; +use ast::ModuleStatement; +use common::{BodyBuilder, BuildCtx, IdGenerator}; use edlang_ast as ast; use edlang_ir as ir; -use ir::{ConstData, ConstKind, DefId, Local, Operand, Place, Statement, Terminator, TypeInfo}; +use ir::{ + BasicBlock, Body, ConstData, ConstKind, DefId, Local, LocalKind, Operand, Place, ProgramBody, + RValue, Statement, StatementKind, Terminator, TypeInfo, TypeKind, +}; mod common; +mod prepass; -pub fn lower_modules( - modules: &[ast::Module], -) -> (HashMap, HashMap) { +pub fn lower_modules(modules: &[ast::Module]) -> ProgramBody { let mut ctx = BuildCtx::default(); - for m in modules { - let module_id = ctx.gen.module_defid(); - ctx.module_name_to_id - .insert(m.name.name.clone(), ctx.gen.module_defid()); - ctx.symbol_names.insert(module_id, m.name.name.clone()); - - ctx.modules.insert( - module_id, - ModuleCtx { - id: module_id, - gen: IdGenerator::new(module_id.module_id), - ..Default::default() - }, - ); - - ctx.gen.next_module_defid(); - } - - let mut lowered_modules = HashMap::with_capacity(modules.len()); - - // todo: maybe should do a prepass here populating all symbols - + // resolve symbols for module in modules { - let ir; - (ctx, ir) = lower_module(ctx, module); - lowered_modules.insert(ir.module_id, ir); + ctx = prepass::prepass_module(ctx, module); } - (ctx.symbol_names, lowered_modules) + // resolve imports + for module in modules { + ctx = prepass::prepass_imports(ctx, module); + } + + for mod_def in modules { + let id = *ctx + .body + .top_level_module_names + .get(&mod_def.name.name) + .expect("module should exist"); + + ctx = lower_module(ctx, mod_def, id); + } + + ctx.body } -fn lower_module(mut ctx: BuildCtx, module: &ast::Module) -> (BuildCtx, ir::ModuleBody) { - let module_id = *ctx.module_name_to_id.get(&module.name.name).unwrap(); - let mut body = ir::ModuleBody { - module_id, - functions: Default::default(), - modules: Default::default(), - span: module.span, - }; +fn lower_module(mut ctx: BuildCtx, module: &ast::Module, id: DefId) -> BuildCtx { + let body = ctx.body.modules.get(&id).unwrap(); - for stmt in &module.contents { - match stmt { - ast::ModuleStatement::Function(func) => { - let next_id = { - let module_ctx = ctx.modules.get_mut(&module_id).unwrap(); - let next_id = module_ctx.gen.next_defid(); - module_ctx - .func_name_to_id - .insert(func.name.name.clone(), next_id); - ctx.symbol_names.insert(next_id, func.name.name.clone()); - next_id + // fill fn sigs + for content in &module.contents { + if let ModuleStatement::Function(fn_def) = content { + let fn_id = *body.symbols.functions.get(&fn_def.name.name).unwrap(); + + let mut args = Vec::new(); + let ret_type; + + for arg in &fn_def.params { + let ty = lower_type(&ctx, &arg.arg_type); + args.push(ty); + } + + if let Some(ty) = &fn_def.return_type { + ret_type = lower_type(&ctx, ty); + } else { + ret_type = TypeInfo { + span: None, + kind: ir::TypeKind::Unit, }; - - let mut args = Vec::new(); - let ret_type; - - if let Some(ret) = func.return_type.as_ref() { - ret_type = lower_type(&mut ctx, ret); - } else { - ret_type = TypeInfo { - span: None, - kind: ir::TypeKind::Unit, - }; - } - - for arg in &func.params { - let ty = lower_type(&mut ctx, &arg.arg_type); - args.push(ty); - } - - let module_ctx = ctx.modules.get_mut(&module_id).unwrap(); - module_ctx.functions.insert(next_id, (args, ret_type)); } - ast::ModuleStatement::Constant(_) => todo!(), - ast::ModuleStatement::Struct(_) => todo!(), - ast::ModuleStatement::Module(_) => {} + + ctx.body.function_signatures.insert(fn_id, (args, ret_type)); } } - for stmt in &module.contents { - match stmt { - ast::ModuleStatement::Function(func) => { - let (res, new_ctx) = lower_function(ctx, func, body.module_id); - body.functions.insert(res.def_id, res); - ctx = new_ctx; + for content in &module.contents { + match content { + ModuleStatement::Constant(_) => todo!(), + ModuleStatement::Function(fn_def) => { + ctx = lower_function(ctx, fn_def, id); } - ast::ModuleStatement::Constant(_) => todo!(), - ast::ModuleStatement::Struct(_) => todo!(), - ast::ModuleStatement::Module(_) => todo!(), + ModuleStatement::Struct(_) => todo!(), + // ModuleStatement::Type(_) => todo!(), + ModuleStatement::Module(_mod_def) => {} } } - (ctx, body) + ctx } -fn lower_function( - mut ctx: BuildCtx, - func: &ast::Function, - module_id: DefId, -) -> (ir::Body, BuildCtx) { - let def_id = *ctx - .modules - .get(&module_id) - .unwrap() - .func_name_to_id - .get(&func.name.name) - .unwrap(); - - let body = ir::Body { - def_id, - ret_type: func - .return_type - .as_ref() - .map(|x| lower_type(&mut ctx, x)) - .unwrap_or_else(|| TypeInfo { - span: None, - kind: ir::TypeKind::Unit, - }), - locals: Default::default(), - blocks: Default::default(), - fn_span: func.span, - is_pub: func.is_public, - is_extern: func.is_extern, - }; - +fn lower_function(ctx: BuildCtx, func: &ast::Function, module_id: DefId) -> BuildCtx { let mut builder = BodyBuilder { - body, - statements: Vec::new(), - locals: HashMap::new(), - ret_local: None, - ctx, + body: Body { + blocks: Default::default(), + locals: Default::default(), + name: func.name.name.clone(), + def_id: { + let body = ctx.body.modules.get(&module_id).unwrap(); + *body.symbols.functions.get(&func.name.name).unwrap() + }, + is_pub: func.is_public, + is_extern: func.is_extern, + fn_span: func.span, + }, local_module: module_id, + ret_local: 0, + name_to_local: HashMap::new(), + statements: Vec::new(), + ctx, }; + let fn_id = builder.body.def_id; + + let (args_ty, ret_ty) = builder + .ctx + .body + .function_signatures + .get(&fn_id) + .unwrap() + .clone(); + // store args ret - if let Some(ret_type) = func.return_type.as_ref() { - let ty = lower_type(&mut builder.ctx, ret_type); + builder.ret_local = builder.body.locals.len(); + builder.body.locals.push(Local::new( + None, + LocalKind::ReturnPointer, + ret_ty.clone(), + None, + false, + )); - let local = Local { - mutable: false, - span: None, - ty, - kind: ir::LocalKind::ReturnPointer, - }; - - builder.ret_local = Some(builder.body.locals.len()); - builder.body.locals.push(local); - } - - for arg in &func.params { - let ty = lower_type(&mut builder.ctx, &arg.arg_type); - let local = Local { - mutable: false, - span: Some(arg.span), - ty, - kind: ir::LocalKind::Arg, - }; + for (arg, ty) in func.params.iter().zip(args_ty) { builder - .locals - .insert(arg.name.name.clone(), builder.locals.len()); - builder.body.locals.push(local); + .name_to_local + .insert(arg.name.name.clone(), builder.body.locals.len()); + builder.body.locals.push(Local::new( + Some(arg.name.span), + LocalKind::Arg, + ty, + Some(arg.name.name.clone()), + false, + )); } + // Get all user defined locals for stmt in &func.body.body { - match stmt { - ast::Statement::Let(info) => lower_let(&mut builder, info), - ast::Statement::Assign(info) => lower_assign(&mut builder, info), - ast::Statement::For(_) => todo!(), - ast::Statement::While(_) => todo!(), - ast::Statement::If(_) => todo!(), - ast::Statement::Return(info) => lower_return(&mut builder, info), - ast::Statement::FnCall(info) => { - lower_fn_call_no_ret(&mut builder, info); - } + if let ast::Statement::Let(info) = stmt { + let ty = lower_type(&builder.ctx, &info.r#type); + builder + .name_to_local + .insert(info.name.name.clone(), builder.body.locals.len()); + builder.body.locals.push(Local::new( + Some(info.name.span), + LocalKind::Temp, + ty, + Some(info.name.name.clone()), + info.is_mut, + )); } } - (builder.body, builder.ctx) + for stmt in &func.body.body { + lower_statement(&mut builder, stmt, &ret_ty); + } + + let (mut ctx, body) = (builder.ctx, builder.body); + ctx.unresolved_function_signatures.remove(&body.def_id); + ctx.body.functions.insert(body.def_id, body); + ctx +} + +fn lower_statement(builder: &mut BodyBuilder, info: &ast::Statement, ret_type: &TypeInfo) { + match info { + ast::Statement::Let(info) => lower_let(builder, info), + ast::Statement::Assign(info) => lower_assign(builder, info), + ast::Statement::For(_) => todo!(), + ast::Statement::While(_) => todo!(), + ast::Statement::If(_) => todo!(), + ast::Statement::Return(info) => lower_return(builder, info, ret_type), + ast::Statement::FnCall(info) => { + lower_fn_call(builder, info); + } + } } fn lower_let(builder: &mut BodyBuilder, info: &ast::LetStmt) { - let ty = lower_type(&mut builder.ctx, &info.r#type); + let ty = lower_type(&builder.ctx, &info.r#type); let rvalue = lower_expr(builder, &info.value, Some(&ty)); - - let local = ir::Local { - mutable: info.is_mut, - span: Some(info.span), - ty: lower_type(&mut builder.ctx, &info.r#type), - kind: ir::LocalKind::Temp, - }; - - let id = builder.body.locals.len(); - builder.locals.insert(info.name.name.clone(), id); - builder.body.locals.push(local); - - builder.statements.push(ir::Statement { - span: Some(info.span), - kind: ir::StatementKind::StorageLive(id), + let local_idx = builder.name_to_local.get(&info.name.name).copied().unwrap(); + builder.statements.push(Statement { + span: Some(info.name.span), + kind: StatementKind::StorageLive(local_idx), }); - builder.statements.push(ir::Statement { - span: Some(info.span), - kind: ir::StatementKind::Assign( + builder.statements.push(Statement { + span: Some(info.name.span), + kind: StatementKind::Assign( Place { - local: id, + local: local_idx, projection: Default::default(), }, rvalue, @@ -225,14 +199,14 @@ fn lower_let(builder: &mut BodyBuilder, info: &ast::LetStmt) { } fn lower_assign(builder: &mut BodyBuilder, info: &ast::AssignStmt) { - let local = *builder.locals.get(&info.name.first.name).unwrap(); + let local = *builder.name_to_local.get(&info.name.first.name).unwrap(); let ty = builder.body.locals[local].ty.clone(); let rvalue = lower_expr(builder, &info.value, Some(&ty)); let place = lower_path(builder, &info.name); builder.statements.push(Statement { - span: Some(info.span), - kind: ir::StatementKind::Assign(place, rvalue), + span: Some(info.name.first.span), + kind: StatementKind::Assign(place, rvalue), }) } @@ -259,261 +233,140 @@ fn lower_binary_expr( type_hint: Option<&TypeInfo>, ) -> ir::RValue { let expr_type = type_hint.expect("type hint needed"); - let lhs = { - let rvalue = lower_expr(builder, lhs, type_hint); - let local = builder.add_local(Local { - mutable: false, - span: None, - ty: expr_type.clone(), - kind: ir::LocalKind::Temp, - }); + let lhs = lower_expr(builder, lhs, type_hint); + let rhs = lower_expr(builder, rhs, type_hint); - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::StorageLive(local), - }); - - let place = Place { - local, - projection: Default::default(), - }; - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::Assign(place.clone(), rvalue), - }); - - place + let local_ty = expr_type; + let lhs_local = builder.add_local(Local::temp(local_ty.clone())); + let rhs_local = builder.add_local(Local::temp(local_ty.clone())); + let lhs_place = Place { + local: lhs_local, + projection: Default::default(), }; - let rhs = { - let rvalue = lower_expr(builder, rhs, type_hint); - let local = builder.add_local(Local { - mutable: false, - span: None, - ty: expr_type.clone(), - kind: ir::LocalKind::Temp, - }); - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::StorageLive(local), - }); - - let place = Place { - local, - projection: Default::default(), - }; - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::Assign(place.clone(), rvalue), - }); - - place - }; - - match op { - ast::BinaryOp::Arith(op, _) => match op { - ast::ArithOp::Add => { - ir::RValue::BinOp(ir::BinOp::Add, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::ArithOp::Sub => { - ir::RValue::BinOp(ir::BinOp::Sub, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::ArithOp::Mul => { - ir::RValue::BinOp(ir::BinOp::Mul, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::ArithOp::Div => { - ir::RValue::BinOp(ir::BinOp::Div, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::ArithOp::Mod => { - ir::RValue::BinOp(ir::BinOp::Rem, Operand::Move(lhs), Operand::Move(rhs)) - } - }, - ast::BinaryOp::Logic(op, _) => match op { - ast::LogicOp::And => { - ir::RValue::LogicOp(ir::LogicalOp::And, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::LogicOp::Or => { - ir::RValue::LogicOp(ir::LogicalOp::Or, Operand::Move(lhs), Operand::Move(rhs)) - } - }, - ast::BinaryOp::Compare(op, _) => match op { - ast::CmpOp::Eq => { - ir::RValue::BinOp(ir::BinOp::Eq, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::CmpOp::NotEq => { - ir::RValue::BinOp(ir::BinOp::Ne, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::CmpOp::Lt => { - ir::RValue::BinOp(ir::BinOp::Lt, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::CmpOp::LtEq => { - ir::RValue::BinOp(ir::BinOp::Le, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::CmpOp::Gt => { - ir::RValue::BinOp(ir::BinOp::Gt, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::CmpOp::GtEq => { - ir::RValue::BinOp(ir::BinOp::Ge, Operand::Move(lhs), Operand::Move(rhs)) - } - }, - ast::BinaryOp::Bitwise(op, _) => match op { - ast::BitwiseOp::And => { - ir::RValue::BinOp(ir::BinOp::BitAnd, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::BitwiseOp::Or => { - ir::RValue::BinOp(ir::BinOp::BitOr, Operand::Move(lhs), Operand::Move(rhs)) - } - ast::BitwiseOp::Xor => { - ir::RValue::BinOp(ir::BinOp::BitXor, Operand::Move(lhs), Operand::Move(rhs)) - } - }, - } -} - -fn lower_fn_call_no_ret(builder: &mut BodyBuilder, info: &ast::FnCallExpr) { - let (arg_types, _ret_type) = builder.get_fn_by_name(&info.name.name).unwrap().clone(); - - let mut args = Vec::new(); - - for (expr, ty) in info.params.iter().zip(arg_types) { - let rvalue = lower_expr(builder, expr, Some(&ty)); - - let local = builder.add_local(Local { - mutable: false, - span: None, - ty, - kind: ir::LocalKind::Temp, - }); - - let place = Place { - local, - projection: Default::default(), - }; - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::StorageLive(local), - }); - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::Assign(place.clone(), rvalue), - }); - - args.push(Operand::Move(place)) - } - - let fn_id = *builder - .get_current_module() - .func_name_to_id - .get(&info.name.name) - .unwrap(); - - let next_block = builder.body.blocks.len() + 1; - - let terminator = Terminator::Call { - func: Operand::Constant(ConstData { - span: Some(info.span), - type_info: TypeInfo { - span: None, - kind: ir::TypeKind::FnDef(fn_id, vec![]), - }, - kind: ConstKind::ZeroSized, - }), - args, - dest: None, - target: Some(next_block), - }; - - let statements = std::mem::take(&mut builder.statements); - - builder.body.blocks.push(ir::BasicBlock { - id: builder.body.blocks.len(), - statements: statements.into(), - terminator, - }); -} - -fn lower_fn_call(builder: &mut BodyBuilder, info: &ast::FnCallExpr) -> ir::Operand { - let (arg_types, ret_type) = builder.get_fn_by_name(&info.name.name).unwrap().clone(); - - let mut args = Vec::new(); - - let target_local = builder.add_local(Local { - mutable: false, - span: None, - ty: ret_type, - kind: ir::LocalKind::Temp, - }); - - let dest_place = Place { - local: target_local, + let rhs_place = Place { + local: lhs_local, projection: Default::default(), }; - for (expr, ty) in info.params.iter().zip(arg_types) { - let rvalue = lower_expr(builder, expr, Some(&ty)); - - let local = builder.add_local(Local { - mutable: false, - span: None, - ty, - kind: ir::LocalKind::Temp, - }); - - let place = Place { - local, - projection: Default::default(), - }; - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::StorageLive(local), - }); - - builder.statements.push(Statement { - span: None, - kind: ir::StatementKind::Assign(place.clone(), rvalue), - }); - - args.push(Operand::Move(place)) - } + builder.statements.push(Statement { + span: None, + kind: StatementKind::StorageLive(lhs_local), + }); builder.statements.push(Statement { span: None, - kind: ir::StatementKind::StorageLive(target_local), + kind: StatementKind::Assign(lhs_place.clone(), lhs), }); - let fn_id = *builder - .get_current_module() - .func_name_to_id - .get(&info.name.name) - .unwrap(); + builder.statements.push(Statement { + span: None, + kind: StatementKind::StorageLive(rhs_local), + }); - let next_block = builder.body.blocks.len() + 1; + builder.statements.push(Statement { + span: None, + kind: StatementKind::Assign(rhs_place.clone(), rhs), + }); - let terminator = Terminator::Call { - func: Operand::Constant(ConstData { - span: Some(info.span), - type_info: TypeInfo { - span: None, - kind: ir::TypeKind::FnDef(fn_id, vec![]), - }, - kind: ConstKind::ZeroSized, - }), + let lhs = Operand::Move(lhs_place); + let rhs = Operand::Move(rhs_place); + + match op { + ast::BinaryOp::Arith(op, _) => match op { + ast::ArithOp::Add => ir::RValue::BinOp(ir::BinOp::Add, lhs, rhs), + ast::ArithOp::Sub => ir::RValue::BinOp(ir::BinOp::Sub, lhs, rhs), + ast::ArithOp::Mul => ir::RValue::BinOp(ir::BinOp::Mul, lhs, rhs), + ast::ArithOp::Div => ir::RValue::BinOp(ir::BinOp::Div, lhs, rhs), + ast::ArithOp::Mod => ir::RValue::BinOp(ir::BinOp::Rem, lhs, rhs), + }, + ast::BinaryOp::Logic(op, _) => match op { + ast::LogicOp::And => ir::RValue::LogicOp(ir::LogicalOp::And, lhs, rhs), + ast::LogicOp::Or => ir::RValue::LogicOp(ir::LogicalOp::Or, lhs, rhs), + }, + ast::BinaryOp::Compare(op, _) => match op { + ast::CmpOp::Eq => ir::RValue::BinOp(ir::BinOp::Eq, lhs, rhs), + ast::CmpOp::NotEq => ir::RValue::BinOp(ir::BinOp::Ne, lhs, rhs), + ast::CmpOp::Lt => ir::RValue::BinOp(ir::BinOp::Lt, lhs, rhs), + ast::CmpOp::LtEq => ir::RValue::BinOp(ir::BinOp::Le, lhs, rhs), + ast::CmpOp::Gt => ir::RValue::BinOp(ir::BinOp::Gt, lhs, rhs), + ast::CmpOp::GtEq => ir::RValue::BinOp(ir::BinOp::Ge, lhs, rhs), + }, + ast::BinaryOp::Bitwise(op, _) => match op { + ast::BitwiseOp::And => ir::RValue::BinOp(ir::BinOp::BitAnd, lhs, rhs), + ast::BitwiseOp::Or => ir::RValue::BinOp(ir::BinOp::BitOr, lhs, rhs), + ast::BitwiseOp::Xor => ir::RValue::BinOp(ir::BinOp::BitXor, lhs, rhs), + }, + } +} + +fn lower_fn_call(builder: &mut BodyBuilder, info: &ast::FnCallExpr) -> Operand { + let fn_id = { + let mod_body = builder.get_module_body(); + + if let Some(id) = mod_body.symbols.functions.get(&info.name.name) { + *id + } else { + *mod_body + .imports + .get(&info.name.name) + .expect("function call not found") + } + }; + let (args_ty, ret_ty) = { + if let Some(x) = builder.ctx.body.function_signatures.get(&fn_id).cloned() { + x + } else { + let (args, ret) = builder + .ctx + .unresolved_function_signatures + .get(&fn_id) + .unwrap(); + + let args: Vec<_> = args.iter().map(|x| lower_type(&builder.ctx, x)).collect(); + let ret = ret + .as_ref() + .map(|x| lower_type(&builder.ctx, x)) + .unwrap_or(TypeInfo { + span: None, + kind: TypeKind::Unit, + }); + builder + .ctx + .body + .function_signatures + .insert(fn_id, (args.clone(), ret.clone())); + (args, ret) + } + }; + + let mut args = Vec::new(); + + for (arg, arg_ty) in info.params.iter().zip(args_ty) { + let rvalue = lower_expr(builder, arg, Some(&arg_ty)); + args.push(rvalue); + } + + let dest_local = builder.add_local(Local::temp(ret_ty)); + + let dest_place = Place { + local: dest_local, + projection: Default::default(), + }; + + let target_block = builder.body.blocks.len() + 1; + + // todo: check if function is diverging such as exit(). + let kind = Terminator::Call { + func: fn_id, args, - dest: Some(dest_place.clone()), - target: Some(next_block), + destination: dest_place.clone(), + target: Some(target_block), }; let statements = std::mem::take(&mut builder.statements); - - builder.body.blocks.push(ir::BasicBlock { - id: builder.body.blocks.len(), + builder.body.blocks.push(BasicBlock { statements: statements.into(), - terminator, + terminator: kind, }); Operand::Move(dest_place) @@ -542,51 +395,61 @@ fn lower_value( kind: ir::ConstKind::Value(ir::ValueTree::Leaf(ir::ConstValue::U32((*value) as u32))), }), ast::ValueExpr::Int { value, span } => { - let (ty, val) = match type_hint { + let (ty, val, type_span) = match type_hint { Some(type_hint) => match &type_hint.kind { - ir::TypeKind::Int(type_hint) => match type_hint { + ir::TypeKind::Int(int_type) => match int_type { ir::IntTy::I128 => ( ir::TypeKind::Int(ir::IntTy::I128), ir::ConstValue::I128((*value) as i128), + type_hint.span, ), ir::IntTy::I64 => ( ir::TypeKind::Int(ir::IntTy::I64), ir::ConstValue::I64((*value) as i64), + type_hint.span, ), ir::IntTy::I32 => ( ir::TypeKind::Int(ir::IntTy::I32), ir::ConstValue::I32((*value) as i32), + type_hint.span, ), ir::IntTy::I16 => ( ir::TypeKind::Int(ir::IntTy::I16), ir::ConstValue::I16((*value) as i16), + type_hint.span, ), ir::IntTy::I8 => ( ir::TypeKind::Int(ir::IntTy::I8), ir::ConstValue::I8((*value) as i8), + type_hint.span, ), ir::IntTy::Isize => todo!(), }, - ir::TypeKind::Uint(type_hint) => match type_hint { + ir::TypeKind::Uint(int_type) => match int_type { ir::UintTy::U128 => ( ir::TypeKind::Uint(ir::UintTy::U128), ir::ConstValue::U128(*value), + type_hint.span, ), ir::UintTy::U64 => ( ir::TypeKind::Uint(ir::UintTy::U64), ir::ConstValue::U64((*value) as u64), + type_hint.span, ), ir::UintTy::U32 => ( ir::TypeKind::Uint(ir::UintTy::U32), ir::ConstValue::U32((*value) as u32), + type_hint.span, ), ir::UintTy::U16 => ( ir::TypeKind::Uint(ir::UintTy::U16), ir::ConstValue::U16((*value) as u16), + type_hint.span, ), ir::UintTy::U8 => ( ir::TypeKind::Uint(ir::UintTy::U8), ir::ConstValue::U8((*value) as u8), + type_hint.span, ), _ => todo!(), }, @@ -598,14 +461,41 @@ fn lower_value( ir::Operand::Constant(ir::ConstData { span: Some(*span), type_info: ir::TypeInfo { - span: None, + span: type_span, kind: ty, }, kind: ir::ConstKind::Value(ir::ValueTree::Leaf(val)), }) } - ast::ValueExpr::Float { value, span } => todo!(), - ast::ValueExpr::Str { value, span } => todo!(), + ast::ValueExpr::Float { value, span } => match type_hint { + Some(type_hint) => match &type_hint.kind { + TypeKind::Float(float_ty) => match float_ty { + ir::FloatTy::F32 => ir::Operand::Constant(ir::ConstData { + span: Some(*span), + type_info: ir::TypeInfo { + span: type_hint.span, + kind: ir::TypeKind::Float(ir::FloatTy::F32), + }, + kind: ir::ConstKind::Value(ir::ValueTree::Leaf(ir::ConstValue::F32( + value.parse().unwrap(), + ))), + }), + ir::FloatTy::F64 => ir::Operand::Constant(ir::ConstData { + span: Some(*span), + type_info: ir::TypeInfo { + span: type_hint.span, + kind: ir::TypeKind::Float(ir::FloatTy::F64), + }, + kind: ir::ConstKind::Value(ir::ValueTree::Leaf(ir::ConstValue::F64( + value.parse().unwrap(), + ))), + }), + }, + _ => unreachable!(), + }, + None => todo!(), + }, + ast::ValueExpr::Str { value: _, span: _ } => todo!(), ast::ValueExpr::Path(info) => { // add deref info to path Operand::Move(lower_path(builder, info)) @@ -613,36 +503,31 @@ fn lower_value( } } -fn lower_return(builder: &mut BodyBuilder, info: &ast::ReturnStmt) { - let ret_type = builder.body.ret_type.clone(); - if let Some(value) = &info.value { - let rvalue = lower_expr(builder, value, Some(&ret_type)); - let ret_local = builder.ret_local.unwrap(); - +fn lower_return(builder: &mut BodyBuilder, info: &ast::ReturnStmt, return_type: &TypeInfo) { + if let Some(value_expr) = &info.value { + let value = lower_expr(builder, value_expr, Some(return_type)); builder.statements.push(Statement { - span: Some(info.span), - kind: ir::StatementKind::Assign( + span: None, + kind: StatementKind::Assign( Place { - local: ret_local, + local: builder.ret_local, projection: Default::default(), }, - rvalue, + value, ), - }) + }); } let statements = std::mem::take(&mut builder.statements); - - builder.body.blocks.push(ir::BasicBlock { - id: builder.body.blocks.len(), + builder.body.blocks.push(BasicBlock { statements: statements.into(), - terminator: ir::Terminator::Return, + terminator: Terminator::Return, }); } fn lower_path(builder: &mut BodyBuilder, info: &ast::PathExpr) -> ir::Place { let local = *builder - .locals + .name_to_local .get(&info.first.name) .expect("local not found"); @@ -652,7 +537,7 @@ fn lower_path(builder: &mut BodyBuilder, info: &ast::PathExpr) -> ir::Place { } } -pub fn lower_type(_ctx: &mut BuildCtx, t: &ast::Type) -> ir::TypeInfo { +pub fn lower_type(_ctx: &BuildCtx, t: &ast::Type) -> ir::TypeInfo { match t.name.name.as_str() { "()" => ir::TypeInfo { span: Some(t.span), diff --git a/lib/edlang_lowering/src/prepass.rs b/lib/edlang_lowering/src/prepass.rs new file mode 100644 index 000000000..be65f6ecc --- /dev/null +++ b/lib/edlang_lowering/src/prepass.rs @@ -0,0 +1,316 @@ +use std::collections::HashMap; + +use crate::DefId; + +use super::common::BuildCtx; +use edlang_ast as ast; +use edlang_ir::ModuleBody; + +pub fn prepass_module(mut ctx: BuildCtx, mod_def: &ast::Module) -> BuildCtx { + let module_id = ctx.gen.next_defid(); + tracing::debug!("running ir prepass on module {:?}", module_id); + + ctx.body + .top_level_module_names + .insert(mod_def.name.name.clone(), module_id); + + ctx.body.modules.insert( + module_id, + ModuleBody { + module_id, + parent_ids: vec![], + name: mod_def.name.name.clone(), + symbols: Default::default(), + modules: Default::default(), + functions: Default::default(), + structs: Default::default(), + types: Default::default(), + constants: Default::default(), + imports: Default::default(), + span: mod_def.span, + }, + ); + + { + let mut gen = ctx.gen; + let current_module = ctx + .body + .modules + .get_mut(&module_id) + .expect("module should exist"); + + for ct in &mod_def.contents { + match ct { + ast::ModuleStatement::Constant(info) => { + let next_id = gen.next_defid(); + current_module + .symbols + .constants + .insert(info.name.name.clone(), next_id); + current_module.constants.insert(next_id); + } + ast::ModuleStatement::Function(info) => { + let next_id = gen.next_defid(); + current_module + .symbols + .functions + .insert(info.name.name.clone(), next_id); + current_module.functions.insert(next_id); + ctx.unresolved_function_signatures.insert( + next_id, + ( + info.params.iter().map(|x| &x.arg_type).cloned().collect(), + info.return_type.clone(), + ), + ); + } + ast::ModuleStatement::Struct(info) => { + let next_id = gen.next_defid(); + current_module + .symbols + .structs + .insert(info.name.name.clone(), next_id); + current_module.structs.insert(next_id); + } + /* + ast::ModuleStatement::Type(info) => { + let next_id = gen.next_defid(); + current_module + .symbols + .types + .insert(info.name.name.clone(), next_id); + current_module.types.insert(next_id); + } + */ + ast::ModuleStatement::Module(info) => { + let next_id = gen.next_defid(); + current_module + .symbols + .modules + .insert(info.name.name.clone(), next_id); + current_module.modules.insert(next_id); + } + } + } + + ctx.gen = gen; + } + + for ct in &mod_def.contents { + if let ast::ModuleStatement::Module(info) = ct { + let current_module = ctx + .body + .modules + .get_mut(&module_id) + .expect("module should exist"); + + let next_id = *current_module.symbols.modules.get(&info.name.name).unwrap(); + ctx = prepass_sub_module(ctx, &[module_id], next_id, info); + } + } + + ctx +} + +pub fn prepass_sub_module( + mut ctx: BuildCtx, + parent_ids: &[DefId], + module_id: DefId, + mod_def: &ast::Module, +) -> BuildCtx { + tracing::debug!("running ir prepass on submodule {:?}", module_id); + let mut submodule_parents_ids = parent_ids.to_vec(); + submodule_parents_ids.push(module_id); + + { + let mut gen = ctx.gen; + let mut submodule = ModuleBody { + module_id, + name: mod_def.name.name.clone(), + parent_ids: parent_ids.to_vec(), + imports: Default::default(), + symbols: Default::default(), + modules: Default::default(), + functions: Default::default(), + structs: Default::default(), + types: Default::default(), + constants: Default::default(), + span: mod_def.span, + }; + + for ct in &mod_def.contents { + match ct { + ast::ModuleStatement::Constant(info) => { + let next_id = gen.next_defid(); + submodule + .symbols + .constants + .insert(info.name.name.clone(), next_id); + submodule.constants.insert(next_id); + } + ast::ModuleStatement::Function(info) => { + let next_id = gen.next_defid(); + submodule + .symbols + .functions + .insert(info.name.name.clone(), next_id); + submodule.functions.insert(next_id); + ctx.unresolved_function_signatures.insert( + next_id, + ( + info.params.iter().map(|x| &x.arg_type).cloned().collect(), + info.return_type.clone(), + ), + ); + } + ast::ModuleStatement::Struct(info) => { + let next_id = gen.next_defid(); + submodule + .symbols + .structs + .insert(info.name.name.clone(), next_id); + submodule.structs.insert(next_id); + } + /* + ast::ModuleStatement::Type(info) => { + let next_id = gen.next_defid(); + submodule + .symbols + .types + .insert(info.name.name.clone(), next_id); + submodule.types.insert(next_id); + } + */ + ast::ModuleStatement::Module(info) => { + let next_id = gen.next_defid(); + submodule + .symbols + .modules + .insert(info.name.name.clone(), next_id); + submodule.modules.insert(next_id); + } + } + } + + ctx.gen = gen; + + ctx.body.modules.insert(module_id, submodule); + } + + for ct in &mod_def.contents { + if let ast::ModuleStatement::Module(info) = ct { + let next_id = ctx.gen.next_defid(); + ctx = prepass_sub_module(ctx, &submodule_parents_ids, next_id, info); + } + } + + ctx +} + +pub fn prepass_imports(mut ctx: BuildCtx, mod_def: &ast::Module) -> BuildCtx { + let mod_id = *ctx + .body + .top_level_module_names + .get(&mod_def.name.name) + .unwrap(); + + for import in &mod_def.imports { + let imported_module_id = ctx + .body + .top_level_module_names + .get(&import.module[0].name) + .expect("import module not found"); + let mut imported_module = ctx.body.modules.get(imported_module_id).unwrap(); + + for x in import.module.iter().skip(1) { + let imported_module_id = imported_module.symbols.modules.get(&x.name).unwrap(); + imported_module = ctx.body.modules.get(imported_module_id).unwrap(); + } + + let mut imports = HashMap::new(); + + for sym in &import.symbols { + if let Some(id) = imported_module.symbols.functions.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.structs.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.types.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.constants.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else { + panic!("import symbol not found") + } + } + + ctx.body + .modules + .get_mut(&mod_id) + .unwrap() + .imports + .extend(imports); + } + + for c in &mod_def.contents { + if let ast::ModuleStatement::Module(info) = c { + ctx = prepass_imports_submodule(ctx, info, mod_id); + } + } + + ctx +} + +pub fn prepass_imports_submodule( + mut ctx: BuildCtx, + mod_def: &ast::Module, + parent_id: DefId, +) -> BuildCtx { + let mod_id = *ctx + .body + .modules + .get(&parent_id) + .unwrap() + .symbols + .modules + .get(&mod_def.name.name) + .unwrap(); + + for import in &mod_def.imports { + let imported_module_id = ctx + .body + .top_level_module_names + .get(&import.module[0].name) + .expect("import module not found"); + let mut imported_module = ctx.body.modules.get(imported_module_id).unwrap(); + + for x in import.module.iter().skip(1) { + let imported_module_id = imported_module.symbols.modules.get(&x.name).unwrap(); + imported_module = ctx.body.modules.get(imported_module_id).unwrap(); + } + + let mut imports = HashMap::new(); + + for sym in &import.symbols { + if let Some(id) = imported_module.symbols.functions.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.structs.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.types.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else if let Some(id) = imported_module.symbols.constants.get(&sym.name) { + imports.insert(sym.name.clone(), *id); + } else { + panic!("import symbol not found") + } + } + + ctx.body + .modules + .get_mut(&mod_id) + .unwrap() + .imports + .extend(imports); + } + + ctx +} diff --git a/lib/edlang_parser/src/grammar.lalrpop b/lib/edlang_parser/src/grammar.lalrpop index 919fe39e3..dcc75feed 100644 --- a/lib/edlang_parser/src/grammar.lalrpop +++ b/lib/edlang_parser/src/grammar.lalrpop @@ -45,6 +45,7 @@ extern { "=" => Token::Assign, ";" => Token::Semicolon, ":" => Token::Colon, + "::" => Token::DoubleColon, "->" => Token::Arrow, "," => Token::Coma, "<" => Token::LessThanSign, @@ -104,6 +105,16 @@ SemiColon: Vec = { } }; +DoubleColon: Vec = { + "::")*> => match e { + None => v, + Some(e) => { + v.push(e); + v + } + } +}; + PlusSeparated: Vec = { "+")*> => match e { None => v, @@ -383,8 +394,8 @@ pub(crate) Struct: ast::Struct = { } pub(crate) Import: ast::Import = { - "use" > "}")?> ";" => ast::Import { - path, + "use" > > "}")?> ";" => ast::Import { + module, symbols: symbols.unwrap_or(vec![]), span: ast::Span::new(lo, hi), } diff --git a/lib/edlang_parser/src/tokens.rs b/lib/edlang_parser/src/tokens.rs index 781ed6bdb..aaff73247 100644 --- a/lib/edlang_parser/src/tokens.rs +++ b/lib/edlang_parser/src/tokens.rs @@ -86,6 +86,8 @@ pub enum Token { Semicolon, #[token(":")] Colon, + #[token("::")] + DoubleColon, #[token("->")] Arrow, #[token(",")]