This commit is contained in:
Edgar 2023-06-09 11:08:40 +02:00
parent 10cc3ae591
commit e1736a4ee8
No known key found for this signature in database
GPG key ID: 70ADAE8F35904387
2 changed files with 49 additions and 20 deletions

View file

@ -18,7 +18,7 @@ fn works(x: i64) -> i64 {
} }
fn main() -> i64 { fn main() -> i64 {
let y: i64 = 2; let y = 2;
let z = y; let z = y;
return works(z); return works(z);
} }

View file

@ -5,6 +5,25 @@ use tracing::{info, warn};
use crate::ast::{self, Expression, Function, Statement, TypeExp}; use crate::ast::{self, Expression, Function, Statement, TypeExp};
pub fn type_inference(ast: &mut ast::Program) { pub fn type_inference(ast: &mut ast::Program) {
let mut struct_cache: HashMap<String, HashMap<String, TypeExp>> = HashMap::new();
for statement in ast.statements.iter_mut() {
if let Statement::Struct(st) = statement {
let fields = st
.fields
.iter()
.map(|x| (x.ident.clone(), x.type_exp.clone()))
.collect();
struct_cache.insert(st.name.clone(), fields);
}
}
let mut fn_cache: HashMap<String, Function> = HashMap::new();
for statement in ast.statements.iter_mut() {
if let Statement::Function(function) = statement {
fn_cache.insert(function.name.clone(), function.clone());
}
}
for statement in ast.statements.iter_mut() { for statement in ast.statements.iter_mut() {
if let Statement::Function(function) = statement { if let Statement::Function(function) = statement {
let ret_type = function.return_type.clone(); let ret_type = function.return_type.clone();
@ -22,12 +41,12 @@ pub fn type_inference(ast: &mut ast::Program) {
} }
} }
update_statements(&mut function.body, &mut var_cache); update_statements(&mut function.body, &mut var_cache, &fn_cache);
} }
} }
} }
fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<String, TypeExp>) { fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<String, TypeExp>, fn_cache: &HashMap<String, Function>) {
let mut var_cache = var_cache.clone(); let mut var_cache = var_cache.clone();
{ {
@ -54,7 +73,7 @@ fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<Strin
if var_cache.contains_key(name) { if var_cache.contains_key(name) {
*value_type = var_cache.get(name).cloned(); *value_type = var_cache.get(name).cloned();
let mut env = Some(value_type.clone().unwrap()); let mut env = Some(value_type.clone().unwrap());
set_exp_types_from_cache(value, &mut var_cache, &mut env); set_exp_types_from_cache(value, &mut var_cache, &mut env, fn_cache);
} else { } else {
// no type info? // no type info?
} }
@ -80,7 +99,7 @@ fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<Strin
if var_cache.contains_key(name) { if var_cache.contains_key(name) {
*value_type = var_cache.get(name).cloned(); *value_type = var_cache.get(name).cloned();
let mut env = Some(value_type.clone().unwrap()); let mut env = Some(value_type.clone().unwrap());
set_exp_types_from_cache(value, &mut var_cache, &mut env); set_exp_types_from_cache(value, &mut var_cache, &mut env, fn_cache);
} else { } else {
// no type info? // no type info?
} }
@ -101,7 +120,7 @@ fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<Strin
if value_type.is_none() { if value_type.is_none() {
// evalue the value expr first to find a possible type. // evalue the value expr first to find a possible type.
let mut env = None; let mut env = None;
set_exp_types_from_cache(value, &mut var_cache, &mut env); set_exp_types_from_cache(value, &mut var_cache, &mut env, fn_cache);
// try to find if it was set on the cache // try to find if it was set on the cache
if var_cache.contains_key(name) { if var_cache.contains_key(name) {
@ -128,7 +147,7 @@ fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<Strin
if var_cache.contains_key(name) { if var_cache.contains_key(name) {
*value_type = var_cache.get(name).cloned(); *value_type = var_cache.get(name).cloned();
let mut env = Some(value_type.clone().unwrap()); let mut env = Some(value_type.clone().unwrap());
set_exp_types_from_cache(value, &mut var_cache, &mut env); set_exp_types_from_cache(value, &mut var_cache, &mut env, fn_cache);
} else { } else {
// no type info? // no type info?
} }
@ -140,16 +159,16 @@ fn update_statements(statements: &mut [Statement], var_cache: &mut HashMap<Strin
else_body, else_body,
} => { } => {
let mut env = None; let mut env = None;
set_exp_types_from_cache(condition, &mut var_cache, &mut env); set_exp_types_from_cache(condition, &mut var_cache, &mut env, fn_cache);
update_statements(body, &mut var_cache); update_statements(body, &mut var_cache, fn_cache);
if let Some(else_body) = else_body { if let Some(else_body) = else_body {
update_statements(else_body, &mut var_cache); update_statements(else_body, &mut var_cache, fn_cache);
} }
} }
Statement::Return(exp) => { Statement::Return(exp) => {
if let Some(exp) = exp { if let Some(exp) = exp {
let mut env = None; let mut env = None;
set_exp_types_from_cache(exp, &mut var_cache, &mut env); set_exp_types_from_cache(exp, &mut var_cache, &mut env, fn_cache);
} }
} }
Statement::Function(_) => unreachable!(), Statement::Function(_) => unreachable!(),
@ -172,6 +191,7 @@ fn set_exp_types_from_cache(
exp: &mut Expression, exp: &mut Expression,
var_cache: &mut HashMap<String, TypeExp>, var_cache: &mut HashMap<String, TypeExp>,
env: &mut Option<TypeExp>, env: &mut Option<TypeExp>,
fn_cache: &HashMap<String, Function>
) { ) {
match exp { match exp {
Expression::Variable { name, value_type } => { Expression::Variable { name, value_type } => {
@ -189,15 +209,15 @@ fn set_exp_types_from_cache(
} }
Expression::BinaryOp(lhs, op, rhs) => match op { Expression::BinaryOp(lhs, op, rhs) => match op {
ast::OpCode::Eq | ast::OpCode::Ne => { ast::OpCode::Eq | ast::OpCode::Ne => {
set_exp_types_from_cache(lhs, var_cache, env); set_exp_types_from_cache(lhs, var_cache, env, fn_cache);
set_exp_types_from_cache(rhs, var_cache, env); set_exp_types_from_cache(rhs, var_cache, env, fn_cache);
set_exp_types_from_cache(lhs, var_cache, env); set_exp_types_from_cache(lhs, var_cache, env, fn_cache);
*env = Some(TypeExp::Boolean); *env = Some(TypeExp::Boolean);
} }
_ => { _ => {
set_exp_types_from_cache(lhs, var_cache, env); set_exp_types_from_cache(lhs, var_cache, env, fn_cache);
set_exp_types_from_cache(rhs, var_cache, env); set_exp_types_from_cache(rhs, var_cache, env, fn_cache);
set_exp_types_from_cache(lhs, var_cache, env); // needed in case 2 == x set_exp_types_from_cache(lhs, var_cache, env, fn_cache); // needed in case 2 == x
} }
}, },
Expression::Literal(lit) => match lit { Expression::Literal(lit) => match lit {
@ -219,19 +239,28 @@ fn set_exp_types_from_cache(
} }
}, },
Expression::Call { Expression::Call {
function: _, function,
args: _, args,
value_type, value_type,
} => { } => {
let fn_type = fn_cache.get(function).unwrap().clone();
match value_type { match value_type {
Some(value_type) => *env = Some(value_type.clone()), Some(value_type) => *env = Some(value_type.clone()),
None => { None => {
if env.is_some() { if env.is_some() {
let env = env.clone();
*value_type = env.clone(); *value_type = env.clone();
} else {
*value_type = fn_type.return_type.clone();
*env = fn_type.return_type.clone();
} }
} }
} }
// TODO: infer args based on function args!
for (i, arg) in args.iter_mut().enumerate() {
let mut env = Some(fn_type.params[i].type_exp.clone());
set_exp_types_from_cache(arg, var_cache, &mut env, fn_cache);
}
} }
} }
} }