From e5ea36a25cdfdb9f72c8ec802edda572e0837f35 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sat, 3 Feb 2024 12:06:16 +0100 Subject: [PATCH] fn call --- lib/edlang_driver/src/lib.rs | 7 +- lib/edlang_ir/src/lib.rs | 5 +- lib/edlang_lowering/src/common.rs | 89 +++++++++++++ lib/edlang_lowering/src/lib.rs | 213 +++++++++++++++++++++++------- programs/simple.ed | 5 + 5 files changed, 262 insertions(+), 57 deletions(-) create mode 100644 lib/edlang_lowering/src/common.rs diff --git a/lib/edlang_driver/src/lib.rs b/lib/edlang_driver/src/lib.rs index 0df95afdc..271e2b2a5 100644 --- a/lib/edlang_driver/src/lib.rs +++ b/lib/edlang_driver/src/lib.rs @@ -3,7 +3,7 @@ use std::{error::Error, path::PathBuf, time::Instant}; use ariadne::Source; use clap::Parser; use edlang_codegen_mlir::linker::{link_binary, link_shared_lib}; -use edlang_lowering::{lower_module, IdGenerator}; +use edlang_lowering::lower_modules; use edlang_session::{DebugInfo, OptLevel, Session}; #[derive(Parser, Debug)] @@ -85,11 +85,10 @@ pub fn main() -> Result<(), Box> { return Ok(()); } - let mut gen = IdGenerator::new(0); - let ir = lower_module(&mut gen, &module); + let module_irs = lower_modules(&[module.clone()]); if args.ir { - println!("{:#?}", ir); + println!("{:#?}", module_irs); return Ok(()); } diff --git a/lib/edlang_ir/src/lib.rs b/lib/edlang_ir/src/lib.rs index f15fffbee..5109da80a 100644 --- a/lib/edlang_ir/src/lib.rs +++ b/lib/edlang_ir/src/lib.rs @@ -84,7 +84,6 @@ pub enum Terminator { args: Vec, dest: Place, target: Option, // block - fn_span: Span, }, Unreachable, } @@ -97,12 +96,13 @@ pub struct TypeInfo { #[derive(Debug, Clone)] pub enum TypeKind { + Unit, Bool, Char, Int(IntTy), Uint(UintTy), Float(FloatTy), - FuncDef { name: String, args: Vec }, + FnDef(DefId, Vec), // The vec are generic types, not arg types } #[derive(Debug, Clone)] @@ -141,6 +141,7 @@ pub struct ConstData { #[derive(Debug, Clone)] pub enum ConstKind { Value(ValueTree), + ZeroSized, } #[derive(Debug, Clone)] diff --git a/lib/edlang_lowering/src/common.rs b/lib/edlang_lowering/src/common.rs new file mode 100644 index 000000000..9f0f7fc28 --- /dev/null +++ b/lib/edlang_lowering/src/common.rs @@ -0,0 +1,89 @@ +use std::collections::HashMap; + +use edlang_ir::{Body, DefId, Local, Statement, TypeInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct IdGenerator { + pub current_id: usize, + pub module_id: usize, +} + +impl IdGenerator { + pub const fn new(module_id: usize) -> Self { + Self { + current_id: 0, + module_id: 0, + } + } + + pub fn next_id(&mut self) -> usize { + self.current_id += 1; + self.current_id + } + + pub fn next_defid(&mut self) -> DefId { + let id = self.next_id(); + + DefId { + module_id: self.module_id, + id, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct BuildCtx { + pub module_name_to_id: HashMap, + pub modules: HashMap, + pub functions: HashMap, + pub module_id_counter: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct ModuleCtx { + pub id: usize, + pub func_name_to_id: HashMap, + pub functions: HashMap, TypeInfo)>, + pub gen: IdGenerator, +} + +#[derive(Debug, Clone)] +pub struct BodyBuilder { + pub local_module: usize, + pub body: Body, + pub statements: Vec, + pub locals: HashMap, + pub ret_local: Option, + pub ctx: BuildCtx, +} + +impl BodyBuilder { + pub fn add_local(&mut self, local: Local) -> usize { + let id = self.body.locals.len(); + self.body.locals.push(local); + id + } + pub fn get_local(&self, name: &str) -> Option<&Local> { + self.body.locals.get(*(self.locals.get(name)?)) + } + + pub fn get_current_module(&self) -> &ModuleCtx { + self.ctx + .modules + .get(&self.local_module) + .expect("current module should exist") + } + + pub fn get_current_module_mut(&mut self) -> &mut ModuleCtx { + self.ctx + .modules + .get_mut(&self.local_module) + .expect("current module should exist") + } + + pub fn get_fn_by_name(&self, name: &str) -> Option<&(Vec, TypeInfo)> { + let id = self.get_current_module().func_name_to_id.get(name)?; + let f = self.get_current_module().functions.get(&id)?; + Some(f) + } +} diff --git a/lib/edlang_lowering/src/lib.rs b/lib/edlang_lowering/src/lib.rs index 8ed67e3b5..c3d440624 100644 --- a/lib/edlang_lowering/src/lib.rs +++ b/lib/edlang_lowering/src/lib.rs @@ -1,71 +1,90 @@ use std::collections::HashMap; -use ast::Function; +use common::{BodyBuilder, BuildCtx, IdGenerator, ModuleCtx}; use edlang_ast as ast; use edlang_ir as ir; -use ir::{DefId, Local, Operand, Place, Statement, TypeInfo}; +use ir::{ConstData, ConstKind, Local, Operand, Place, Statement, Terminator, TypeInfo}; -pub struct IdGenerator { - pub current_id: usize, - pub module_id: usize, -} +mod common; -impl IdGenerator { - pub const fn new(module_id: usize) -> Self { - Self { - current_id: 0, - module_id: 0, - } +pub fn lower_modules(modules: &[ast::Module]) -> Vec { + let mut ctx = BuildCtx::default(); + + for m in modules { + ctx.module_name_to_id + .insert(m.name.name.clone(), ctx.module_id_counter); + ctx.module_id_counter += 1; } - pub fn next_id(&mut self) -> usize { - self.current_id += 1; - self.current_id + let mut lowered_modules = Vec::with_capacity(modules.len()); + + // todo: maybe should do a prepass here populating all symbols + + for module in modules { + let ir; + (ctx, ir) = lower_module(ctx, module); + lowered_modules.push(ir); } - pub fn next_defid(&mut self) -> DefId { - let id = self.next_id(); - - DefId { - module_id: self.module_id, - id, - } - } + lowered_modules } -struct BuildCtx { - pub func_ids: HashMap, - pub functions: HashMap, - pub gen: IdGenerator, -} - -pub fn lower_module(gen: &mut IdGenerator, module: &ast::Module) -> ir::ModuleBody { +fn lower_module(mut ctx: BuildCtx, module: &ast::Module) -> (BuildCtx, ir::ModuleBody) { let mut body = ir::ModuleBody { - module_id: gen.next_id(), + module_id: ctx.module_id_counter, functions: Default::default(), modules: Default::default(), span: module.span, }; + ctx.module_id_counter += 1; - let mut ctx = BuildCtx { - func_ids: Default::default(), - functions: Default::default(), + let mut module_ctx = ModuleCtx { + id: body.module_id, gen: IdGenerator::new(body.module_id), + ..Default::default() }; for stmt in &module.contents { match stmt { ast::ModuleStatement::Function(func) => { - ctx.func_ids.insert(func.name.name.clone(), ctx.gen.next_defid()); + let next_id = module_ctx.gen.next_defid(); + module_ctx + .func_name_to_id + .insert(func.name.name.clone(), next_id); + + let mut args = Vec::new(); + let ret_type; + + if let Some(ret) = func.return_type.as_ref() { + ret_type = lower_type(&mut ctx, ret); + } else { + ret_type = TypeInfo { + span: None, + kind: ir::TypeKind::Unit, + }; + } + + for arg in &func.params { + let ty = lower_type(&mut ctx, &arg.arg_type); + args.push(ty); + } + + module_ctx.functions.insert(next_id, (args, ret_type)); } - _ => {} + ast::ModuleStatement::Constant(_) => todo!(), + ast::ModuleStatement::Struct(_) => todo!(), + ast::ModuleStatement::Module(_) => todo!(), } } + ctx.module_name_to_id + .insert(module.name.name.clone(), body.module_id); + ctx.modules.insert(body.module_id, module_ctx); + for stmt in &module.contents { match stmt { ast::ModuleStatement::Function(func) => { - let (res, new_ctx) = lower_function(ctx, func); + let (res, new_ctx) = lower_function(ctx, func, body.module_id); body.functions.insert(res.def_id, res); ctx = new_ctx; } @@ -75,20 +94,24 @@ pub fn lower_module(gen: &mut IdGenerator, module: &ast::Module) -> ir::ModuleBo } } - body + (ctx, body) } -struct BodyBuilder { - pub body: ir::Body, - pub statements: Vec, - pub locals: HashMap, - pub ret_local: Option, - pub ctx: BuildCtx, -} +fn lower_function( + mut ctx: BuildCtx, + func: &ast::Function, + module_id: usize, +) -> (ir::Body, BuildCtx) { + let def_id = *ctx + .modules + .get(&module_id) + .unwrap() + .func_name_to_id + .get(&func.name.name) + .unwrap(); -fn lower_function(mut ctx: BuildCtx, func: &ast::Function) -> (ir::Body, BuildCtx) { let body = ir::Body { - def_id: *ctx.func_ids.get(&func.name.name).unwrap(), + def_id, ret_type: func.return_type.as_ref().map(|x| lower_type(&mut ctx, x)), locals: Default::default(), blocks: Default::default(), @@ -102,7 +125,8 @@ fn lower_function(mut ctx: BuildCtx, func: &ast::Function) -> (ir::Body, BuildCt statements: Vec::new(), locals: HashMap::new(), ret_local: None, - ctx + ctx, + local_module: module_id, }; // store args ret @@ -189,7 +213,7 @@ fn lower_assign(builder: &mut BodyBuilder, info: &ast::AssignStmt) { builder.statements.push(Statement { span: Some(info.span), - kind: ir::StatementKind::Assign(place, rvalue) + kind: ir::StatementKind::Assign(place, rvalue), }) } @@ -200,12 +224,95 @@ fn lower_expr( ) -> ir::RValue { match info { ast::Expression::Value(info) => ir::RValue::Use(lower_value(builder, info, type_hint)), - ast::Expression::FnCall(_) => todo!(), + ast::Expression::FnCall(info) => ir::RValue::Use(lower_fn_call(builder, info)), ast::Expression::Unary(_, _) => todo!(), ast::Expression::Binary(_, _, _) => todo!(), } } +fn lower_fn_call(builder: &mut BodyBuilder, info: &ast::FnCallExpr) -> ir::Operand { + let (arg_types, ret_type) = builder.get_fn_by_name(&info.name.name).unwrap().clone(); + + let mut args = Vec::new(); + + let target_local = builder.add_local(Local { + mutable: false, + span: None, + ty: ret_type, + kind: ir::LocalKind::Temp, + }); + + let dest_place = Place { + local: target_local, + projection: Default::default(), + }; + + for (expr, ty) in info.params.iter().zip(arg_types) { + let rvalue = lower_expr(builder, expr, Some(&ty)); + + let local = builder.add_local(Local { + mutable: false, + span: None, + ty, + kind: ir::LocalKind::Temp, + }); + + let place = Place { + local, + projection: Default::default(), + }; + + builder.statements.push(Statement { + span: None, + kind: ir::StatementKind::StorageLive(local), + }); + + builder.statements.push(Statement { + span: None, + kind: ir::StatementKind::Assign(place.clone(), rvalue), + }); + + args.push(Operand::Move(place)) + } + + builder.statements.push(Statement { + span: None, + kind: ir::StatementKind::StorageLive(target_local), + }); + + let fn_id = *builder + .get_current_module() + .func_name_to_id + .get(&info.name.name) + .unwrap(); + + let next_block = builder.body.blocks.len() + 1; + + let terminator = Terminator::Call { + func: Operand::Constant(ConstData { + span: Some(info.span), + type_info: TypeInfo { + span: None, + kind: ir::TypeKind::FnDef(fn_id, vec![]), + }, + kind: ConstKind::ZeroSized, + }), + args, + dest: dest_place.clone(), + target: Some(next_block), + }; + + let statements = std::mem::take(&mut builder.statements); + + builder.body.blocks.push(ir::BasicBlock { + id: builder.body.blocks.len(), + statements: statements.into(), + terminator, + }); + + Operand::Move(dest_place) +} + fn lower_value( builder: &mut BodyBuilder, info: &ast::ValueExpr, @@ -339,8 +446,12 @@ fn lower_path(builder: &mut BodyBuilder, info: &ast::PathExpr) -> ir::Place { } } -pub fn lower_type(gen: &mut BuildCtx, t: &ast::Type) -> ir::TypeInfo { +pub fn lower_type(ctx: &mut BuildCtx, t: &ast::Type) -> ir::TypeInfo { match t.name.name.as_str() { + "()" => ir::TypeInfo { + span: Some(t.span), + kind: ir::TypeKind::Unit, + }, "u8" => ir::TypeInfo { span: Some(t.span), kind: ir::TypeKind::Uint(ir::UintTy::U8), diff --git a/programs/simple.ed b/programs/simple.ed index dbb5594c9..bbe5f6b47 100644 --- a/programs/simple.ed +++ b/programs/simple.ed @@ -3,6 +3,11 @@ mod Main { fn main(argc: i32) -> i32 { let mut x: i32 = 2; x = 4; + let y: i32 = other(2); return x; } + + fn other(a: i32) -> i32 { + return a; + } }