diff --git a/simple.ed b/simple.ed index d3131726c..1251c241f 100644 --- a/simple.ed +++ b/simple.ed @@ -1,5 +1,9 @@ - -fn main(x: i64) -> i64 { - let x = 2 + 3; - return x; +fn main(x: i32) -> i32 { + let y = 0; + if x == 5 { + y = 2 * x; + } else { + y = 3 * x; + } + return y; } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 41494b3af..9d39049c1 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -5,6 +5,10 @@ pub enum OpCode { Mul, Div, Rem, + And, + Or, + Eq, + Ne, } impl OpCode { @@ -15,6 +19,10 @@ impl OpCode { OpCode::Mul => "muli", OpCode::Div => "divi", OpCode::Rem => "remi", + OpCode::And => "and", + OpCode::Or => "or", + OpCode::Eq => "eq", + OpCode::Ne => "ne", } } } @@ -82,6 +90,11 @@ pub enum Statement { name: String, value: Box, }, + If { + condition: Box, + body: Vec, + else_body: Option>, + }, Return(Option>), Function(Function), } diff --git a/src/codegen.rs b/src/codegen.rs index b37c99b63..95ebe1784 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, path::{Path, PathBuf}, todo, }; @@ -12,6 +12,7 @@ use inkwell::{ module::Module, types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}, values::{BasicMetadataValueEnum, BasicValue, BasicValueEnum}, + IntPredicate, }; use itertools::{Either, Itertools}; @@ -40,6 +41,12 @@ pub struct CodeGen<'ctx> { ast: ast::Program, } +#[derive(Debug, Clone)] +struct BlockInfo<'a> { + pub blocks: Vec>, + pub current_block: usize, +} + impl<'ctx> CodeGen<'ctx> { pub fn new( context: &'ctx Context, @@ -70,6 +77,7 @@ impl<'ctx> CodeGen<'ctx> { match &statement { Statement::Variable { .. } => unreachable!(), Statement::Return(_) => unreachable!(), + Statement::If { .. } => unreachable!(), Statement::Function(function) => { functions.push(function); self.compile_function_signature(function)?; @@ -130,25 +138,27 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(entry_block); - let mut variables: HashMap> = HashMap::new(); + let mut variables: HashMap, usize)> = HashMap::new(); for (i, param) in function.params.iter().enumerate() { let id = param.ident.clone(); variables.insert( id.clone(), - func.get_nth_param(i.try_into().unwrap()) - .expect("parameter"), + ( + func.get_nth_param(i.try_into().unwrap()) + .expect("parameter"), + 0, + ), ); } - // todo: check function has return? let mut has_return = false; for statement in &function.body { if let Statement::Return(_) = statement { has_return = true } - self.compile_statement(&entry_block, statement, &mut variables)?; + self.compile_statement(statement, &mut variables)?; } if !has_return { @@ -160,29 +170,124 @@ impl<'ctx> CodeGen<'ctx> { fn compile_statement( &self, - block: &BasicBlock, statement: &Statement, - variables: &mut HashMap>, + // value, assignments + variables: &mut HashMap, usize)>, ) -> Result<()> { match statement { // Variable assignment Statement::Variable { name, value } => { let result = self - .compile_expression(block, value, variables)? + .compile_expression(value, variables)? .expect("should have result"); - variables.insert(name.clone(), result); + let accesses = if let Some(x) = variables.get(name) { + x.1 + 1 + } else { + 0 + }; + variables.insert(name.clone(), (result, accesses)); } Statement::Return(ret) => { if let Some(ret) = ret { let result = self - .compile_expression(block, ret, variables)? + .compile_expression(ret, variables)? .expect("should have result"); self.builder.build_return(Some(&result)); } else { self.builder.build_return(None); } } + Statement::If { + condition, + body, + else_body, + } => { + let condition = self + .compile_expression(condition, variables)? + .expect("should produce a value"); + + let func = self + .builder + .get_insert_block() + .unwrap() + .get_parent() + .expect("parent should exist"); + + let mut if_block = self.context.append_basic_block(func, "if"); + let mut else_block = self.context.append_basic_block(func, "else"); + let merge_block = self.context.append_basic_block(func, "merge"); + + self.builder.build_conditional_branch( + condition.into_int_value(), + if_block, + if let Some(else_body) = else_body { + else_block + } else { + merge_block + }, + ); + + let mut variables_if = variables.clone(); + self.builder.position_at_end(if_block); + for s in body { + self.compile_statement(s, &mut variables_if); + } + // should we set the builder at the end of the if_block again? + self.builder.build_unconditional_branch(merge_block); + if_block = self.builder.get_insert_block().unwrap(); // update for phi + + let mut variables_else = variables.clone(); + if let Some(else_body) = else_body { + self.builder.position_at_end(else_block); + + for s in else_body { + self.compile_statement(s, &mut variables_else); + } + // should we set the builder at the end of the if_block again? + self.builder.build_unconditional_branch(merge_block); + else_block = self.builder.get_insert_block().unwrap(); // update for phi + } + + self.builder.position_at_end(merge_block); + + let mut processed_vars = HashMap::new(); + for (name, (value, acc)) in variables_if { + if variables.contains_key(&name) { + let (old_val, old_acc) = variables.get(&name).unwrap(); + if acc > *old_acc { + let phi = self + .builder + .build_phi(old_val.get_type(), &format!("{name}_phi")); + phi.add_incoming(&[(&value, if_block)]); + processed_vars.insert(name, (value, phi)); + } + } + } + + if else_body.is_some() { + for (name, (value, acc)) in variables_else { + if variables.contains_key(&name) { + let (old_val, old_acc) = variables.get(&name).unwrap(); + if acc > *old_acc { + if let Some((_, phi)) = processed_vars.get(&name) { + phi.add_incoming(&[(&value, else_block)]); + } else { + let phi = self + .builder + .build_phi(old_val.get_type(), &format!("{name}_phi")); + phi.add_incoming(&[(&value, else_block)]); + processed_vars.insert(name, (value, phi)); + } + } + } + } + } + + for (name, (_, phi)) in processed_vars { + variables.insert(name, (phi.as_basic_value(), 0)); + } + } Statement::Function(_function) => unreachable!(), }; @@ -191,28 +296,24 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_expression( &self, - block: &BasicBlock, expr: &Expression, - variables: &mut HashMap>, + variables: &mut HashMap, usize)>, ) -> Result>> { Ok(match expr { Expression::Variable(term) => Some(self.compile_variable(term, variables)?), Expression::Literal(term) => Some(self.compile_literal(term)?), - Expression::Call { function, args } => { - self.compile_call(block, function, args, variables)? - } + Expression::Call { function, args } => self.compile_call(function, args, variables)?, Expression::BinaryOp(lhs, op, rhs) => { - Some(self.compile_binary_op(block, lhs, op, rhs, variables)?) + Some(self.compile_binary_op(lhs, op, rhs, variables)?) } }) } pub fn compile_call( &self, - block: &BasicBlock, func_name: &str, args: &[Box], - variables: &mut HashMap>, + variables: &mut HashMap, usize)>, ) -> Result>> { let function = self.module.get_function(func_name).expect("should exist"); @@ -220,7 +321,7 @@ impl<'ctx> CodeGen<'ctx> { for arg in args { let res = self - .compile_expression(block, arg, variables)? + .compile_expression(arg, variables)? .expect("should have result"); value_args.push(res.into()); } @@ -238,18 +339,17 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_binary_op( &self, - block: &BasicBlock, lhs: &Expression, op: &OpCode, rhs: &Expression, - variables: &mut HashMap>, + variables: &mut HashMap, usize)>, ) -> Result> { let lhs = self - .compile_expression(block, lhs, variables)? + .compile_expression(lhs, variables)? .expect("should have result") .into_int_value(); let rhs = self - .compile_expression(block, rhs, variables)? + .compile_expression(rhs, variables)? .expect("should have result") .into_int_value(); @@ -259,6 +359,14 @@ impl<'ctx> CodeGen<'ctx> { OpCode::Mul => self.builder.build_int_mul(lhs, rhs, "mul"), OpCode::Div => self.builder.build_int_signed_div(lhs, rhs, "div"), OpCode::Rem => self.builder.build_int_signed_rem(lhs, rhs, "rem"), + OpCode::And => self.builder.build_and(lhs, rhs, "and"), + OpCode::Or => self.builder.build_or(lhs, rhs, "or"), + OpCode::Eq => self + .builder + .build_int_compare(IntPredicate::EQ, lhs, rhs, "eq"), + OpCode::Ne => self + .builder + .build_int_compare(IntPredicate::NE, lhs, rhs, "eq"), }; Ok(result.as_basic_value_enum()) @@ -288,9 +396,9 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_variable( &self, variable: &str, - variables: &mut HashMap>, + variables: &mut HashMap, usize)>, ) -> Result> { let var = *variables.get(variable).expect("value"); - Ok(var) + Ok(var.0) } } diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 7bd230317..54cea829e 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -14,6 +14,9 @@ extern { enum Token { "let" => Token::KeywordLet, "print" => Token::KeywordPrint, + "struct" => Token::KeywordStruct, + "if" => Token::KeywordIf, + "else" => Token::KeywordElse, "identifier" => Token::Identifier(), "int" => Token::Integer(), "return" => Token::KeywordReturn, @@ -32,6 +35,10 @@ extern { "*" => Token::OperatorMul, "/" => Token::OperatorDiv, "%" => Token::OperatorRem, + "&&" => Token::OperatorAnd, + "||" => Token::OperatorOr, + "==" => Token::OperatorEq, + "!=" => Token::OperatorNe, } } @@ -66,15 +73,30 @@ Statement: ast::Statement = { BasicStatement: ast::Statement = { "let" "=" ";" => ast::Statement::Variable { name: i, value: e}, "=" ";" => ast::Statement::Variable { name: i, value: e}, + "if" "{" "}" => ast::Statement::If { condition: cond, body: s, else_body: e}, "return" ";" => ast::Statement::Return(e), }; -ExprOp: ast::OpCode = { +ElseExpr: Vec = { + "else" "{" "}" => s +} + +Level0_Op: ast::OpCode = { + "&&" => ast::OpCode::And, + "||" => ast::OpCode::Or, +} + +Level1_Op: ast::OpCode = { + "==" => ast::OpCode::Eq, + "!=" => ast::OpCode::Ne, +} + +Level2_Op: ast::OpCode = { "+" => ast::OpCode::Add, "-" => ast::OpCode::Sub, }; -FactorOp: ast::OpCode = { +Level3_Op: ast::OpCode = { "*" => ast::OpCode::Mul, "/" => ast::OpCode::Div, "%" => ast::OpCode::Rem, @@ -85,18 +107,20 @@ Tier: Box = { NextTier }; -Expr = Tier; -Factor = Tier; +Expr = Tier; +Expr2 = Tier; +Expr3 = Tier; +Expr4 = Tier; // Terms: variables, literals, calls Term: Box = { => Box::new(ast::Expression::Variable(i)), - => Box::new(ast::Expression::Literal(n)), + => Box::new(ast::Expression::Literal(n)), "(" > ")" => Box::new(ast::Expression::Call { function: i, args: values}), "(" ")" }; -Num: ast::LiteralValue = => ast::LiteralValue::Integer { bits: None, signed: true, value: n.to_string()}; +Number: ast::LiteralValue = => ast::LiteralValue::Integer { bits: None, signed: true, value: n.to_string()}; // Function handling Param: ast::Parameter = { diff --git a/src/main.rs b/src/main.rs index bc19dcbfd..148eedf52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -121,6 +121,7 @@ fn main() -> Result<()> { // return Ok(()); //} + println!("{:#?}", ast); let context = Context::create(); let codegen = codegen::CodeGen::new(&context, &file_name, program, ast)?; codegen.compile_ast()?; diff --git a/src/tokens.rs b/src/tokens.rs index ac8a7ab44..986fba6c4 100644 --- a/src/tokens.rs +++ b/src/tokens.rs @@ -13,6 +13,12 @@ pub enum Token { KeywordFn, #[token("return")] KeywordReturn, + #[token("struct")] + KeywordStruct, + #[token("if")] + KeywordIf, + #[token("else")] + KeywordElse, #[regex(r"_?\p{XID_Start}\p{XID_Continue}*", |lex| lex.slice().parse().ok())] Identifier(String), @@ -48,6 +54,14 @@ pub enum Token { OperatorDiv, #[token("%")] OperatorRem, + #[token("&&")] + OperatorAnd, + #[token("||")] + OperatorOr, + #[token("==")] + OperatorEq, + #[token("!=")] + OperatorNe, } impl fmt::Display for Token {