From c568cf08e9167a3fe42ab5aec4abceee5cd2f9aa Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sat, 3 Jun 2023 12:27:30 +0200 Subject: [PATCH] add rudimentary type inference, todo: args --- Cargo.lock | 23 ++-- Cargo.toml | 2 +- src/ast/mod.rs | 15 ++- src/codegen.rs | 156 +++++++----------------- src/grammar.lalrpop | 22 +++- src/main.rs | 12 +- src/type_analysis.rs | 276 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 363 insertions(+), 143 deletions(-) create mode 100644 src/type_analysis.rs diff --git a/Cargo.lock b/Cargo.lock index d618a1646..4f0a67fdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,9 +156,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.3.0" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93aae7a4192245f70fe75dd9157fc7b4a5bf53e88d30bd4396f7d8f9284d5acc" +checksum = "b4ed2379f8603fa2b7509891660e802b88c70a79a6427a70abb5968054de2c28" dependencies = [ "clap_builder", "clap_derive", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.0" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f423e341edefb78c9caba2d9c7f7687d0e72e89df3ce3394554754393ac3990" +checksum = "72394f3339a76daf211e57d4bcb374410f3965dcc606dd0e03738c7888766980" dependencies = [ "anstream", "anstyle", @@ -180,9 +180,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.0" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "191d9573962933b4027f932c600cd252ce27a8ad5979418fe78e43c07996f27b" +checksum = "59e9ef9a08ee1c0e1f2e162121665ac45ac3783b0f897db7244ae75ad9a8f65b" dependencies = [ "heck", "proc-macro2", @@ -520,12 +520,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.17" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "518ef76f2f87365916b142844c16d8fefd85039bc5699050210a7778ee1cd1de" [[package]] name = "logos" @@ -610,9 +607,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "9670a07f94779e00908f3e686eab508878ebb390ba6e604d3a284c00e8d0487b" [[package]] name = "overload" diff --git a/Cargo.toml b/Cargo.toml index 59fb864ab..eeecd749e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ categories = ["compilers"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -clap = { version = "4.3.0", features = ["derive"] } +clap = { version = "4.3.1", features = ["derive"] } color-eyre = "0.6.2" itertools = "0.10.5" lalrpop-util = { version = "0.20.0", features = ["lexer"] } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2c01e44b9..14b67573e 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -51,17 +51,25 @@ pub enum TypeExp { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum LiteralValue { String(String), - Integer(String), + Integer { + value: String, + bits: Option, + signed: Option, + }, Boolean(bool), } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Expression { Literal(LiteralValue), - Variable(Spanned), + Variable { + name: Spanned, + value_type: Option, + }, Call { function: String, args: Vec>, + value_type: Option, }, BinaryOp(Box, OpCode, Box), } @@ -128,12 +136,13 @@ pub enum Statement { Let { name: String, value: Box, - type_name: Option, + value_type: Option, span: (usize, usize), }, Mutate { name: String, value: Box, + value_type: Option, span: (usize, usize), }, If { diff --git a/src/codegen.rs b/src/codegen.rs index bb23e98b9..963d5e70a 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -261,7 +261,7 @@ impl<'ctx> CodeGen<'ctx> { if let Statement::Return(_) = statement { has_return = true } - self.compile_statement(function, func, statement, &mut variables, &mut types)?; + self.compile_statement(func, statement, &mut variables, &mut types)?; } if !has_return { @@ -271,41 +271,8 @@ impl<'ctx> CodeGen<'ctx> { Ok(()) } - fn find_expr_type(&self, expr: &Expression, variables: &Variables<'ctx>) -> Option { - match expr { - Expression::Literal(x) => match x { - LiteralValue::String(_s) => { - todo!("make internal string struct") - /* todo: internal string structure here - Some( - self.context - .i8_type() - .array_type(s.bytes().len() as u32 + 1) - .as_basic_type_enum(), - ) */ - } - LiteralValue::Integer(_) => Some(TypeExp::Integer { - bits: 32, - signed: true, - }), - LiteralValue::Boolean(_) => Some(TypeExp::Boolean), - }, - Expression::Variable(x) => variables.get(&x.value).cloned().map(|x| x.type_exp), - Expression::Call { function, args: _ } => { - self.functions.get(function).unwrap().clone().1 - } - Expression::BinaryOp(lhs, op, rhs) => match op { - //OpCode::Eq | OpCode::Ne => Some(TypeExp::Boolean), - _ => self - .find_expr_type(lhs, variables) - .or_else(|| self.find_expr_type(rhs, variables)), - }, - } - } - fn compile_statement( &self, - function: &Function, function_value: FunctionValue, statement: &Statement, // value, assignments @@ -317,22 +284,11 @@ impl<'ctx> CodeGen<'ctx> { Statement::Let { name, value, - type_name, + value_type: _, .. } => { - let type_hint = if let Some(type_name) = type_name { - type_name.clone() - } else { - let type_exp = self - .find_expr_type(value, variables) - .expect("type should be found"); - let ty = self.get_llvm_type(&type_exp)?; - types.insert(type_exp.clone(), ty); - type_exp - }; - let (value, value_type) = self - .compile_expression(value, variables, types, Some(type_hint))? + .compile_expression(value, variables, types)? .expect("should have result"); if !types.contains_key(&value_type) { @@ -352,11 +308,8 @@ impl<'ctx> CodeGen<'ctx> { ); } Statement::Mutate { name, value, .. } => { - let var = variables.get(name).cloned().expect("variable should exist"); - let type_hint = var.type_exp; - let (value, value_type) = self - .compile_expression(value, variables, types, Some(type_hint))? + .compile_expression(value, variables, types)? .expect("should have result"); let var = variables.get_mut(name).expect("variable should exist"); @@ -367,14 +320,8 @@ impl<'ctx> CodeGen<'ctx> { } Statement::Return(ret) => { if let Some(ret) = ret { - let type_hint = self - .functions - .get(&function.name) - .expect("function should exist") - .clone() - .1; let (value, _value_type) = self - .compile_expression(ret, variables, types, type_hint)? + .compile_expression(ret, variables, types)? .expect("should have result"); self.builder.build_return(Some(&value)); } else { @@ -386,9 +333,8 @@ impl<'ctx> CodeGen<'ctx> { body, else_body, } => { - let type_hint = self.find_expr_type(condition, variables); let (condition, _cond_type) = self - .compile_expression(condition, variables, types, type_hint)? + .compile_expression(condition, variables, types)? .expect("should produce a value"); let mut if_block = self.context.append_basic_block(function_value, "if"); @@ -408,7 +354,7 @@ impl<'ctx> CodeGen<'ctx> { let mut variables_if = variables.clone(); self.builder.position_at_end(if_block); for s in body { - self.compile_statement(function, function_value, s, &mut variables_if, types)?; + self.compile_statement(function_value, s, &mut variables_if, types)?; } self.builder.build_unconditional_branch(merge_block); if_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -418,13 +364,7 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(else_block); for s in else_body { - self.compile_statement( - function, - function_value, - s, - &mut variables_else, - types, - )?; + self.compile_statement(function_value, s, &mut variables_else, types)?; } self.builder.build_unconditional_branch(merge_block); else_block = self.builder.get_insert_block().unwrap(); // update for phi @@ -489,16 +429,20 @@ impl<'ctx> CodeGen<'ctx> { expr: &Expression, variables: &mut Variables<'ctx>, types: &mut TypeStorage<'ctx>, - type_hint: Option, ) -> Result, TypeExp)>> { Ok(match expr { - Expression::Variable(term) => Some(self.compile_variable(&term.value, variables)?), - Expression::Literal(term) => Some(self.compile_literal(term, type_hint)?), - Expression::Call { function, args } => { - self.compile_call(function, args, variables, types)? - } + Expression::Variable { + name, + value_type: _, + } => Some(self.compile_variable(&name.value, variables)?), + Expression::Literal(term) => Some(self.compile_literal(term)?), + Expression::Call { + function, + args, + value_type, + } => self.compile_call(function, args, variables, types, value_type.clone())?, Expression::BinaryOp(lhs, op, rhs) => { - Some(self.compile_binary_op(lhs, op, rhs, variables, types, type_hint)?) + Some(self.compile_binary_op(lhs, op, rhs, variables, types)?) } }) } @@ -509,20 +453,16 @@ impl<'ctx> CodeGen<'ctx> { args: &[Box], variables: &mut Variables<'ctx>, types: &mut TypeStorage<'ctx>, + value_type: Option, ) -> Result, TypeExp)>> { info!("compiling fn call: func_name={}", func_name); let function = self.module.get_function(func_name).expect("should exist"); - let func_info = self - .functions - .get(func_name) - .cloned() - .expect("should exist"); let mut value_args: Vec = Vec::with_capacity(args.len()); - for (arg, arg_type) in args.iter().zip(func_info.0.iter()) { + for arg in args.iter() { let (res, _res_type) = self - .compile_expression(arg, variables, types, Some(arg_type.clone()))? + .compile_expression(arg, variables, types)? .expect("should have result"); value_args.push(res.into()); } @@ -533,10 +473,7 @@ impl<'ctx> CodeGen<'ctx> { .try_as_basic_value(); Ok(match result { - Either::Left(val) => Some(( - val, - func_info.1.expect("should have ret type info if returns"), - )), + Either::Left(val) => Some((val, value_type.unwrap())), Either::Right(_) => None, }) } @@ -548,13 +485,12 @@ impl<'ctx> CodeGen<'ctx> { rhs: &Expression, variables: &mut Variables<'ctx>, types: &mut TypeStorage<'ctx>, - type_hint: Option, ) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { let (lhs, lhs_type) = self - .compile_expression(lhs, variables, types, type_hint.clone())? + .compile_expression(lhs, variables, types)? .expect("should have result"); let (rhs, _rhs_type) = self - .compile_expression(rhs, variables, types, type_hint)? + .compile_expression(rhs, variables, types)? .expect("should have result"); let lhs = lhs.into_int_value(); @@ -593,11 +529,7 @@ impl<'ctx> CodeGen<'ctx> { Ok((result.as_basic_value_enum(), res_type)) } - pub fn compile_literal( - &self, - term: &LiteralValue, - type_hint: Option, - ) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { + pub fn compile_literal(&self, term: &LiteralValue) -> Result<(BasicValueEnum<'ctx>, TypeExp)> { let value = match term { LiteralValue::String(_s) => { todo!() @@ -614,28 +546,20 @@ impl<'ctx> CodeGen<'ctx> { .as_basic_value_enum(), TypeExp::Boolean, ), - LiteralValue::Integer(v) => { - if let Some(type_hint) = type_hint { - ( - self.get_llvm_type(&type_hint)? - .into_int_type() - .const_int(v.parse().unwrap(), false) - .as_basic_value_enum(), - type_hint, - ) - } else { - let type_exp = TypeExp::Integer { - bits: 32, - signed: true, - }; - ( - self.get_llvm_type(&type_exp)? - .into_int_type() - .const_int(v.parse().unwrap(), false) - .as_basic_value_enum(), - type_exp, - ) - } + LiteralValue::Integer { + value, + bits, + signed, + } => { + let bits = bits.unwrap_or(32); + let signed = signed.unwrap_or(true); + ( + self.context + .custom_width_int_type(bits) + .const_int(value.parse().unwrap(), false) + .as_basic_value_enum(), + TypeExp::Integer { bits, signed }, + ) } }; diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 26bca94b3..adeb1e360 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -94,9 +94,12 @@ TypeInfo: ast::TypeExp = { // statements not including function definitions BasicStatement: ast::Statement = { - "let" "=" ";" => ast::Statement::Let { name: i, value: e, type_name: t, span: (lo, hi) }, - "=" ";" => ast::Statement::Mutate { name: i, value: e, span: (lo, hi) }, - "if" "{" "}" => ast::Statement::If { condition: cond, body: s, else_body: e}, + "let" "=" ";" => + ast::Statement::Let { name: i, value: e, value_type: t, span: (lo, hi) }, + "=" ";" => + ast::Statement::Mutate { name: i, value: e, span: (lo, hi), value_type: None }, + "if" "{" "}" => + ast::Statement::If { condition: cond, body: s, else_body: e}, "return" ";" => ast::Statement::Return(e), }; @@ -137,15 +140,22 @@ Expr4 = Tier; // Terms: variables, literals, calls Term: Box = { - => Box::new(ast::Expression::Variable(Spanned::new(i, (lo, hi)))), + => Box::new(ast::Expression::Variable { + name: Spanned::new(i, (lo, hi)), + value_type: None + }), => Box::new(ast::Expression::Literal(n)), => Box::new(ast::Expression::Literal(n)), => Box::new(ast::Expression::Literal(n)), - "(" > ")" => Box::new(ast::Expression::Call { function: i, args: values}), + "(" > ")" => Box::new(ast::Expression::Call { function: i, args: values, value_type: None }), "(" ")" }; -Number: ast::LiteralValue = => ast::LiteralValue::Integer(n); +Number: ast::LiteralValue = => ast::LiteralValue::Integer { + value: n, + bits: None, + signed: None +}; StringLit: ast::LiteralValue = => ast::LiteralValue::String(n[1..(n.len()-1)].to_string()); diff --git a/src/main.rs b/src/main.rs index a366bd52c..32e1fefd8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ pub mod check; pub mod codegen; pub mod lexer; pub mod tokens; +pub mod type_analysis; lalrpop_mod!(pub grammar); @@ -99,15 +100,17 @@ fn main() -> Result<()> { let code = fs::read_to_string(&input)?; let lexer = Lexer::new(code.as_str()); let parser = grammar::ProgramParser::new(); - let ast = parser.parse(lexer)?; + let mut ast = parser.parse(lexer)?; + type_analysis::type_inference(&mut ast); let program = ProgramData::new(&input, &code); check_program(&program, &ast); } Commands::Ast { input } => { - let code = fs::read_to_string(&input)?; + let code = fs::read_to_string(input)?; let lexer = Lexer::new(code.as_str()); let parser = grammar::ProgramParser::new(); - let ast = parser.parse(lexer)?; + let mut ast = parser.parse(lexer)?; + type_analysis::type_inference(&mut ast); println!("{ast:#?}"); } Commands::Compile { @@ -119,7 +122,8 @@ fn main() -> Result<()> { let code = fs::read_to_string(&input)?; let lexer = Lexer::new(code.as_str()); let parser = grammar::ProgramParser::new(); - let ast: Program = parser.parse(lexer)?; + let mut ast: Program = parser.parse(lexer)?; + type_analysis::type_inference(&mut ast); let program = ProgramData::new(&input, &code); diff --git a/src/type_analysis.rs b/src/type_analysis.rs new file mode 100644 index 000000000..a72cfe832 --- /dev/null +++ b/src/type_analysis.rs @@ -0,0 +1,276 @@ +use std::collections::HashMap; + +use tracing::{info, warn}; + +use crate::ast::{self, Expression, Function, Statement, TypeExp}; + +pub fn type_inference(ast: &mut ast::Program) { + for statement in ast.statements.iter_mut() { + if let Statement::Function(function) = statement { + let ret_type = function.return_type.clone(); + let mut var_cache: HashMap = HashMap::new(); + + if let Some(ret_type) = &ret_type { + let ret_type_exp = fn_return_type(function); + + if let Some(exp) = ret_type_exp { + set_expression_type(exp, ret_type, &mut var_cache); + } + } + + update_statements(&mut function.body, &mut var_cache); + } + } +} + +fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap) { + let mut var_cache = var_cache.clone(); + + { + let mut let_or_mut: Vec<&mut Statement> = statements + .iter_mut() + .filter(|x| matches!(x, Statement::Let { .. } | Statement::Mutate { .. })) + .collect(); + + // process mutate first + for st in let_or_mut.iter_mut() { + if let Statement::Mutate { + name, + value, + value_type, + .. + } = st + { + if let Some(value_type) = value_type { + // todo: check types matches? + var_cache.insert(name.clone(), value_type.clone()); + set_expression_type(value, value_type, &mut var_cache); + } else { + // evalue the value expr first to find a possible type. + if var_cache.contains_key(name) { + *value_type = var_cache.get(name).cloned(); + let mut env = Some(value_type.clone().unwrap()); + set_exp_types_from_cache(value, &mut var_cache, &mut env); + } else { + // no type info? + } + } + } + } + + // we need to process lets with a specified type first. + for st in let_or_mut.iter_mut() { + if let Statement::Let { + name, + value, + value_type, + .. + } = st + { + if let Some(value_type) = value_type { + // todo: check types matches? + var_cache.insert(name.clone(), value_type.clone()); + set_expression_type(value, value_type, &mut var_cache); + } else { + // evalue the value expr first to find a possible type. + if var_cache.contains_key(name) { + *value_type = var_cache.get(name).cloned(); + let mut env = Some(value_type.clone().unwrap()); + set_exp_types_from_cache(value, &mut var_cache, &mut env); + } else { + // no type info? + } + } + } + } + } + + for st in statements.iter_mut() { + match st { + Statement::Let { + name, + value_type, + value, + .. + } => { + // infer type if let has no type + if value_type.is_none() { + // evalue the value expr first to find a possible type. + let mut env = None; + set_exp_types_from_cache(value, &mut var_cache, &mut env); + + // try to find if it was set on the cache + if var_cache.contains_key(name) { + *value_type = var_cache.get(name).cloned(); + set_expression_type(value, value_type.as_ref().unwrap(), &mut var_cache); + } else { + // what here? no let type, no cache + println!("no cache let found") + } + } + } + Statement::Mutate { + name, + value_type, + value, + .. + } => { + if let Some(value_type) = value_type { + // todo: check types matches? + var_cache.insert(name.clone(), value_type.clone()); + set_expression_type(value, value_type, &mut var_cache); + } else { + // evalue the value expr first to find a possible type. + if var_cache.contains_key(name) { + *value_type = var_cache.get(name).cloned(); + let mut env = Some(value_type.clone().unwrap()); + set_exp_types_from_cache(value, &mut var_cache, &mut env); + } else { + // no type info? + } + } + } + Statement::If { + condition, + body, + else_body, + } => { + let mut env = None; + set_exp_types_from_cache(condition, &mut var_cache, &mut env); + update_statements(body, &mut var_cache); + if let Some(else_body) = else_body { + update_statements(else_body, &mut var_cache); + } + } + Statement::Return(exp) => { + if let Some(exp) = exp { + let mut env = None; + set_exp_types_from_cache(exp, &mut var_cache, &mut env); + } + } + Statement::Function(_) => unreachable!(), + Statement::Struct(_) => unreachable!(), + } + } +} + +fn fn_return_type(func: &mut Function) -> Option<&mut Box> { + for st in func.body.iter_mut() { + if let Statement::Return(r) = st { + return r.as_mut(); + } + } + None +} + +// set variables using the cache +fn set_exp_types_from_cache( + exp: &mut Expression, + var_cache: &mut HashMap, + env: &mut Option, +) { + match exp { + Expression::Variable { name, value_type } => { + let name = name.value.clone(); + if let Some(value_type) = value_type { + // todo: check types matches? + var_cache.insert(name, value_type.clone()); + } else if var_cache.contains_key(&name) { + *value_type = var_cache.get(&name).cloned(); + if env.is_none() { + *env = value_type.clone(); + } + } + } + Expression::BinaryOp(lhs, op, rhs) => match op { + ast::OpCode::Eq | ast::OpCode::Ne => {} + _ => { + set_exp_types_from_cache(lhs, var_cache, env); + set_exp_types_from_cache(rhs, var_cache, env); + } + }, + Expression::Literal(lit) => match lit { + ast::LiteralValue::String(_) => { + warn!("found string, unimplemented") + } + ast::LiteralValue::Integer { bits, signed, .. } => { + if let Some(TypeExp::Integer { + bits: t_bits, + signed: t_signed, + }) = env + { + *bits = Some(*t_bits); + *signed = Some(*t_signed); + } + } + ast::LiteralValue::Boolean(_) => { + warn!("found bool, unimplemented") + } + }, + Expression::Call { + function: _, + args: _, + value_type, + } => { + match value_type { + Some(value_type) => *env = Some(value_type.clone()), + None => { + if env.is_some() { + *value_type = env.clone(); + } + } + } + // TODO: infer args based on function args! + } + } +} + +fn set_expression_type( + exp: &mut Expression, + expected_type: &TypeExp, + var_cache: &mut HashMap, +) { + match exp { + Expression::Variable { name, value_type } => { + // if needed? + if value_type.is_none() { + *value_type = Some(expected_type.clone()); + } + if !var_cache.contains_key(&name.value) { + var_cache.insert(name.value.clone(), expected_type.clone()); + } + } + Expression::BinaryOp(lhs, op, rhs) => match op { + ast::OpCode::Eq | ast::OpCode::Ne => {} + _ => { + set_expression_type(lhs, expected_type, var_cache); + set_expression_type(rhs, expected_type, var_cache); + } + }, + Expression::Literal(lit) => match lit { + ast::LiteralValue::String(_) => { + warn!("found string, unimplemented") + } + ast::LiteralValue::Integer { bits, signed, .. } => { + if let TypeExp::Integer { + bits: t_bits, + signed: t_signed, + } = expected_type + { + *bits = Some(*t_bits); + *signed = Some(*t_signed); + } + } + ast::LiteralValue::Boolean(_) => { + warn!("found bool, unimplemented") + } + }, + Expression::Call { + function: _, + args: _, + value_type, + } => { + *value_type = Some(expected_type.clone()); + } + } +}