From 8c212d948d4c1c234c69822f3043e124f948080a Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sun, 11 Jun 2023 12:07:15 +0200 Subject: [PATCH] require fully typed integers for now --- example.ed => programs/example.ed | 0 programs/ifelse.ed | 15 ++ simple.ed => programs/simple.ed | 2 +- src/ast/mod.rs | 19 ++- src/codegen.rs | 271 ++++++++++++------------------ src/grammar.lalrpop | 17 +- src/main.rs | 6 +- src/type_analysis.rs | 111 ++++++------ 8 files changed, 208 insertions(+), 233 deletions(-) rename example.ed => programs/example.ed (100%) create mode 100644 programs/ifelse.ed rename simple.ed => programs/simple.ed (93%) diff --git a/example.ed b/programs/example.ed similarity index 100% rename from example.ed rename to programs/example.ed diff --git a/programs/ifelse.ed b/programs/ifelse.ed new file mode 100644 index 000000000..6703b2c37 --- /dev/null +++ b/programs/ifelse.ed @@ -0,0 +1,15 @@ +fn works(x: i64) -> i64 { + let z = 0i64; + if 2i64 == x { + z = x * 2i64; + } else { + z = x * 3i64; + } + return z; +} + +fn main() -> i64 { + let y = 2i64; + let z = y; + return works(z); +} diff --git a/simple.ed b/programs/simple.ed similarity index 93% rename from simple.ed rename to programs/simple.ed index 788b78dd3..c21572016 100644 --- a/simple.ed +++ b/programs/simple.ed @@ -9,7 +9,7 @@ fn test(x: Hello) { fn works(x: i64) -> i64 { let z = 0i64; - if 2 == x { + if 2i64 == x { z = x * 2i64; } else { z = x * 3i64; diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8668853ad..680e1039d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Spanned { pub span: (usize, usize), @@ -64,12 +66,10 @@ pub enum Expression { Literal(LiteralValue), Variable { name: Spanned, - value_type: Option, }, Call { function: String, args: Vec>, - value_type: Option, }, BinaryOp(Box, OpCode, Box), } @@ -86,16 +86,17 @@ impl Parameter { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq)] pub struct Function { pub name: String, pub params: Vec, pub body: Vec, + pub scope_type_info: HashMap>, pub return_type: Option, } impl Function { - pub const fn new( + pub fn new( name: String, params: Vec, body: Vec, @@ -106,6 +107,7 @@ impl Function { params, body, return_type, + scope_type_info: HashMap::new(), } } } @@ -113,14 +115,14 @@ impl Function { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct StructField { pub ident: String, - pub type_exp: TypeExp, + pub field_type: TypeExp, } impl StructField { pub const fn new(ident: String, type_name: TypeExp) -> Self { Self { ident, - type_exp: type_name, + field_type: type_name, } } } @@ -131,7 +133,7 @@ pub struct Struct { pub fields: Vec, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq)] pub enum Statement { Let { name: String, @@ -142,13 +144,14 @@ pub enum Statement { Mutate { name: String, value: Box, - value_type: Option, span: (usize, usize), }, If { condition: Box, body: Vec, + scope_type_info: HashMap>, else_body: Option>, + else_body_scope_type_info: HashMap>, }, Return(Option>), Function(Function), diff --git a/src/codegen.rs b/src/codegen.rs index 73bbe567e..15f34d8b4 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -37,10 +37,10 @@ pub struct CodeGen<'ctx> { context: &'ctx Context, pub module: Module<'ctx>, builder: Builder<'ctx>, - types: TypeStorage<'ctx>, + //types: TypeStorage<'ctx>, struct_types: StructTypeStorage<'ctx>, // function to return type - functions: HashMap, Option)>, + functions: HashMap, _program: ProgramData, ast: ast::Program, } @@ -48,12 +48,12 @@ pub struct CodeGen<'ctx> { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Variable<'ctx> { pub value: BasicValueEnum<'ctx>, + pub type_counter: usize, pub phi_counter: usize, - pub type_exp: TypeExp, } pub type Variables<'ctx> = HashMap>; -pub type TypeStorage<'ctx> = HashMap>; +// pub type TypeStorage<'ctx> = HashMap>; /// Holds the struct type and maps fields to types and the location within the struct. #[derive(Debug, Clone, PartialEq, Eq)] @@ -79,7 +79,6 @@ impl<'ctx> CodeGen<'ctx> { builder: context.create_builder(), _program, ast, - types: HashMap::new(), struct_types: HashMap::new(), functions: HashMap::new(), }; @@ -88,9 +87,8 @@ impl<'ctx> CodeGen<'ctx> { } pub fn compile_ast(&mut self) -> Result<()> { - let mut functions = vec![]; - let mut func_info = HashMap::new(); - let mut types: TypeStorage<'ctx> = HashMap::new(); + let mut functions = HashMap::new(); + // let mut types: TypeStorage<'ctx> = HashMap::new(); let mut struct_types: StructTypeStorage<'ctx> = HashMap::new(); // todo fix the grammar so top level statements are only functions and static vars. @@ -102,13 +100,11 @@ impl<'ctx> CodeGen<'ctx> { let mut field_types = vec![]; for (i, field) in s.fields.iter().enumerate() { - if !types.contains_key(&field.type_exp) { - types.insert(field.type_exp.clone(), self.get_llvm_type(&field.type_exp)?); - } - let ty = self.get_llvm_type(&field.type_exp)?; + // todo: this doesnt handle out of order structs well + let ty = self.get_llvm_type(&field.field_type)?; field_types.push(ty); // todo: ensure alignment and padding here - fields.insert(field.ident.clone(), (i, field.type_exp.clone())); + fields.insert(field.ident.clone(), (i, field.field_type.clone())); } let ty = self.context.struct_type(&field_types, false); @@ -123,38 +119,16 @@ impl<'ctx> CodeGen<'ctx> { // create the llvm functions first. for statement in &self.ast.statements { if let Statement::Function(function) = &statement { - functions.push(function); - let (args, ret_type) = self.compile_function_signature(function)?; - let mut arg_types = vec![]; - for arg in args { - if !types.contains_key(&arg) { - let ty = self.get_llvm_type(&arg)?; - types.insert(arg.clone(), ty); - } - arg_types.push(arg); - } - if let Some(ret_type) = ret_type { - let ret_type = if !types.contains_key(&ret_type) { - let ty = self.get_llvm_type(&ret_type)?; - types.insert(ret_type.clone(), ty); - ret_type - } else { - ret_type - }; - func_info.insert(function.name.clone(), (arg_types, Some(ret_type))); - } else { - func_info.insert(function.name.clone(), (arg_types, None)); - } + functions.insert(function.name.clone(), function.clone()); + self.compile_function_signature(function)?; } } - - self.types = types; - self.functions = func_info; + self.functions = functions; info!("functions:\n{:#?}", self.functions); // implement them. - for function in functions { + for (_, function) in &self.functions { self.compile_function(function)?; } @@ -169,38 +143,31 @@ impl<'ctx> CodeGen<'ctx> { } fn get_llvm_type(&self, id: &TypeExp) -> Result> { - if let Some(ty) = self.types.get(id) { - Ok(*ty) - } else { - Ok(match id { - TypeExp::Integer { bits, signed: _ } => self - .context - .custom_width_int_type(*bits) - .as_basic_type_enum(), - TypeExp::Boolean => self.context.bool_type().as_basic_type_enum(), - TypeExp::Array { of, len } => { - let ty = self.get_llvm_type(of)?; - ty.array_type(len.unwrap()).as_basic_type_enum() - } - TypeExp::Pointer { target } => { - let ty = self.get_llvm_type(target)?; - ty.ptr_type(Default::default()).as_basic_type_enum() - } - TypeExp::Other { id } => self - .struct_types - .get(id) - .expect("struct type not found") - .ty - .as_basic_type_enum(), - }) - } + Ok(match id { + TypeExp::Integer { bits, signed: _ } => self + .context + .custom_width_int_type(*bits) + .as_basic_type_enum(), + TypeExp::Boolean => self.context.bool_type().as_basic_type_enum(), + TypeExp::Array { of, len } => { + let ty = self.get_llvm_type(of)?; + ty.array_type(len.unwrap()).as_basic_type_enum() + } + TypeExp::Pointer { target } => { + let ty = self.get_llvm_type(target)?; + ty.ptr_type(Default::default()).as_basic_type_enum() + } + TypeExp::Other { id } => self + .struct_types + .get(id) + .expect("struct type not found") + .ty + .as_basic_type_enum(), + }) } /// creates the llvm function without the body, so other function bodies can call it. - fn compile_function_signature( - &self, - function: &Function, - ) -> Result<(Vec, Option)> { + fn compile_function_signature(&self, function: &Function) -> Result<()> { let args_types: Vec> = function .params .iter() @@ -221,14 +188,7 @@ impl<'ctx> CodeGen<'ctx> { self.module.add_function(&function.name, fn_type, None); - Ok(( - function - .params - .iter() - .map(|param| param.type_exp.clone()) - .collect(), - ret_type, - )) + Ok(()) } fn compile_function(&self, function: &Function) -> Result<()> { @@ -238,7 +198,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(entry_block); let mut variables: Variables = HashMap::new(); - let mut types: TypeStorage = self.types.clone(); for (i, param) in function.params.iter().enumerate() { let id = ¶m.ident; @@ -250,7 +209,7 @@ impl<'ctx> CodeGen<'ctx> { Variable { value: param_value, phi_counter: 0, - type_exp: param.type_exp.clone(), + type_counter: 0, }, ); } @@ -261,7 +220,7 @@ impl<'ctx> CodeGen<'ctx> { if let Statement::Return(_) = statement { has_return = true } - self.compile_statement(func, statement, &mut variables, &mut types)?; + self.compile_statement(func, statement, &mut variables, &function.scope_type_info)?; } if !has_return { @@ -277,7 +236,7 @@ impl<'ctx> CodeGen<'ctx> { statement: &Statement, // value, assignments variables: &mut Variables<'ctx>, - types: &mut TypeStorage<'ctx>, + scope_info: &HashMap>, ) -> Result<()> { match statement { // Variable assignment @@ -287,41 +246,32 @@ impl<'ctx> CodeGen<'ctx> { value_type: _, .. } => { - let (value, value_type) = self - .compile_expression(value, variables, types)? + let value = self + .compile_expression(value, variables, scope_info)? .expect("should have result"); - if !types.contains_key(&value_type) { - let ty = self.get_llvm_type(&value_type)?; - types.insert(value_type.clone(), ty); - } - - info!("adding variable: name={}, ty={:?}", name, value_type); - variables.insert( name.clone(), Variable { value, phi_counter: 0, - type_exp: value_type, + type_counter: 0, }, ); } Statement::Mutate { name, value, .. } => { - let (value, value_type) = self - .compile_expression(value, variables, types)? + let value = self + .compile_expression(value, variables, scope_info)? .expect("should have result"); let var = variables.get_mut(name).expect("variable should exist"); var.phi_counter += 1; var.value = value; - assert_eq!(var.type_exp, value_type, "variable type shouldn't change!"); - info!("mutated variable: name={}, ty={:?}", name, var.type_exp); } Statement::Return(ret) => { if let Some(ret) = ret { - let (value, _value_type) = self - .compile_expression(ret, variables, types)? + let value = self + .compile_expression(ret, variables, scope_info)? .expect("should have result"); self.builder.build_return(Some(&value)); } else { @@ -332,9 +282,11 @@ impl<'ctx> CodeGen<'ctx> { condition, body, else_body, + scope_type_info, + else_body_scope_type_info, } => { - let (condition, _cond_type) = self - .compile_expression(condition, variables, types)? + let condition = self + .compile_expression(condition, variables, scope_info)? .expect("should produce a value"); let mut if_block = self.context.append_basic_block(function_value, "if"); @@ -354,7 +306,7 @@ impl<'ctx> CodeGen<'ctx> { let mut variables_if = variables.clone(); self.builder.position_at_end(if_block); for s in body { - self.compile_statement(function_value, s, &mut variables_if, types)?; + self.compile_statement(function_value, s, &mut variables_if, scope_type_info)?; } self.builder.build_unconditional_branch(merge_block); if_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -364,7 +316,12 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(else_block); for s in else_body { - self.compile_statement(function_value, s, &mut variables_else, types)?; + self.compile_statement( + function_value, + s, + &mut variables_else, + else_body_scope_type_info, + )?; } self.builder.build_unconditional_branch(merge_block); else_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -381,7 +338,7 @@ impl<'ctx> CodeGen<'ctx> { .builder .build_phi(old_var.value.get_type(), &format!("{name}_phi")); phi.add_incoming(&[(&new_var.value, if_block)]); - processed_vars.insert(name, (phi, new_var.type_exp)); + processed_vars.insert(name, phi); } } } @@ -391,7 +348,7 @@ impl<'ctx> CodeGen<'ctx> { if variables.contains_key(&name) { let old_var = variables.get(&name).unwrap(); if new_var.phi_counter > old_var.phi_counter { - if let Some((phi, _)) = processed_vars.get(&name) { + if let Some(phi) = processed_vars.get(&name) { phi.add_incoming(&[(&new_var.value, else_block)]); } else { let phi = self.builder.build_phi( @@ -399,22 +356,25 @@ impl<'ctx> CodeGen<'ctx> { &format!("{name}_phi"), ); phi.add_incoming(&[(&old_var.value, else_block)]); - processed_vars.insert(name, (phi, new_var.type_exp)); + processed_vars.insert(name, phi); } } } } } - for (name, (phi, type_exp)) in processed_vars { + for (name, phi) in processed_vars { + /* variables.insert( name, Variable { value: phi.as_basic_value(), phi_counter: 0, - type_exp, }, ); + */ + let mut var = variables.get_mut(&name).unwrap(); + var.value = phi.as_basic_value(); } } Statement::Function(_) => unreachable!(), @@ -428,21 +388,16 @@ impl<'ctx> CodeGen<'ctx> { &self, expr: &Expression, variables: &mut Variables<'ctx>, - types: &mut TypeStorage<'ctx>, - ) -> Result, TypeExp)>> { + scope_info: &HashMap>, + ) -> Result>> { Ok(match expr { - Expression::Variable { - name, - value_type: _, - } => Some(self.compile_variable(&name.value, variables)?), + Expression::Variable { name } => Some(self.compile_variable(&name.value, variables)?), Expression::Literal(term) => Some(self.compile_literal(term)?), - Expression::Call { - function, - args, - value_type, - } => self.compile_call(function, args, variables, types, value_type.clone())?, + Expression::Call { function, args } => { + self.compile_call(function, args, variables, scope_info)? + } Expression::BinaryOp(lhs, op, rhs) => { - Some(self.compile_binary_op(lhs, op, rhs, variables, types)?) + Some(self.compile_binary_op(lhs, op, rhs, variables, scope_info)?) } }) } @@ -452,17 +407,16 @@ impl<'ctx> CodeGen<'ctx> { func_name: &str, args: &[Box], variables: &mut Variables<'ctx>, - types: &mut TypeStorage<'ctx>, - value_type: Option, - ) -> Result, TypeExp)>> { + scope_info: &HashMap>, + ) -> Result>> { info!("compiling fn call: func_name={}", func_name); let function = self.module.get_function(func_name).expect("should exist"); let mut value_args: Vec = Vec::with_capacity(args.len()); for arg in args.iter() { - let (res, _res_type) = self - .compile_expression(arg, variables, types)? + let res = self + .compile_expression(arg, variables, scope_info)? .expect("should have result"); value_args.push(res.into()); } @@ -473,7 +427,7 @@ impl<'ctx> CodeGen<'ctx> { .try_as_basic_value(); Ok(match result { - Either::Left(val) => Some((val, value_type.unwrap())), + Either::Left(val) => Some(val), Either::Right(_) => None, }) } @@ -484,20 +438,18 @@ impl<'ctx> CodeGen<'ctx> { op: &OpCode, rhs: &Expression, variables: &mut Variables<'ctx>, - types: &mut TypeStorage<'ctx>, - ) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { - let (lhs, lhs_type) = self - .compile_expression(lhs, variables, types)? + scope_info: &HashMap>, + ) -> Result> { + let lhs = self + .compile_expression(lhs, variables, scope_info)? .expect("should have result"); - let (rhs, rhs_type) = self - .compile_expression(rhs, variables, types)? + let rhs = self + .compile_expression(rhs, variables, scope_info)? .expect("should have result"); - assert_eq!(lhs_type, rhs_type); let lhs = lhs.into_int_value(); let rhs = rhs.into_int_value(); - let mut bool_result = false; let result = match op { OpCode::Add => self.builder.build_int_add(lhs, rhs, "add"), OpCode::Sub => self.builder.build_int_sub(lhs, rhs, "sub"), @@ -506,31 +458,18 @@ impl<'ctx> CodeGen<'ctx> { OpCode::Rem => self.builder.build_int_signed_rem(lhs, rhs, "rem"), OpCode::And => self.builder.build_and(lhs, rhs, "and"), OpCode::Or => self.builder.build_or(lhs, rhs, "or"), - OpCode::Eq => { - bool_result = true; - self.builder - .build_int_compare(IntPredicate::EQ, lhs, rhs, "eq") - } - OpCode::Ne => { - bool_result = true; - self.builder - .build_int_compare(IntPredicate::NE, lhs, rhs, "eq") - } + OpCode::Eq => self + .builder + .build_int_compare(IntPredicate::EQ, lhs, rhs, "eq"), + OpCode::Ne => self + .builder + .build_int_compare(IntPredicate::NE, lhs, rhs, "eq"), }; - let mut res_type = lhs_type; - - if bool_result { - res_type = TypeExp::Integer { - bits: 1, - signed: false, - }; - } - - Ok((result.as_basic_value_enum(), res_type)) + Ok(result.as_basic_value_enum()) } - pub fn compile_literal(&self, term: &LiteralValue) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { + pub fn compile_literal(&self, term: &LiteralValue) -> Result> { let value = match term { LiteralValue::String(_s) => { todo!() @@ -540,13 +479,11 @@ impl<'ctx> CodeGen<'ctx> { .const_string(s.as_bytes(), true) .as_basic_value_enum() */ } - LiteralValue::Boolean(v) => ( - self.context - .bool_type() - .const_int((*v).into(), false) - .as_basic_value_enum(), - TypeExp::Boolean, - ), + LiteralValue::Boolean(v) => self + .context + .bool_type() + .const_int((*v).into(), false) + .as_basic_value_enum(), LiteralValue::Integer { value, bits, @@ -554,13 +491,11 @@ impl<'ctx> CodeGen<'ctx> { } => { let bits = *bits; let signed = *signed; - ( - self.context - .custom_width_int_type(bits) - .const_int(value.parse().unwrap(), false) - .as_basic_value_enum(), - TypeExp::Integer { bits, signed }, - ) + + self.context + .custom_width_int_type(bits) + .const_int(value.parse().unwrap(), false) + .as_basic_value_enum() } }; @@ -571,8 +506,8 @@ impl<'ctx> CodeGen<'ctx> { &self, variable: &str, variables: &mut Variables<'ctx>, - ) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { + ) -> Result> { let var = variables.get(variable).expect("value").clone(); - Ok((var.value, var.type_exp)) + Ok(var.value) } } diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 78bfc0ec7..d8e57e0c4 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -1,4 +1,4 @@ -use std::str::FromStr; +use std::collections::HashMap; use crate::{ ast::{self, Spanned}, tokens::Token, @@ -99,9 +99,15 @@ BasicStatement: ast::Statement = { "let" "=" ";" => ast::Statement::Let { name: i, value: e, value_type: t, span: (lo, hi) }, "=" ";" => - ast::Statement::Mutate { name: i, value: e, span: (lo, hi), value_type: None }, + ast::Statement::Mutate { name: i, value: e, span: (lo, hi) }, "if" "{" "}" => - ast::Statement::If { condition: cond, body: s, else_body: e}, + ast::Statement::If { + condition: cond, + body: s, + else_body: e, + scope_type_info: Default::default(), + else_body_scope_type_info: Default::default(), + }, "return" ";" => ast::Statement::Return(e), }; @@ -143,13 +149,12 @@ Expr4 = Tier; // Terms: variables, literals, calls Term: Box = { => Box::new(ast::Expression::Variable { - name: Spanned::new(i, (lo, hi)), - value_type: None + name: Spanned::new(i, (lo, hi)) }), => Box::new(ast::Expression::Literal(n)), => Box::new(ast::Expression::Literal(n)), => Box::new(ast::Expression::Literal(n)), - "(" > ")" => Box::new(ast::Expression::Call { function: i, args: values, value_type: None }), + "(" > ")" => Box::new(ast::Expression::Call { function: i, args: values }), "(" ")" }; diff --git a/src/main.rs b/src/main.rs index 47403110b..ae4df5d60 100644 --- a/src/main.rs +++ b/src/main.rs @@ -102,7 +102,7 @@ fn main() -> Result<()> { let lexer = Lexer::new(code.as_str()); let parser = grammar::ProgramParser::new(); let mut ast = parser.parse(lexer)?; - type_analysis::type_inference2(&mut ast); + type_analysis::type_inference(&mut ast); let program = ProgramData::new(&input, &code); check_program(&program, &ast); } @@ -112,7 +112,7 @@ fn main() -> Result<()> { let parser = grammar::ProgramParser::new(); match parser.parse(lexer) { Ok(mut ast) => { - type_analysis::type_inference2(&mut ast); + type_analysis::type_inference(&mut ast); println!("{ast:#?}"); } Err(e) => { @@ -130,7 +130,7 @@ fn main() -> Result<()> { let lexer = Lexer::new(code.as_str()); let parser = grammar::ProgramParser::new(); let mut ast: Program = parser.parse(lexer)?; - type_analysis::type_inference2(&mut ast); + type_analysis::type_inference(&mut ast); let program = ProgramData::new(&input, &code); diff --git a/src/type_analysis.rs b/src/type_analysis.rs index a5517cc05..271ed112d 100644 --- a/src/type_analysis.rs +++ b/src/type_analysis.rs @@ -8,19 +8,15 @@ struct Storage { functions: HashMap, } -/* -To briefly summarize the union-find algorithm, given the set of all types in a proof, -it allows one to group them together into equivalence classes by means of a union procedure and to - pick a representative for each such class using a find procedure. Emphasizing the word procedure in - the sense of side effect, we're clearly leaving the realm of logic in order to prepare an effective algorithm. - The representative of a u n i o n ( a , b ) {\mathtt {union}}(a,b) is determined such that, if both a and b are - type variables then the representative is arbitrarily one of them, but while uniting a variable and a term, the - term becomes the representative. Assuming an implementation of union-find at hand, one can formulate the unification of two monotypes as follows: - */ +// problem with scopes, +// let x = 2; +// let x = 2i64; + +type ScopeMap = HashMap>>; // this works, but need to find a way to store the found info + handle literal integer types (or not?) // maybe use scope ids -pub fn type_inference2(ast: &mut ast::Program) { +pub fn type_inference(ast: &mut ast::Program) { let mut storage = Storage::default(); // gather global constructs first @@ -30,7 +26,7 @@ pub fn type_inference2(ast: &mut ast::Program) { let fields = st .fields .iter() - .map(|x| (x.ident.clone(), x.type_exp.clone())) + .map(|x| (x.ident.clone(), x.field_type.clone())) .collect(); storage.structs.insert(st.name.clone(), fields); } @@ -46,26 +42,33 @@ pub fn type_inference2(ast: &mut ast::Program) { dbg!(&storage); - for function in storage.functions.values() { - let mut scope_vars: HashMap> = HashMap::new(); + for statement in ast.statements.iter_mut() { + if let Statement::Function(function) = statement { + let mut scope_vars: ScopeMap = HashMap::new(); - for arg in &function.params { - scope_vars.insert(arg.ident.clone(), Some(arg.type_exp.clone())); + for arg in &function.params { + scope_vars.insert(arg.ident.clone(), vec![Some(arg.type_exp.clone())]); + } + + let func_info = function.clone(); + let (new_scope_vars, _) = + type_inference_scope(&mut function.body, &scope_vars, &func_info, &storage); + // todo: check all vars have type info? + function.scope_type_info = new_scope_vars + .into_iter() + .map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect())) + .collect(); } - - let (new_scope_vars, _) = - type_inference_scope(&function.body, &scope_vars, function, &storage); - dbg!(new_scope_vars); } } /// Finds variable types in the scope, returns newly created variables to handle shadowing fn type_inference_scope( - statements: &[ast::Statement], - scope_vars: &HashMap>, + statements: &mut [ast::Statement], + scope_vars: &ScopeMap, func: &Function, storage: &Storage, -) -> (HashMap>, HashSet) { +) -> (ScopeMap, HashSet) { let mut scope_vars = scope_vars.clone(); let mut new_vars: HashSet = HashSet::new(); @@ -81,19 +84,24 @@ fn type_inference_scope( let exp_type = type_inference_expression(value, &mut scope_vars, storage, None); + if !scope_vars.contains_key(name) { + scope_vars.insert(name.clone(), vec![]); + } + + let var = scope_vars.get_mut(name).unwrap(); + if value_type.is_none() { - scope_vars.insert(name.clone(), exp_type); + var.push(exp_type); } else { if exp_type.is_some() && &exp_type != value_type { panic!("let type mismatch: {:?} != {:?}", value_type, exp_type); } - scope_vars.insert(name.clone(), value_type.clone()); + var.push(value_type.clone()); } } Statement::Mutate { name, value, - value_type: _, span: _, } => { if !scope_vars.contains_key(name) { @@ -101,7 +109,7 @@ fn type_inference_scope( } let exp_type = type_inference_expression(value, &mut scope_vars, storage, None); - let var = scope_vars.get_mut(name).unwrap(); + let var = scope_vars.get_mut(name).unwrap().last_mut().unwrap(); if var.is_none() { *var = exp_type; @@ -113,6 +121,8 @@ fn type_inference_scope( condition, body, else_body, + scope_type_info, + else_body_scope_type_info, } => { type_inference_expression( condition, @@ -124,23 +134,33 @@ fn type_inference_scope( let (new_scope_vars, new_vars) = type_inference_scope(body, &scope_vars, func, storage); - for (k, v) in new_scope_vars.into_iter() { + for (k, v) in new_scope_vars.iter() { // not a new var within the scope (shadowing), so type info is valid - if scope_vars.contains_key(&k) && !new_vars.contains(&k) { - scope_vars.insert(k, v); + if scope_vars.contains_key(k) && !new_vars.contains(k) { + scope_vars.insert(k.clone(), v.clone()); } } + *scope_type_info = new_scope_vars + .into_iter() + .map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect())) + .collect(); + if let Some(body) = else_body { let (new_scope_vars, new_vars) = type_inference_scope(body, &scope_vars, func, storage); - for (k, v) in new_scope_vars.into_iter() { + for (k, v) in new_scope_vars.iter() { // not a new var within the scope (shadowing), so type info is valid - if scope_vars.contains_key(&k) && !new_vars.contains(&k) { - scope_vars.insert(k, v); + if scope_vars.contains_key(k) && !new_vars.contains(k) { + scope_vars.insert(k.clone(), v.clone()); } } + + *else_body_scope_type_info = new_scope_vars + .into_iter() + .map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect())) + .collect(); } } Statement::Return(exp) => { @@ -163,7 +183,7 @@ fn type_inference_scope( fn type_inference_expression( exp: &Expression, - scope_vars: &mut HashMap>, + scope_vars: &mut ScopeMap, storage: &Storage, expected_type: Option, ) -> Option { @@ -182,31 +202,28 @@ fn type_inference_expression( ast::LiteralValue::Boolean(_) => Some(TypeExp::Boolean), } } - Expression::Variable { - name, - value_type: _, - } => { - let var = scope_vars.get(&name.value).cloned().flatten(); + Expression::Variable { name } => { + let var = scope_vars + .get_mut(&name.value) + .expect("to exist") + .last_mut() + .unwrap(); if expected_type.is_some() { if var.is_none() { - scope_vars.insert(name.value.clone(), expected_type.clone()); + *var = expected_type.clone(); expected_type } else if expected_type.is_some() { - assert_eq!(var, expected_type, "type mismatch with variables"); + assert_eq!(*var, expected_type, "type mismatch with variables"); expected_type } else { - var + var.clone() } } else { - var + var.clone() } } - Expression::Call { - function, - args, - value_type: _, - } => { + Expression::Call { function, args } => { let func = storage.functions.get(function).cloned().unwrap(); for (i, arg) in args.iter().enumerate() {