diff --git a/lib/edlang_codegen_llvm/src/codegen.rs b/lib/edlang_codegen_llvm/src/codegen.rs index bf259925d..357201876 100644 --- a/lib/edlang_codegen_llvm/src/codegen.rs +++ b/lib/edlang_codegen_llvm/src/codegen.rs @@ -422,7 +422,39 @@ fn compile_fn(ctx: &ModuleCompileCtx, fn_id: DefId) -> Result<(), BuilderError> ctx.builder.build_return(None)?; } }, - ir::Terminator::Switch => todo!(), + ir::Terminator::SwitchInt { + discriminator, + targets, + } => { + let (condition, condition_ty) = + compile_load_operand(ctx, fn_id, &locals, discriminator)?; + let cond = condition.into_int_value(); + dbg!(&cond); + dbg!(&condition_ty); + + let mut cases = Vec::new(); + + for (value, target) in targets.values.iter().zip(targets.targets.iter()) { + let target = *target; + let ty_kind = value.get_type(); + dbg!(&ty_kind); + let block = blocks[target]; + let value = compile_value( + ctx, + value, + &TypeInfo { + span: None, + kind: ty_kind, + }, + )? + .into_int_value(); + dbg!(&value); + cases.push((value, block)); + } + + ctx.builder + .build_switch(cond, blocks[*targets.targets.last().unwrap()], &cases)?; + } ir::Terminator::Call { func, args, diff --git a/lib/edlang_ir/src/lib.rs b/lib/edlang_ir/src/lib.rs index c8e9923db..c59db7692 100644 --- a/lib/edlang_ir/src/lib.rs +++ b/lib/edlang_ir/src/lib.rs @@ -160,7 +160,10 @@ pub enum StatementKind { pub enum Terminator { Target(usize), Return, - Switch, + SwitchInt { + discriminator: Operand, + targets: SwitchTarget, + }, Call { /// The function to call. func: DefId, @@ -174,6 +177,13 @@ pub enum Terminator { Unreachable, } +/// Used for ifs, match +#[derive(Debug, Clone)] +pub struct SwitchTarget { + pub values: Vec, + pub targets: Vec, +} + #[derive(Debug, Clone)] pub struct TypeInfo { pub span: Option, @@ -268,6 +278,29 @@ pub enum ValueTree { Branch(Vec), } +impl ValueTree { + pub fn get_type(&self) -> TypeKind { + match self { + ValueTree::Leaf(value) => match value { + ConstValue::Bool(_) => TypeKind::Bool, + ConstValue::I8(_) => TypeKind::Int(IntTy::I8), + ConstValue::I16(_) => TypeKind::Int(IntTy::I16), + ConstValue::I32(_) => TypeKind::Int(IntTy::I32), + ConstValue::I64(_) => TypeKind::Int(IntTy::I64), + ConstValue::I128(_) => TypeKind::Int(IntTy::I128), + ConstValue::U8(_) => TypeKind::Uint(UintTy::U8), + ConstValue::U16(_) => TypeKind::Uint(UintTy::U16), + ConstValue::U32(_) => TypeKind::Uint(UintTy::U32), + ConstValue::U64(_) => TypeKind::Uint(UintTy::U64), + ConstValue::U128(_) => TypeKind::Uint(UintTy::U8), + ConstValue::F32(_) => TypeKind::Float(FloatTy::F32), + ConstValue::F64(_) => TypeKind::Float(FloatTy::F64), + }, + ValueTree::Branch(_) => todo!(), + } + } +} + #[derive(Debug, Clone)] pub enum RValue { Use(Operand), diff --git a/lib/edlang_lowering/src/common.rs b/lib/edlang_lowering/src/common.rs index 51f536fc5..11d1df801 100644 --- a/lib/edlang_lowering/src/common.rs +++ b/lib/edlang_lowering/src/common.rs @@ -70,7 +70,7 @@ impl BodyBuilder { id } - pub fn _add_temp_local(&mut self, ty_kind: TypeKind) -> usize { + pub fn add_temp_local(&mut self, ty_kind: TypeKind) -> usize { let id = self.body.locals.len(); self.body.locals.push(Local::temp(TypeInfo { span: None, @@ -79,7 +79,7 @@ impl BodyBuilder { id } - pub fn _get_local(&self, name: &str) -> Option<&Local> { + pub fn get_local(&self, name: &str) -> Option<&Local> { self.body.locals.get(*(self.name_to_local.get(name)?)) } diff --git a/lib/edlang_lowering/src/lib.rs b/lib/edlang_lowering/src/lib.rs index c3daff8cc..7cb5036ee 100644 --- a/lib/edlang_lowering/src/lib.rs +++ b/lib/edlang_lowering/src/lib.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; -use ast::ModuleStatement; +use ast::{BinaryOp, ModuleStatement}; use common::{BodyBuilder, BuildCtx}; use edlang_ast as ast; use edlang_ir as ir; use ir::{ - BasicBlock, Body, DefId, Local, LocalKind, Operand, Place, ProgramBody, Statement, - StatementKind, Terminator, TypeInfo, TypeKind, + BasicBlock, Body, ConstValue, DefId, Local, LocalKind, Operand, Place, ProgramBody, Statement, + StatementKind, SwitchTarget, Terminator, TypeInfo, TypeKind, ValueTree, }; mod common; @@ -170,7 +170,7 @@ fn lower_statement(builder: &mut BodyBuilder, info: &ast::Statement, ret_type: & ast::Statement::Assign(info) => lower_assign(builder, info), ast::Statement::For(_) => todo!(), ast::Statement::While(_) => todo!(), - ast::Statement::If(_) => todo!(), + ast::Statement::If(info) => lower_if_stmt(builder, info, ret_type), ast::Statement::Return(info) => lower_return(builder, info, ret_type), ast::Statement::FnCall(info) => { lower_fn_call(builder, info); @@ -178,6 +178,76 @@ fn lower_statement(builder: &mut BodyBuilder, info: &ast::Statement, ret_type: & } } +fn lower_if_stmt(builder: &mut BodyBuilder, info: &ast::IfStmt, ret_type: &TypeInfo) { + let cond_ty = find_expr_type(builder, &info.condition).expect("coouldnt find cond type"); + let condition = lower_expr(builder, &info.condition, Some(&cond_ty)); + + let local = builder.add_temp_local(TypeKind::Bool); + let place = Place { + local, + projection: vec![].into(), + }; + + builder.statements.push(Statement { + span: None, + kind: StatementKind::Assign(place.clone(), condition), + }); + + // keep idx to change terminator + let current_block_idx = builder.body.blocks.len(); + + let statements = std::mem::take(&mut builder.statements); + builder.body.blocks.push(BasicBlock { + statements: statements.into(), + terminator: Terminator::Unreachable, + }); + + // keep idx for switch targets + let first_then_block_idx = builder.body.blocks.len(); + + for stmt in &info.then_block.body { + lower_statement(builder, stmt, ret_type); + } + + // keet idx to change terminator + let last_then_block_idx = builder.body.blocks.len(); + let statements = std::mem::take(&mut builder.statements); + builder.body.blocks.push(BasicBlock { + statements: statements.into(), + terminator: Terminator::Unreachable, + }); + + let first_else_block_idx = builder.body.blocks.len(); + + if let Some(contents) = &info.else_block { + for stmt in &contents.body { + lower_statement(builder, stmt, ret_type); + } + } + + let last_else_block_idx = builder.body.blocks.len(); + let statements = std::mem::take(&mut builder.statements); + builder.body.blocks.push(BasicBlock { + statements: statements.into(), + terminator: Terminator::Unreachable, + }); + + let targets = SwitchTarget { + values: vec![TypeKind::Bool.get_falsy_value()], + targets: vec![first_else_block_idx, first_then_block_idx], + }; + + let kind = Terminator::SwitchInt { + discriminator: Operand::Move(place), + targets, + }; + builder.body.blocks[current_block_idx].terminator = kind; + + let next_block_idx = builder.body.blocks.len(); + builder.body.blocks[last_then_block_idx].terminator = Terminator::Target(next_block_idx); + builder.body.blocks[last_else_block_idx].terminator = Terminator::Target(next_block_idx); +} + fn lower_let(builder: &mut BodyBuilder, info: &ast::LetStmt) { let ty = lower_type(&builder.ctx, &info.r#type); let rvalue = lower_expr(builder, &info.value, Some(&ty)); @@ -210,6 +280,58 @@ fn lower_assign(builder: &mut BodyBuilder, info: &ast::AssignStmt) { }) } +fn find_expr_type(builder: &mut BodyBuilder, info: &ast::Expression) -> Option { + Some(TypeInfo { + span: None, + kind: match info { + ast::Expression::Value(x) => match x { + ast::ValueExpr::Bool { .. } => TypeKind::Bool, + ast::ValueExpr::Char { .. } => TypeKind::Char, + ast::ValueExpr::Int { .. } => return None, + ast::ValueExpr::Float { .. } => return None, + ast::ValueExpr::Str { .. } => todo!(), + ast::ValueExpr::Path(path) => { + // todo: handle full path + builder.get_local(&path.first.name)?.ty.kind.clone() + } + }, + ast::Expression::FnCall(info) => { + let fn_id = { + let mod_body = builder.get_module_body(); + + if let Some(id) = mod_body.symbols.functions.get(&info.name.name) { + *id + } else { + *mod_body + .imports + .get(&info.name.name) + .expect("function call not found") + } + }; + + builder + .ctx + .body + .function_signatures + .get(&fn_id)? + .1 + .kind + .clone() + } + ast::Expression::Unary(_, info) => find_expr_type(builder, info)?.kind, + ast::Expression::Binary(lhs, op, rhs) => { + if matches!(op, BinaryOp::Logic(_, _)) { + TypeKind::Bool + } else { + find_expr_type(builder, lhs) + .or(find_expr_type(builder, rhs))? + .kind + } + } + }, + }) +} + fn lower_expr( builder: &mut BodyBuilder, info: &ast::Expression, @@ -232,19 +354,27 @@ fn lower_binary_expr( rhs: &ast::Expression, type_hint: Option<&TypeInfo>, ) -> ir::RValue { - let expr_type = type_hint.expect("type hint needed"); - let lhs = lower_expr(builder, lhs, type_hint); - let rhs = lower_expr(builder, rhs, type_hint); + let (lhs, lhs_ty) = if type_hint.is_none() { + let ty = find_expr_type(builder, lhs); + (lower_expr(builder, lhs, ty.as_ref()), ty) + } else { + (lower_expr(builder, lhs, type_hint), type_hint.cloned()) + }; + let (rhs, rhs_ty) = if type_hint.is_none() { + let ty = find_expr_type(builder, rhs); + (lower_expr(builder, rhs, ty.as_ref()), ty) + } else { + (lower_expr(builder, rhs, type_hint), type_hint.cloned()) + }; - let local_ty = expr_type; - let lhs_local = builder.add_local(Local::temp(local_ty.clone())); - let rhs_local = builder.add_local(Local::temp(local_ty.clone())); + let lhs_local = builder.add_local(Local::temp(lhs_ty.unwrap().clone())); + let rhs_local = builder.add_local(Local::temp(rhs_ty.unwrap().clone())); let lhs_place = Place { local: lhs_local, projection: Default::default(), }; let rhs_place = Place { - local: lhs_local, + local: rhs_local, projection: Default::default(), }; diff --git a/programs/simple.ed b/programs/simple.ed index dd88b6c42..df585bf16 100644 --- a/programs/simple.ed +++ b/programs/simple.ed @@ -1,15 +1,12 @@ mod Main { - pub fn main(argc: i32) -> i32 { - let mut x: i32 = 2; - x = 4; - a(); - return x + 2; - } + pub fn main(argc: i64) -> i64 { + let mut a: i64 = 0; - pub fn a() { - let mut x: i32 = 2; - x = 4; - return; + if argc > 2 { + a = 1; + } + + return a; } }