rly basic type inference

This commit is contained in:
Edgar 2023-05-20 10:54:25 +02:00
parent 82940755d8
commit 81b57d646d
No known key found for this signature in database
GPG key ID: 70ADAE8F35904387
5 changed files with 140 additions and 54 deletions

View file

@ -1,5 +1,5 @@
fn main(x: i32, z: i32) -> i32 { fn main(x: i64, z: i64) -> i64 {
let y = 0; let y: i64 = 0;
if x == 5 { if x == 5 {
if x == z { if x == z {
y = 2 * x; y = 2 * x;

View file

@ -86,7 +86,12 @@ impl Function {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Statement { pub enum Statement {
Variable { Let {
name: String,
value: Box<Expression>,
type_name: Option<String>,
},
Mutate {
name: String, name: String,
value: Box<Expression>, value: Box<Expression>,
}, },

View file

@ -36,10 +36,14 @@ pub struct CodeGen<'ctx> {
context: &'ctx Context, context: &'ctx Context,
pub module: Module<'ctx>, pub module: Module<'ctx>,
builder: Builder<'ctx>, builder: Builder<'ctx>,
fn_types: VariableTypes<'ctx>,
_program: ProgramData, _program: ProgramData,
ast: ast::Program, ast: ast::Program,
} }
type Variables<'ctx> = HashMap<String, (BasicValueEnum<'ctx>, usize)>;
type VariableTypes<'ctx> = HashMap<String, BasicTypeEnum<'ctx>>;
impl<'ctx> CodeGen<'ctx> { impl<'ctx> CodeGen<'ctx> {
pub fn new( pub fn new(
context: &'ctx Context, context: &'ctx Context,
@ -55,29 +59,38 @@ impl<'ctx> CodeGen<'ctx> {
builder: context.create_builder(), builder: context.create_builder(),
_program, _program,
ast, ast,
fn_types: HashMap::new(),
}; };
Ok(codegen) Ok(codegen)
} }
pub fn compile_ast(&self) -> Result<()> { pub fn compile_ast(&mut self) -> Result<()> {
let mut functions = vec![]; let mut functions = vec![];
let mut types: VariableTypes<'ctx> = HashMap::new();
// todo fix the grammar so top level statements are only functions and static vars. // todo fix the grammar so top level statements are only functions and static vars.
// create the llvm functions first. // create the llvm functions first.
for statement in &self.ast.statements { for statement in &self.ast.statements {
match &statement { match &statement {
Statement::Variable { .. } => unreachable!(), Statement::Let { .. } => unreachable!(),
Statement::Mutate { .. } => unreachable!(),
Statement::Return(_) => unreachable!(), Statement::Return(_) => unreachable!(),
Statement::If { .. } => unreachable!(), Statement::If { .. } => unreachable!(),
Statement::Function(function) => { Statement::Function(function) => {
functions.push(function); functions.push(function);
self.compile_function_signature(function)?; let ret_type = self.compile_function_signature(function)?;
if let Some(ret_type) = ret_type {
types.insert(function.name.clone(), ret_type);
}
} }
} }
} }
self.fn_types = types;
// implement them. // implement them.
for function in functions { for function in functions {
self.compile_function(function)?; self.compile_function(function)?;
@ -104,7 +117,10 @@ impl<'ctx> CodeGen<'ctx> {
} }
/// creates the llvm function without the body, so other function bodies can call it. /// creates the llvm function without the body, so other function bodies can call it.
fn compile_function_signature(&self, function: &Function) -> Result<()> { fn compile_function_signature(
&self,
function: &Function,
) -> Result<Option<BasicTypeEnum<'ctx>>> {
let args_types: Vec<BasicTypeEnum<'ctx>> = function let args_types: Vec<BasicTypeEnum<'ctx>> = function
.params .params
.iter() .iter()
@ -116,13 +132,16 @@ impl<'ctx> CodeGen<'ctx> {
args_types.into_iter().map(|t| t.into()).collect_vec(); args_types.into_iter().map(|t| t.into()).collect_vec();
let fn_type = match &function.return_type { let fn_type = match &function.return_type {
Some(id) => self.get_llvm_type(id)?.fn_type(&args_types, false), Some(id) => {
let return_type = self.get_llvm_type(id)?;
return_type.fn_type(&args_types, false)
}
None => self.context.void_type().fn_type(&args_types, false), None => self.context.void_type().fn_type(&args_types, false),
}; };
self.module.add_function(&function.name, fn_type, None); self.module.add_function(&function.name, fn_type, None);
Ok(()) Ok(fn_type.get_return_type())
} }
fn compile_function(&self, function: &Function) -> Result<()> { fn compile_function(&self, function: &Function) -> Result<()> {
@ -131,18 +150,16 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.position_at_end(entry_block); self.builder.position_at_end(entry_block);
let mut variables: HashMap<String, (BasicValueEnum<'ctx>, usize)> = HashMap::new(); let mut variables: Variables = HashMap::new();
let mut types: VariableTypes = HashMap::new();
for (i, param) in function.params.iter().enumerate() { for (i, param) in function.params.iter().enumerate() {
let id = param.ident.clone(); let id = param.ident.clone();
variables.insert( let param = func
id.clone(), .get_nth_param(i.try_into().unwrap())
( .expect("parameter");
func.get_nth_param(i.try_into().unwrap()) variables.insert(id.clone(), (param, 0));
.expect("parameter"), types.insert(id.clone(), param.get_type());
0,
),
);
} }
let mut has_return = false; let mut has_return = false;
@ -151,7 +168,7 @@ impl<'ctx> CodeGen<'ctx> {
if let Statement::Return(_) = statement { if let Statement::Return(_) = statement {
has_return = true has_return = true
} }
self.compile_statement(statement, &mut variables)?; self.compile_statement(statement, &mut variables, &mut types)?;
} }
if !has_return { if !has_return {
@ -161,30 +178,70 @@ impl<'ctx> CodeGen<'ctx> {
Ok(()) Ok(())
} }
fn find_expr_type(
&self,
expr: &Expression,
types: &VariableTypes<'ctx>,
) -> Option<BasicTypeEnum<'ctx>> {
match expr {
Expression::Literal(x) => match x {
LiteralValue::String => todo!(),
LiteralValue::Integer {
bits,
signed,
value,
} => bits.map(|bits| self.context.custom_width_int_type(bits).into()),
},
Expression::Variable(x) => types.get(x).cloned(),
Expression::Call { function, args } => types.get(function).cloned(),
Expression::BinaryOp(lhs, op, rhs) => self
.find_expr_type(lhs, types)
.or_else(|| self.find_expr_type(rhs, types)),
}
}
fn compile_statement( fn compile_statement(
&self, &self,
statement: &Statement, statement: &Statement,
// value, assignments // value, assignments
variables: &mut HashMap<String, (BasicValueEnum<'ctx>, usize)>, variables: &mut Variables<'ctx>,
types: &mut VariableTypes<'ctx>,
) -> Result<()> { ) -> Result<()> {
match statement { match statement {
// Variable assignment // Variable assignment
Statement::Variable { name, value } => { Statement::Let {
name,
value,
type_name,
} => {
let type_hint = if let Some(type_name) = type_name {
self.get_llvm_type(type_name)?
} else {
self.find_expr_type(value, types)
.expect("type should be found")
};
types.insert(name.clone(), type_hint);
let result = self let result = self
.compile_expression(value, variables)? .compile_expression(value, variables, types, Some(type_hint))?
.expect("should have result"); .expect("should have result");
let accesses = if let Some(x) = variables.get(name) { variables.insert(name.clone(), (result, 0));
x.1 + 1 }
} else { Statement::Mutate { name, value } => {
0 let type_hint = *types.get(name).expect("should exist");
}; let result = self
variables.insert(name.clone(), (result, accesses)); .compile_expression(value, variables, types, Some(type_hint))?
.expect("should have result");
let (old_val, acc) = variables.get(name).expect("variable should exist");
variables.insert(name.clone(), (result, acc + 1));
} }
Statement::Return(ret) => { Statement::Return(ret) => {
if let Some(ret) = ret { if let Some(ret) = ret {
let type_hint = self.find_expr_type(ret, types);
let result = self let result = self
.compile_expression(ret, variables)? .compile_expression(ret, variables, types, type_hint)?
.expect("should have result"); .expect("should have result");
self.builder.build_return(Some(&result)); self.builder.build_return(Some(&result));
} else { } else {
@ -196,8 +253,9 @@ impl<'ctx> CodeGen<'ctx> {
body, body,
else_body, else_body,
} => { } => {
let type_hint_cond = self.find_expr_type(condition, types);
let condition = self let condition = self
.compile_expression(condition, variables)? .compile_expression(condition, variables, types, type_hint_cond)?
.expect("should produce a value"); .expect("should produce a value");
let func = self let func = self
@ -224,7 +282,7 @@ impl<'ctx> CodeGen<'ctx> {
let mut variables_if = variables.clone(); let mut variables_if = variables.clone();
self.builder.position_at_end(if_block); self.builder.position_at_end(if_block);
for s in body { for s in body {
self.compile_statement(s, &mut variables_if)?; self.compile_statement(s, &mut variables_if, types)?;
} }
self.builder.build_unconditional_branch(merge_block); self.builder.build_unconditional_branch(merge_block);
if_block = self.builder.get_insert_block().unwrap(); // update for phi if_block = self.builder.get_insert_block().unwrap(); // update for phi
@ -234,7 +292,7 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.position_at_end(else_block); self.builder.position_at_end(else_block);
for s in else_body { for s in else_body {
self.compile_statement(s, &mut variables_else)?; self.compile_statement(s, &mut variables_else, types)?;
} }
self.builder.build_unconditional_branch(merge_block); self.builder.build_unconditional_branch(merge_block);
else_block = self.builder.get_insert_block().unwrap(); // update for phi else_block = self.builder.get_insert_block().unwrap(); // update for phi
@ -288,14 +346,18 @@ impl<'ctx> CodeGen<'ctx> {
pub fn compile_expression( pub fn compile_expression(
&self, &self,
expr: &Expression, expr: &Expression,
variables: &mut HashMap<String, (BasicValueEnum<'ctx>, usize)>, variables: &mut Variables<'ctx>,
types: &mut VariableTypes<'ctx>,
type_hint: Option<BasicTypeEnum<'ctx>>,
) -> Result<Option<BasicValueEnum<'ctx>>> { ) -> Result<Option<BasicValueEnum<'ctx>>> {
Ok(match expr { Ok(match expr {
Expression::Variable(term) => Some(self.compile_variable(term, variables)?), Expression::Variable(term) => Some(self.compile_variable(term, variables, types)?),
Expression::Literal(term) => Some(self.compile_literal(term)?), Expression::Literal(term) => Some(self.compile_literal(term, type_hint)?),
Expression::Call { function, args } => self.compile_call(function, args, variables)?, Expression::Call { function, args } => {
self.compile_call(function, args, variables, types)?
}
Expression::BinaryOp(lhs, op, rhs) => { Expression::BinaryOp(lhs, op, rhs) => {
Some(self.compile_binary_op(lhs, op, rhs, variables)?) Some(self.compile_binary_op(lhs, op, rhs, variables, types, type_hint)?)
} }
}) })
} }
@ -304,15 +366,17 @@ impl<'ctx> CodeGen<'ctx> {
&self, &self,
func_name: &str, func_name: &str,
args: &[Box<Expression>], args: &[Box<Expression>],
variables: &mut HashMap<String, (BasicValueEnum<'ctx>, usize)>, variables: &mut Variables<'ctx>,
types: &mut VariableTypes<'ctx>,
) -> Result<Option<BasicValueEnum<'ctx>>> { ) -> Result<Option<BasicValueEnum<'ctx>>> {
let function = self.module.get_function(func_name).expect("should exist"); let function = self.module.get_function(func_name).expect("should exist");
let mut value_args: Vec<BasicMetadataValueEnum> = Vec::with_capacity(args.len()); let mut value_args: Vec<BasicMetadataValueEnum> = Vec::with_capacity(args.len());
for arg in args { for arg in args {
let type_enum = self.find_expr_type(arg, types);
let res = self let res = self
.compile_expression(arg, variables)? .compile_expression(arg, variables, types, type_enum)?
.expect("should have result"); .expect("should have result");
value_args.push(res.into()); value_args.push(res.into());
} }
@ -333,14 +397,16 @@ impl<'ctx> CodeGen<'ctx> {
lhs: &Expression, lhs: &Expression,
op: &OpCode, op: &OpCode,
rhs: &Expression, rhs: &Expression,
variables: &mut HashMap<String, (BasicValueEnum<'ctx>, usize)>, variables: &mut Variables<'ctx>,
types: &mut VariableTypes<'ctx>,
type_hint: Option<BasicTypeEnum<'ctx>>,
) -> Result<BasicValueEnum<'ctx>> { ) -> Result<BasicValueEnum<'ctx>> {
let lhs = self let lhs = self
.compile_expression(lhs, variables)? .compile_expression(lhs, variables, types, type_hint)?
.expect("should have result") .expect("should have result")
.into_int_value(); .into_int_value();
let rhs = self let rhs = self
.compile_expression(rhs, variables)? .compile_expression(rhs, variables, types, type_hint)?
.expect("should have result") .expect("should have result")
.into_int_value(); .into_int_value();
@ -363,7 +429,11 @@ impl<'ctx> CodeGen<'ctx> {
Ok(result.as_basic_value_enum()) Ok(result.as_basic_value_enum())
} }
pub fn compile_literal(&self, term: &LiteralValue) -> Result<BasicValueEnum<'ctx>> { pub fn compile_literal(
&self,
term: &LiteralValue,
type_hint: Option<BasicTypeEnum<'ctx>>,
) -> Result<BasicValueEnum<'ctx>> {
let value = match term { let value = match term {
LiteralValue::String => todo!(), LiteralValue::String => todo!(),
LiteralValue::Integer { LiteralValue::Integer {
@ -371,13 +441,19 @@ impl<'ctx> CodeGen<'ctx> {
signed: _, signed: _,
value, value,
} => { } => {
// todo: type resolution for bit size? if let Some(type_hint) = type_hint {
let bits = bits.unwrap_or(32); type_hint
.into_int_type()
.const_int(value.parse().unwrap(), false)
.as_basic_value_enum()
} else {
let bits = bits.unwrap_or(32);
self.context self.context
.custom_width_int_type(bits) .custom_width_int_type(bits)
.const_int(value.parse().unwrap(), false) .const_int(value.parse().unwrap(), false)
.as_basic_value_enum() .as_basic_value_enum()
}
} }
}; };
@ -387,7 +463,8 @@ impl<'ctx> CodeGen<'ctx> {
pub fn compile_variable( pub fn compile_variable(
&self, &self,
variable: &str, variable: &str,
variables: &mut HashMap<String, (BasicValueEnum<'ctx>, usize)>, variables: &mut Variables<'ctx>,
types: &mut VariableTypes<'ctx>,
) -> Result<BasicValueEnum<'ctx>> { ) -> Result<BasicValueEnum<'ctx>> {
let var = *variables.get(variable).expect("value"); let var = *variables.get(variable).expect("value");
Ok(var.0) Ok(var.0)

View file

@ -69,10 +69,14 @@ Statement: ast::Statement = {
<f:Function> => ast::Statement::Function(f), <f:Function> => ast::Statement::Function(f),
}; };
TypeInfo: String = {
":" <i:"identifier"> => i
};
// statements not including function definitions // statements not including function definitions
BasicStatement: ast::Statement = { BasicStatement: ast::Statement = {
"let" <i:"identifier"> "=" <e:Expr> ";" => ast::Statement::Variable { name: i, value: e}, "let" <i:"identifier"> <t:TypeInfo?> "=" <e:Expr> ";" => ast::Statement::Let { name: i, value: e, type_name: t},
<i:"identifier"> "=" <e:Expr> ";" => ast::Statement::Variable { name: i, value: e}, <i:"identifier"> "=" <e:Expr> ";" => ast::Statement::Mutate { name: i, value: e},
"if" <cond:Expr> "{" <s:Statements> "}" <e:ElseExpr?> => ast::Statement::If { condition: cond, body: s, else_body: e}, "if" <cond:Expr> "{" <s:Statements> "}" <e:ElseExpr?> => ast::Statement::If { condition: cond, body: s, else_body: e},
"return" <e:Expr?> ";" => ast::Statement::Return(e), "return" <e:Expr?> ";" => ast::Statement::Return(e),
}; };

View file

@ -123,7 +123,7 @@ fn main() -> Result<()> {
println!("{:#?}", ast); println!("{:#?}", ast);
let context = Context::create(); let context = Context::create();
let codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; let mut codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?;
codegen.compile_ast()?; codegen.compile_ast()?;
let generated_llvm_ir = codegen.generated_code(); let generated_llvm_ir = codegen.generated_code();
@ -144,7 +144,7 @@ fn main() -> Result<()> {
let file_name = input.file_name().unwrap().to_string_lossy(); let file_name = input.file_name().unwrap().to_string_lossy();
let context = Context::create(); let context = Context::create();
let codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; let mut codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?;
codegen.compile_ast()?; codegen.compile_ast()?;
let execution_engine = codegen let execution_engine = codegen
.module .module