From 81b57d646d49134a8ae08c41087aa2b499ce4c21 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sat, 20 May 2023 10:54:25 +0200 Subject: [PATCH] rly basic type inference --- simple.ed | 4 +- src/ast/mod.rs | 7 +- src/codegen.rs | 171 ++++++++++++++++++++++++++++++++------------ src/grammar.lalrpop | 8 ++- src/main.rs | 4 +- 5 files changed, 140 insertions(+), 54 deletions(-) diff --git a/simple.ed b/simple.ed index 9eadde432..6e4d63acb 100644 --- a/simple.ed +++ b/simple.ed @@ -1,5 +1,5 @@ -fn main(x: i32, z: i32) -> i32 { - let y = 0; +fn main(x: i64, z: i64) -> i64 { + let y: i64 = 0; if x == 5 { if x == z { y = 2 * x; diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9d39049c1..a9e94d996 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -86,7 +86,12 @@ impl Function { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Statement { - Variable { + Let { + name: String, + value: Box, + type_name: Option, + }, + Mutate { name: String, value: Box, }, diff --git a/src/codegen.rs b/src/codegen.rs index d2d38417c..8f31da497 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -36,10 +36,14 @@ pub struct CodeGen<'ctx> { context: &'ctx Context, pub module: Module<'ctx>, builder: Builder<'ctx>, + fn_types: VariableTypes<'ctx>, _program: ProgramData, ast: ast::Program, } +type Variables<'ctx> = HashMap, usize)>; +type VariableTypes<'ctx> = HashMap>; + impl<'ctx> CodeGen<'ctx> { pub fn new( context: &'ctx Context, @@ -55,29 +59,38 @@ impl<'ctx> CodeGen<'ctx> { builder: context.create_builder(), _program, ast, + fn_types: HashMap::new(), }; Ok(codegen) } - pub fn compile_ast(&self) -> Result<()> { + pub fn compile_ast(&mut self) -> Result<()> { let mut functions = vec![]; + let mut types: VariableTypes<'ctx> = HashMap::new(); // todo fix the grammar so top level statements are only functions and static vars. // create the llvm functions first. + for statement in &self.ast.statements { match &statement { - Statement::Variable { .. } => unreachable!(), + Statement::Let { .. } => unreachable!(), + Statement::Mutate { .. } => unreachable!(), Statement::Return(_) => unreachable!(), Statement::If { .. } => unreachable!(), Statement::Function(function) => { functions.push(function); - self.compile_function_signature(function)?; + let ret_type = self.compile_function_signature(function)?; + if let Some(ret_type) = ret_type { + types.insert(function.name.clone(), ret_type); + } } } } + self.fn_types = types; + // implement them. for function in functions { self.compile_function(function)?; @@ -104,7 +117,10 @@ impl<'ctx> CodeGen<'ctx> { } /// creates the llvm function without the body, so other function bodies can call it. - fn compile_function_signature(&self, function: &Function) -> Result<()> { + fn compile_function_signature( + &self, + function: &Function, + ) -> Result>> { let args_types: Vec> = function .params .iter() @@ -116,13 +132,16 @@ impl<'ctx> CodeGen<'ctx> { args_types.into_iter().map(|t| t.into()).collect_vec(); let fn_type = match &function.return_type { - Some(id) => self.get_llvm_type(id)?.fn_type(&args_types, false), + Some(id) => { + let return_type = self.get_llvm_type(id)?; + return_type.fn_type(&args_types, false) + } None => self.context.void_type().fn_type(&args_types, false), }; self.module.add_function(&function.name, fn_type, None); - Ok(()) + Ok(fn_type.get_return_type()) } fn compile_function(&self, function: &Function) -> Result<()> { @@ -131,18 +150,16 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(entry_block); - let mut variables: HashMap, usize)> = HashMap::new(); + let mut variables: Variables = HashMap::new(); + let mut types: VariableTypes = HashMap::new(); for (i, param) in function.params.iter().enumerate() { let id = param.ident.clone(); - variables.insert( - id.clone(), - ( - func.get_nth_param(i.try_into().unwrap()) - .expect("parameter"), - 0, - ), - ); + let param = func + .get_nth_param(i.try_into().unwrap()) + .expect("parameter"); + variables.insert(id.clone(), (param, 0)); + types.insert(id.clone(), param.get_type()); } let mut has_return = false; @@ -151,7 +168,7 @@ impl<'ctx> CodeGen<'ctx> { if let Statement::Return(_) = statement { has_return = true } - self.compile_statement(statement, &mut variables)?; + self.compile_statement(statement, &mut variables, &mut types)?; } if !has_return { @@ -161,30 +178,70 @@ impl<'ctx> CodeGen<'ctx> { Ok(()) } + fn find_expr_type( + &self, + expr: &Expression, + types: &VariableTypes<'ctx>, + ) -> Option> { + match expr { + Expression::Literal(x) => match x { + LiteralValue::String => todo!(), + LiteralValue::Integer { + bits, + signed, + value, + } => bits.map(|bits| self.context.custom_width_int_type(bits).into()), + }, + Expression::Variable(x) => types.get(x).cloned(), + Expression::Call { function, args } => types.get(function).cloned(), + Expression::BinaryOp(lhs, op, rhs) => self + .find_expr_type(lhs, types) + .or_else(|| self.find_expr_type(rhs, types)), + } + } + fn compile_statement( &self, statement: &Statement, // value, assignments - variables: &mut HashMap, usize)>, + variables: &mut Variables<'ctx>, + types: &mut VariableTypes<'ctx>, ) -> Result<()> { match statement { // Variable assignment - Statement::Variable { name, value } => { + Statement::Let { + name, + value, + type_name, + } => { + let type_hint = if let Some(type_name) = type_name { + self.get_llvm_type(type_name)? + } else { + self.find_expr_type(value, types) + .expect("type should be found") + }; + types.insert(name.clone(), type_hint); + let result = self - .compile_expression(value, variables)? + .compile_expression(value, variables, types, Some(type_hint))? .expect("should have result"); - let accesses = if let Some(x) = variables.get(name) { - x.1 + 1 - } else { - 0 - }; - variables.insert(name.clone(), (result, accesses)); + variables.insert(name.clone(), (result, 0)); + } + Statement::Mutate { name, value } => { + let type_hint = *types.get(name).expect("should exist"); + let result = self + .compile_expression(value, variables, types, Some(type_hint))? + .expect("should have result"); + + let (old_val, acc) = variables.get(name).expect("variable should exist"); + variables.insert(name.clone(), (result, acc + 1)); } Statement::Return(ret) => { if let Some(ret) = ret { + let type_hint = self.find_expr_type(ret, types); let result = self - .compile_expression(ret, variables)? + .compile_expression(ret, variables, types, type_hint)? .expect("should have result"); self.builder.build_return(Some(&result)); } else { @@ -196,8 +253,9 @@ impl<'ctx> CodeGen<'ctx> { body, else_body, } => { + let type_hint_cond = self.find_expr_type(condition, types); let condition = self - .compile_expression(condition, variables)? + .compile_expression(condition, variables, types, type_hint_cond)? .expect("should produce a value"); let func = self @@ -224,7 +282,7 @@ impl<'ctx> CodeGen<'ctx> { let mut variables_if = variables.clone(); self.builder.position_at_end(if_block); for s in body { - self.compile_statement(s, &mut variables_if)?; + self.compile_statement(s, &mut variables_if, types)?; } self.builder.build_unconditional_branch(merge_block); if_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -234,7 +292,7 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(else_block); for s in else_body { - self.compile_statement(s, &mut variables_else)?; + self.compile_statement(s, &mut variables_else, types)?; } self.builder.build_unconditional_branch(merge_block); else_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -288,14 +346,18 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_expression( &self, expr: &Expression, - variables: &mut HashMap, usize)>, + variables: &mut Variables<'ctx>, + types: &mut VariableTypes<'ctx>, + type_hint: Option>, ) -> Result>> { Ok(match expr { - Expression::Variable(term) => Some(self.compile_variable(term, variables)?), - Expression::Literal(term) => Some(self.compile_literal(term)?), - Expression::Call { function, args } => self.compile_call(function, args, variables)?, + Expression::Variable(term) => Some(self.compile_variable(term, variables, types)?), + Expression::Literal(term) => Some(self.compile_literal(term, type_hint)?), + Expression::Call { function, args } => { + self.compile_call(function, args, variables, types)? + } Expression::BinaryOp(lhs, op, rhs) => { - Some(self.compile_binary_op(lhs, op, rhs, variables)?) + Some(self.compile_binary_op(lhs, op, rhs, variables, types, type_hint)?) } }) } @@ -304,15 +366,17 @@ impl<'ctx> CodeGen<'ctx> { &self, func_name: &str, args: &[Box], - variables: &mut HashMap, usize)>, + variables: &mut Variables<'ctx>, + types: &mut VariableTypes<'ctx>, ) -> Result>> { let function = self.module.get_function(func_name).expect("should exist"); let mut value_args: Vec = Vec::with_capacity(args.len()); for arg in args { + let type_enum = self.find_expr_type(arg, types); let res = self - .compile_expression(arg, variables)? + .compile_expression(arg, variables, types, type_enum)? .expect("should have result"); value_args.push(res.into()); } @@ -333,14 +397,16 @@ impl<'ctx> CodeGen<'ctx> { lhs: &Expression, op: &OpCode, rhs: &Expression, - variables: &mut HashMap, usize)>, + variables: &mut Variables<'ctx>, + types: &mut VariableTypes<'ctx>, + type_hint: Option>, ) -> Result> { let lhs = self - .compile_expression(lhs, variables)? + .compile_expression(lhs, variables, types, type_hint)? .expect("should have result") .into_int_value(); let rhs = self - .compile_expression(rhs, variables)? + .compile_expression(rhs, variables, types, type_hint)? .expect("should have result") .into_int_value(); @@ -363,7 +429,11 @@ impl<'ctx> CodeGen<'ctx> { Ok(result.as_basic_value_enum()) } - pub fn compile_literal(&self, term: &LiteralValue) -> Result> { + pub fn compile_literal( + &self, + term: &LiteralValue, + type_hint: Option>, + ) -> Result> { let value = match term { LiteralValue::String => todo!(), LiteralValue::Integer { @@ -371,13 +441,19 @@ impl<'ctx> CodeGen<'ctx> { signed: _, value, } => { - // todo: type resolution for bit size? - let bits = bits.unwrap_or(32); + if let Some(type_hint) = type_hint { + type_hint + .into_int_type() + .const_int(value.parse().unwrap(), false) + .as_basic_value_enum() + } else { + let bits = bits.unwrap_or(32); - self.context - .custom_width_int_type(bits) - .const_int(value.parse().unwrap(), false) - .as_basic_value_enum() + self.context + .custom_width_int_type(bits) + .const_int(value.parse().unwrap(), false) + .as_basic_value_enum() + } } }; @@ -387,7 +463,8 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_variable( &self, variable: &str, - variables: &mut HashMap, usize)>, + variables: &mut Variables<'ctx>, + types: &mut VariableTypes<'ctx>, ) -> Result> { let var = *variables.get(variable).expect("value"); Ok(var.0) diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 54cea829e..3edcc16bb 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -69,10 +69,14 @@ Statement: ast::Statement = { => ast::Statement::Function(f), }; +TypeInfo: String = { + ":" => i +}; + // statements not including function definitions BasicStatement: ast::Statement = { - "let" "=" ";" => ast::Statement::Variable { name: i, value: e}, - "=" ";" => ast::Statement::Variable { name: i, value: e}, + "let" "=" ";" => ast::Statement::Let { name: i, value: e, type_name: t}, + "=" ";" => ast::Statement::Mutate { name: i, value: e}, "if" "{" "}" => ast::Statement::If { condition: cond, body: s, else_body: e}, "return" ";" => ast::Statement::Return(e), }; diff --git a/src/main.rs b/src/main.rs index 148eedf52..c277715c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -123,7 +123,7 @@ fn main() -> Result<()> { println!("{:#?}", ast); let context = Context::create(); - let codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; + let mut codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; codegen.compile_ast()?; let generated_llvm_ir = codegen.generated_code(); @@ -144,7 +144,7 @@ fn main() -> Result<()> { let file_name = input.file_name().unwrap().to_string_lossy(); let context = Context::create(); - let codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; + let mut codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; codegen.compile_ast()?; let execution_engine = codegen .module