require fully typed integers for now

This commit is contained in:
Edgar 2023-06-11 12:07:15 +02:00
parent f3cc72e7ce
commit 8c212d948d
No known key found for this signature in database
GPG key ID: 70ADAE8F35904387
8 changed files with 208 additions and 233 deletions

15
programs/ifelse.ed Normal file
View file

@ -0,0 +1,15 @@
fn works(x: i64) -> i64 {
let z = 0i64;
if 2i64 == x {
z = x * 2i64;
} else {
z = x * 3i64;
}
return z;
}
fn main() -> i64 {
let y = 2i64;
let z = y;
return works(z);
}

View file

@ -9,7 +9,7 @@ fn test(x: Hello) {
fn works(x: i64) -> i64 {
let z = 0i64;
if 2 == x {
if 2i64 == x {
z = x * 2i64;
} else {
z = x * 3i64;

View file

@ -1,3 +1,5 @@
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Spanned<T> {
pub span: (usize, usize),
@ -64,12 +66,10 @@ pub enum Expression {
Literal(LiteralValue),
Variable {
name: Spanned<String>,
value_type: Option<TypeExp>,
},
Call {
function: String,
args: Vec<Box<Self>>,
value_type: Option<TypeExp>,
},
BinaryOp(Box<Self>, OpCode, Box<Self>),
}
@ -86,16 +86,17 @@ impl Parameter {
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, PartialEq)]
pub struct Function {
pub name: String,
pub params: Vec<Parameter>,
pub body: Vec<Statement>,
pub scope_type_info: HashMap<String, Vec<TypeExp>>,
pub return_type: Option<TypeExp>,
}
impl Function {
pub const fn new(
pub fn new(
name: String,
params: Vec<Parameter>,
body: Vec<Statement>,
@ -106,6 +107,7 @@ impl Function {
params,
body,
return_type,
scope_type_info: HashMap::new(),
}
}
}
@ -113,14 +115,14 @@ impl Function {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct StructField {
pub ident: String,
pub type_exp: TypeExp,
pub field_type: TypeExp,
}
impl StructField {
pub const fn new(ident: String, type_name: TypeExp) -> Self {
Self {
ident,
type_exp: type_name,
field_type: type_name,
}
}
}
@ -131,7 +133,7 @@ pub struct Struct {
pub fields: Vec<StructField>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, PartialEq)]
pub enum Statement {
Let {
name: String,
@ -142,13 +144,14 @@ pub enum Statement {
Mutate {
name: String,
value: Box<Expression>,
value_type: Option<TypeExp>,
span: (usize, usize),
},
If {
condition: Box<Expression>,
body: Vec<Statement>,
scope_type_info: HashMap<String, Vec<TypeExp>>,
else_body: Option<Vec<Statement>>,
else_body_scope_type_info: HashMap<String, Vec<TypeExp>>,
},
Return(Option<Box<Expression>>),
Function(Function),

View file

@ -37,10 +37,10 @@ pub struct CodeGen<'ctx> {
context: &'ctx Context,
pub module: Module<'ctx>,
builder: Builder<'ctx>,
types: TypeStorage<'ctx>,
//types: TypeStorage<'ctx>,
struct_types: StructTypeStorage<'ctx>,
// function to return type
functions: HashMap<String, (Vec<TypeExp>, Option<TypeExp>)>,
functions: HashMap<String, Function>,
_program: ProgramData,
ast: ast::Program,
}
@ -48,12 +48,12 @@ pub struct CodeGen<'ctx> {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Variable<'ctx> {
pub value: BasicValueEnum<'ctx>,
pub type_counter: usize,
pub phi_counter: usize,
pub type_exp: TypeExp,
}
pub type Variables<'ctx> = HashMap<String, Variable<'ctx>>;
pub type TypeStorage<'ctx> = HashMap<TypeExp, BasicTypeEnum<'ctx>>;
// pub type TypeStorage<'ctx> = HashMap<TypeExp, BasicTypeEnum<'ctx>>;
/// Holds the struct type and maps fields to types and the location within the struct.
#[derive(Debug, Clone, PartialEq, Eq)]
@ -79,7 +79,6 @@ impl<'ctx> CodeGen<'ctx> {
builder: context.create_builder(),
_program,
ast,
types: HashMap::new(),
struct_types: HashMap::new(),
functions: HashMap::new(),
};
@ -88,9 +87,8 @@ impl<'ctx> CodeGen<'ctx> {
}
pub fn compile_ast(&mut self) -> Result<()> {
let mut functions = vec![];
let mut func_info = HashMap::new();
let mut types: TypeStorage<'ctx> = HashMap::new();
let mut functions = HashMap::new();
// let mut types: TypeStorage<'ctx> = HashMap::new();
let mut struct_types: StructTypeStorage<'ctx> = HashMap::new();
// todo fix the grammar so top level statements are only functions and static vars.
@ -102,13 +100,11 @@ impl<'ctx> CodeGen<'ctx> {
let mut field_types = vec![];
for (i, field) in s.fields.iter().enumerate() {
if !types.contains_key(&field.type_exp) {
types.insert(field.type_exp.clone(), self.get_llvm_type(&field.type_exp)?);
}
let ty = self.get_llvm_type(&field.type_exp)?;
// todo: this doesnt handle out of order structs well
let ty = self.get_llvm_type(&field.field_type)?;
field_types.push(ty);
// todo: ensure alignment and padding here
fields.insert(field.ident.clone(), (i, field.type_exp.clone()));
fields.insert(field.ident.clone(), (i, field.field_type.clone()));
}
let ty = self.context.struct_type(&field_types, false);
@ -123,38 +119,16 @@ impl<'ctx> CodeGen<'ctx> {
// create the llvm functions first.
for statement in &self.ast.statements {
if let Statement::Function(function) = &statement {
functions.push(function);
let (args, ret_type) = self.compile_function_signature(function)?;
let mut arg_types = vec![];
for arg in args {
if !types.contains_key(&arg) {
let ty = self.get_llvm_type(&arg)?;
types.insert(arg.clone(), ty);
}
arg_types.push(arg);
}
if let Some(ret_type) = ret_type {
let ret_type = if !types.contains_key(&ret_type) {
let ty = self.get_llvm_type(&ret_type)?;
types.insert(ret_type.clone(), ty);
ret_type
} else {
ret_type
};
func_info.insert(function.name.clone(), (arg_types, Some(ret_type)));
} else {
func_info.insert(function.name.clone(), (arg_types, None));
}
functions.insert(function.name.clone(), function.clone());
self.compile_function_signature(function)?;
}
}
self.types = types;
self.functions = func_info;
self.functions = functions;
info!("functions:\n{:#?}", self.functions);
// implement them.
for function in functions {
for (_, function) in &self.functions {
self.compile_function(function)?;
}
@ -169,38 +143,31 @@ impl<'ctx> CodeGen<'ctx> {
}
fn get_llvm_type(&self, id: &TypeExp) -> Result<BasicTypeEnum<'ctx>> {
if let Some(ty) = self.types.get(id) {
Ok(*ty)
} else {
Ok(match id {
TypeExp::Integer { bits, signed: _ } => self
.context
.custom_width_int_type(*bits)
.as_basic_type_enum(),
TypeExp::Boolean => self.context.bool_type().as_basic_type_enum(),
TypeExp::Array { of, len } => {
let ty = self.get_llvm_type(of)?;
ty.array_type(len.unwrap()).as_basic_type_enum()
}
TypeExp::Pointer { target } => {
let ty = self.get_llvm_type(target)?;
ty.ptr_type(Default::default()).as_basic_type_enum()
}
TypeExp::Other { id } => self
.struct_types
.get(id)
.expect("struct type not found")
.ty
.as_basic_type_enum(),
})
}
Ok(match id {
TypeExp::Integer { bits, signed: _ } => self
.context
.custom_width_int_type(*bits)
.as_basic_type_enum(),
TypeExp::Boolean => self.context.bool_type().as_basic_type_enum(),
TypeExp::Array { of, len } => {
let ty = self.get_llvm_type(of)?;
ty.array_type(len.unwrap()).as_basic_type_enum()
}
TypeExp::Pointer { target } => {
let ty = self.get_llvm_type(target)?;
ty.ptr_type(Default::default()).as_basic_type_enum()
}
TypeExp::Other { id } => self
.struct_types
.get(id)
.expect("struct type not found")
.ty
.as_basic_type_enum(),
})
}
/// creates the llvm function without the body, so other function bodies can call it.
fn compile_function_signature(
&self,
function: &Function,
) -> Result<(Vec<TypeExp>, Option<TypeExp>)> {
fn compile_function_signature(&self, function: &Function) -> Result<()> {
let args_types: Vec<BasicTypeEnum<'ctx>> = function
.params
.iter()
@ -221,14 +188,7 @@ impl<'ctx> CodeGen<'ctx> {
self.module.add_function(&function.name, fn_type, None);
Ok((
function
.params
.iter()
.map(|param| param.type_exp.clone())
.collect(),
ret_type,
))
Ok(())
}
fn compile_function(&self, function: &Function) -> Result<()> {
@ -238,7 +198,6 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.position_at_end(entry_block);
let mut variables: Variables = HashMap::new();
let mut types: TypeStorage = self.types.clone();
for (i, param) in function.params.iter().enumerate() {
let id = &param.ident;
@ -250,7 +209,7 @@ impl<'ctx> CodeGen<'ctx> {
Variable {
value: param_value,
phi_counter: 0,
type_exp: param.type_exp.clone(),
type_counter: 0,
},
);
}
@ -261,7 +220,7 @@ impl<'ctx> CodeGen<'ctx> {
if let Statement::Return(_) = statement {
has_return = true
}
self.compile_statement(func, statement, &mut variables, &mut types)?;
self.compile_statement(func, statement, &mut variables, &function.scope_type_info)?;
}
if !has_return {
@ -277,7 +236,7 @@ impl<'ctx> CodeGen<'ctx> {
statement: &Statement,
// value, assignments
variables: &mut Variables<'ctx>,
types: &mut TypeStorage<'ctx>,
scope_info: &HashMap<String, Vec<TypeExp>>,
) -> Result<()> {
match statement {
// Variable assignment
@ -287,41 +246,32 @@ impl<'ctx> CodeGen<'ctx> {
value_type: _,
..
} => {
let (value, value_type) = self
.compile_expression(value, variables, types)?
let value = self
.compile_expression(value, variables, scope_info)?
.expect("should have result");
if !types.contains_key(&value_type) {
let ty = self.get_llvm_type(&value_type)?;
types.insert(value_type.clone(), ty);
}
info!("adding variable: name={}, ty={:?}", name, value_type);
variables.insert(
name.clone(),
Variable {
value,
phi_counter: 0,
type_exp: value_type,
type_counter: 0,
},
);
}
Statement::Mutate { name, value, .. } => {
let (value, value_type) = self
.compile_expression(value, variables, types)?
let value = self
.compile_expression(value, variables, scope_info)?
.expect("should have result");
let var = variables.get_mut(name).expect("variable should exist");
var.phi_counter += 1;
var.value = value;
assert_eq!(var.type_exp, value_type, "variable type shouldn't change!");
info!("mutated variable: name={}, ty={:?}", name, var.type_exp);
}
Statement::Return(ret) => {
if let Some(ret) = ret {
let (value, _value_type) = self
.compile_expression(ret, variables, types)?
let value = self
.compile_expression(ret, variables, scope_info)?
.expect("should have result");
self.builder.build_return(Some(&value));
} else {
@ -332,9 +282,11 @@ impl<'ctx> CodeGen<'ctx> {
condition,
body,
else_body,
scope_type_info,
else_body_scope_type_info,
} => {
let (condition, _cond_type) = self
.compile_expression(condition, variables, types)?
let condition = self
.compile_expression(condition, variables, scope_info)?
.expect("should produce a value");
let mut if_block = self.context.append_basic_block(function_value, "if");
@ -354,7 +306,7 @@ impl<'ctx> CodeGen<'ctx> {
let mut variables_if = variables.clone();
self.builder.position_at_end(if_block);
for s in body {
self.compile_statement(function_value, s, &mut variables_if, types)?;
self.compile_statement(function_value, s, &mut variables_if, scope_type_info)?;
}
self.builder.build_unconditional_branch(merge_block);
if_block = self.builder.get_insert_block().unwrap(); // update for phi
@ -364,7 +316,12 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.position_at_end(else_block);
for s in else_body {
self.compile_statement(function_value, s, &mut variables_else, types)?;
self.compile_statement(
function_value,
s,
&mut variables_else,
else_body_scope_type_info,
)?;
}
self.builder.build_unconditional_branch(merge_block);
else_block = self.builder.get_insert_block().unwrap(); // update for phi
@ -381,7 +338,7 @@ impl<'ctx> CodeGen<'ctx> {
.builder
.build_phi(old_var.value.get_type(), &format!("{name}_phi"));
phi.add_incoming(&[(&new_var.value, if_block)]);
processed_vars.insert(name, (phi, new_var.type_exp));
processed_vars.insert(name, phi);
}
}
}
@ -391,7 +348,7 @@ impl<'ctx> CodeGen<'ctx> {
if variables.contains_key(&name) {
let old_var = variables.get(&name).unwrap();
if new_var.phi_counter > old_var.phi_counter {
if let Some((phi, _)) = processed_vars.get(&name) {
if let Some(phi) = processed_vars.get(&name) {
phi.add_incoming(&[(&new_var.value, else_block)]);
} else {
let phi = self.builder.build_phi(
@ -399,22 +356,25 @@ impl<'ctx> CodeGen<'ctx> {
&format!("{name}_phi"),
);
phi.add_incoming(&[(&old_var.value, else_block)]);
processed_vars.insert(name, (phi, new_var.type_exp));
processed_vars.insert(name, phi);
}
}
}
}
}
for (name, (phi, type_exp)) in processed_vars {
for (name, phi) in processed_vars {
/*
variables.insert(
name,
Variable {
value: phi.as_basic_value(),
phi_counter: 0,
type_exp,
},
);
*/
let mut var = variables.get_mut(&name).unwrap();
var.value = phi.as_basic_value();
}
}
Statement::Function(_) => unreachable!(),
@ -428,21 +388,16 @@ impl<'ctx> CodeGen<'ctx> {
&self,
expr: &Expression,
variables: &mut Variables<'ctx>,
types: &mut TypeStorage<'ctx>,
) -> Result<Option<(BasicValueEnum<'ctx>, TypeExp)>> {
scope_info: &HashMap<String, Vec<TypeExp>>,
) -> Result<Option<BasicValueEnum<'ctx>>> {
Ok(match expr {
Expression::Variable {
name,
value_type: _,
} => Some(self.compile_variable(&name.value, variables)?),
Expression::Variable { name } => Some(self.compile_variable(&name.value, variables)?),
Expression::Literal(term) => Some(self.compile_literal(term)?),
Expression::Call {
function,
args,
value_type,
} => self.compile_call(function, args, variables, types, value_type.clone())?,
Expression::Call { function, args } => {
self.compile_call(function, args, variables, scope_info)?
}
Expression::BinaryOp(lhs, op, rhs) => {
Some(self.compile_binary_op(lhs, op, rhs, variables, types)?)
Some(self.compile_binary_op(lhs, op, rhs, variables, scope_info)?)
}
})
}
@ -452,17 +407,16 @@ impl<'ctx> CodeGen<'ctx> {
func_name: &str,
args: &[Box<Expression>],
variables: &mut Variables<'ctx>,
types: &mut TypeStorage<'ctx>,
value_type: Option<TypeExp>,
) -> Result<Option<(BasicValueEnum<'ctx>, TypeExp)>> {
scope_info: &HashMap<String, Vec<TypeExp>>,
) -> Result<Option<BasicValueEnum<'ctx>>> {
info!("compiling fn call: func_name={}", func_name);
let function = self.module.get_function(func_name).expect("should exist");
let mut value_args: Vec<BasicMetadataValueEnum> = Vec::with_capacity(args.len());
for arg in args.iter() {
let (res, _res_type) = self
.compile_expression(arg, variables, types)?
let res = self
.compile_expression(arg, variables, scope_info)?
.expect("should have result");
value_args.push(res.into());
}
@ -473,7 +427,7 @@ impl<'ctx> CodeGen<'ctx> {
.try_as_basic_value();
Ok(match result {
Either::Left(val) => Some((val, value_type.unwrap())),
Either::Left(val) => Some(val),
Either::Right(_) => None,
})
}
@ -484,20 +438,18 @@ impl<'ctx> CodeGen<'ctx> {
op: &OpCode,
rhs: &Expression,
variables: &mut Variables<'ctx>,
types: &mut TypeStorage<'ctx>,
) -> Result<(BasicValueEnum<'ctx>, TypeExp)> {
let (lhs, lhs_type) = self
.compile_expression(lhs, variables, types)?
scope_info: &HashMap<String, Vec<TypeExp>>,
) -> Result<BasicValueEnum<'ctx>> {
let lhs = self
.compile_expression(lhs, variables, scope_info)?
.expect("should have result");
let (rhs, rhs_type) = self
.compile_expression(rhs, variables, types)?
let rhs = self
.compile_expression(rhs, variables, scope_info)?
.expect("should have result");
assert_eq!(lhs_type, rhs_type);
let lhs = lhs.into_int_value();
let rhs = rhs.into_int_value();
let mut bool_result = false;
let result = match op {
OpCode::Add => self.builder.build_int_add(lhs, rhs, "add"),
OpCode::Sub => self.builder.build_int_sub(lhs, rhs, "sub"),
@ -506,31 +458,18 @@ impl<'ctx> CodeGen<'ctx> {
OpCode::Rem => self.builder.build_int_signed_rem(lhs, rhs, "rem"),
OpCode::And => self.builder.build_and(lhs, rhs, "and"),
OpCode::Or => self.builder.build_or(lhs, rhs, "or"),
OpCode::Eq => {
bool_result = true;
self.builder
.build_int_compare(IntPredicate::EQ, lhs, rhs, "eq")
}
OpCode::Ne => {
bool_result = true;
self.builder
.build_int_compare(IntPredicate::NE, lhs, rhs, "eq")
}
OpCode::Eq => self
.builder
.build_int_compare(IntPredicate::EQ, lhs, rhs, "eq"),
OpCode::Ne => self
.builder
.build_int_compare(IntPredicate::NE, lhs, rhs, "eq"),
};
let mut res_type = lhs_type;
if bool_result {
res_type = TypeExp::Integer {
bits: 1,
signed: false,
};
}
Ok((result.as_basic_value_enum(), res_type))
Ok(result.as_basic_value_enum())
}
pub fn compile_literal(&self, term: &LiteralValue) -> Result<(BasicValueEnum<'ctx>, TypeExp)> {
pub fn compile_literal(&self, term: &LiteralValue) -> Result<BasicValueEnum<'ctx>> {
let value = match term {
LiteralValue::String(_s) => {
todo!()
@ -540,13 +479,11 @@ impl<'ctx> CodeGen<'ctx> {
.const_string(s.as_bytes(), true)
.as_basic_value_enum() */
}
LiteralValue::Boolean(v) => (
self.context
.bool_type()
.const_int((*v).into(), false)
.as_basic_value_enum(),
TypeExp::Boolean,
),
LiteralValue::Boolean(v) => self
.context
.bool_type()
.const_int((*v).into(), false)
.as_basic_value_enum(),
LiteralValue::Integer {
value,
bits,
@ -554,13 +491,11 @@ impl<'ctx> CodeGen<'ctx> {
} => {
let bits = *bits;
let signed = *signed;
(
self.context
.custom_width_int_type(bits)
.const_int(value.parse().unwrap(), false)
.as_basic_value_enum(),
TypeExp::Integer { bits, signed },
)
self.context
.custom_width_int_type(bits)
.const_int(value.parse().unwrap(), false)
.as_basic_value_enum()
}
};
@ -571,8 +506,8 @@ impl<'ctx> CodeGen<'ctx> {
&self,
variable: &str,
variables: &mut Variables<'ctx>,
) -> Result<(BasicValueEnum<'ctx>, TypeExp)> {
) -> Result<BasicValueEnum<'ctx>> {
let var = variables.get(variable).expect("value").clone();
Ok((var.value, var.type_exp))
Ok(var.value)
}
}

View file

@ -1,4 +1,4 @@
use std::str::FromStr;
use std::collections::HashMap;
use crate::{
ast::{self, Spanned},
tokens::Token,
@ -99,9 +99,15 @@ BasicStatement: ast::Statement = {
<lo:@L> "let" <i:"identifier"> <t:TypeInfo?> "=" <e:Expr> ";" <hi:@R> =>
ast::Statement::Let { name: i, value: e, value_type: t, span: (lo, hi) },
<lo:@L> <i:"identifier"> "=" <e:Expr> ";" <hi:@R> =>
ast::Statement::Mutate { name: i, value: e, span: (lo, hi), value_type: None },
ast::Statement::Mutate { name: i, value: e, span: (lo, hi) },
"if" <cond:Expr> "{" <s:Statements> "}" <e:ElseExpr?> =>
ast::Statement::If { condition: cond, body: s, else_body: e},
ast::Statement::If {
condition: cond,
body: s,
else_body: e,
scope_type_info: Default::default(),
else_body_scope_type_info: Default::default(),
},
"return" <e:Expr?> ";" => ast::Statement::Return(e),
};
@ -143,13 +149,12 @@ Expr4 = Tier<Level3_Op, Term>;
// Terms: variables, literals, calls
Term: Box<ast::Expression> = {
<lo:@L> <i:"identifier"> <hi:@R> => Box::new(ast::Expression::Variable {
name: Spanned::new(i, (lo, hi)),
value_type: None
name: Spanned::new(i, (lo, hi))
}),
<n:Number> => Box::new(ast::Expression::Literal(n)),
<n:StringLit> => Box::new(ast::Expression::Literal(n)),
<n:BoolLiteral> => Box::new(ast::Expression::Literal(n)),
<i:"identifier"> "(" <values:Comma<Term>> ")" => Box::new(ast::Expression::Call { function: i, args: values, value_type: None }),
<i:"identifier"> "(" <values:Comma<Term>> ")" => Box::new(ast::Expression::Call { function: i, args: values }),
"(" <Term> ")"
};

View file

@ -102,7 +102,7 @@ fn main() -> Result<()> {
let lexer = Lexer::new(code.as_str());
let parser = grammar::ProgramParser::new();
let mut ast = parser.parse(lexer)?;
type_analysis::type_inference2(&mut ast);
type_analysis::type_inference(&mut ast);
let program = ProgramData::new(&input, &code);
check_program(&program, &ast);
}
@ -112,7 +112,7 @@ fn main() -> Result<()> {
let parser = grammar::ProgramParser::new();
match parser.parse(lexer) {
Ok(mut ast) => {
type_analysis::type_inference2(&mut ast);
type_analysis::type_inference(&mut ast);
println!("{ast:#?}");
}
Err(e) => {
@ -130,7 +130,7 @@ fn main() -> Result<()> {
let lexer = Lexer::new(code.as_str());
let parser = grammar::ProgramParser::new();
let mut ast: Program = parser.parse(lexer)?;
type_analysis::type_inference2(&mut ast);
type_analysis::type_inference(&mut ast);
let program = ProgramData::new(&input, &code);

View file

@ -8,19 +8,15 @@ struct Storage {
functions: HashMap<String, Function>,
}
/*
To briefly summarize the union-find algorithm, given the set of all types in a proof,
it allows one to group them together into equivalence classes by means of a union procedure and to
pick a representative for each such class using a find procedure. Emphasizing the word procedure in
the sense of side effect, we're clearly leaving the realm of logic in order to prepare an effective algorithm.
The representative of a u n i o n ( a , b ) {\mathtt {union}}(a,b) is determined such that, if both a and b are
type variables then the representative is arbitrarily one of them, but while uniting a variable and a term, the
term becomes the representative. Assuming an implementation of union-find at hand, one can formulate the unification of two monotypes as follows:
*/
// problem with scopes,
// let x = 2;
// let x = 2i64;
type ScopeMap = HashMap<String, Vec<Option<TypeExp>>>;
// this works, but need to find a way to store the found info + handle literal integer types (or not?)
// maybe use scope ids
pub fn type_inference2(ast: &mut ast::Program) {
pub fn type_inference(ast: &mut ast::Program) {
let mut storage = Storage::default();
// gather global constructs first
@ -30,7 +26,7 @@ pub fn type_inference2(ast: &mut ast::Program) {
let fields = st
.fields
.iter()
.map(|x| (x.ident.clone(), x.type_exp.clone()))
.map(|x| (x.ident.clone(), x.field_type.clone()))
.collect();
storage.structs.insert(st.name.clone(), fields);
}
@ -46,26 +42,33 @@ pub fn type_inference2(ast: &mut ast::Program) {
dbg!(&storage);
for function in storage.functions.values() {
let mut scope_vars: HashMap<String, Option<TypeExp>> = HashMap::new();
for statement in ast.statements.iter_mut() {
if let Statement::Function(function) = statement {
let mut scope_vars: ScopeMap = HashMap::new();
for arg in &function.params {
scope_vars.insert(arg.ident.clone(), Some(arg.type_exp.clone()));
for arg in &function.params {
scope_vars.insert(arg.ident.clone(), vec![Some(arg.type_exp.clone())]);
}
let func_info = function.clone();
let (new_scope_vars, _) =
type_inference_scope(&mut function.body, &scope_vars, &func_info, &storage);
// todo: check all vars have type info?
function.scope_type_info = new_scope_vars
.into_iter()
.map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect()))
.collect();
}
let (new_scope_vars, _) =
type_inference_scope(&function.body, &scope_vars, function, &storage);
dbg!(new_scope_vars);
}
}
/// Finds variable types in the scope, returns newly created variables to handle shadowing
fn type_inference_scope(
statements: &[ast::Statement],
scope_vars: &HashMap<String, Option<TypeExp>>,
statements: &mut [ast::Statement],
scope_vars: &ScopeMap,
func: &Function,
storage: &Storage,
) -> (HashMap<String, Option<TypeExp>>, HashSet<String>) {
) -> (ScopeMap, HashSet<String>) {
let mut scope_vars = scope_vars.clone();
let mut new_vars: HashSet<String> = HashSet::new();
@ -81,19 +84,24 @@ fn type_inference_scope(
let exp_type = type_inference_expression(value, &mut scope_vars, storage, None);
if !scope_vars.contains_key(name) {
scope_vars.insert(name.clone(), vec![]);
}
let var = scope_vars.get_mut(name).unwrap();
if value_type.is_none() {
scope_vars.insert(name.clone(), exp_type);
var.push(exp_type);
} else {
if exp_type.is_some() && &exp_type != value_type {
panic!("let type mismatch: {:?} != {:?}", value_type, exp_type);
}
scope_vars.insert(name.clone(), value_type.clone());
var.push(value_type.clone());
}
}
Statement::Mutate {
name,
value,
value_type: _,
span: _,
} => {
if !scope_vars.contains_key(name) {
@ -101,7 +109,7 @@ fn type_inference_scope(
}
let exp_type = type_inference_expression(value, &mut scope_vars, storage, None);
let var = scope_vars.get_mut(name).unwrap();
let var = scope_vars.get_mut(name).unwrap().last_mut().unwrap();
if var.is_none() {
*var = exp_type;
@ -113,6 +121,8 @@ fn type_inference_scope(
condition,
body,
else_body,
scope_type_info,
else_body_scope_type_info,
} => {
type_inference_expression(
condition,
@ -124,23 +134,33 @@ fn type_inference_scope(
let (new_scope_vars, new_vars) =
type_inference_scope(body, &scope_vars, func, storage);
for (k, v) in new_scope_vars.into_iter() {
for (k, v) in new_scope_vars.iter() {
// not a new var within the scope (shadowing), so type info is valid
if scope_vars.contains_key(&k) && !new_vars.contains(&k) {
scope_vars.insert(k, v);
if scope_vars.contains_key(k) && !new_vars.contains(k) {
scope_vars.insert(k.clone(), v.clone());
}
}
*scope_type_info = new_scope_vars
.into_iter()
.map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect()))
.collect();
if let Some(body) = else_body {
let (new_scope_vars, new_vars) =
type_inference_scope(body, &scope_vars, func, storage);
for (k, v) in new_scope_vars.into_iter() {
for (k, v) in new_scope_vars.iter() {
// not a new var within the scope (shadowing), so type info is valid
if scope_vars.contains_key(&k) && !new_vars.contains(&k) {
scope_vars.insert(k, v);
if scope_vars.contains_key(k) && !new_vars.contains(k) {
scope_vars.insert(k.clone(), v.clone());
}
}
*else_body_scope_type_info = new_scope_vars
.into_iter()
.map(|(a, b)| (a, b.into_iter().map(Option::unwrap).collect()))
.collect();
}
}
Statement::Return(exp) => {
@ -163,7 +183,7 @@ fn type_inference_scope(
fn type_inference_expression(
exp: &Expression,
scope_vars: &mut HashMap<String, Option<TypeExp>>,
scope_vars: &mut ScopeMap,
storage: &Storage,
expected_type: Option<TypeExp>,
) -> Option<TypeExp> {
@ -182,31 +202,28 @@ fn type_inference_expression(
ast::LiteralValue::Boolean(_) => Some(TypeExp::Boolean),
}
}
Expression::Variable {
name,
value_type: _,
} => {
let var = scope_vars.get(&name.value).cloned().flatten();
Expression::Variable { name } => {
let var = scope_vars
.get_mut(&name.value)
.expect("to exist")
.last_mut()
.unwrap();
if expected_type.is_some() {
if var.is_none() {
scope_vars.insert(name.value.clone(), expected_type.clone());
*var = expected_type.clone();
expected_type
} else if expected_type.is_some() {
assert_eq!(var, expected_type, "type mismatch with variables");
assert_eq!(*var, expected_type, "type mismatch with variables");
expected_type
} else {
var
var.clone()
}
} else {
var
var.clone()
}
}
Expression::Call {
function,
args,
value_type: _,
} => {
Expression::Call { function, args } => {
let func = storage.functions.get(function).cloned().unwrap();
for (i, arg) in args.iter().enumerate() {