feat: compile ifs

This commit is contained in:
Edgar 2024-02-12 12:09:40 +01:00
parent f02e2291c7
commit 9ae8435e36
No known key found for this signature in database
GPG key ID: 70ADAE8F35904387
5 changed files with 217 additions and 25 deletions

View file

@ -422,7 +422,39 @@ fn compile_fn(ctx: &ModuleCompileCtx, fn_id: DefId) -> Result<(), BuilderError>
ctx.builder.build_return(None)?;
}
},
ir::Terminator::Switch => todo!(),
ir::Terminator::SwitchInt {
discriminator,
targets,
} => {
let (condition, condition_ty) =
compile_load_operand(ctx, fn_id, &locals, discriminator)?;
let cond = condition.into_int_value();
dbg!(&cond);
dbg!(&condition_ty);
let mut cases = Vec::new();
for (value, target) in targets.values.iter().zip(targets.targets.iter()) {
let target = *target;
let ty_kind = value.get_type();
dbg!(&ty_kind);
let block = blocks[target];
let value = compile_value(
ctx,
value,
&TypeInfo {
span: None,
kind: ty_kind,
},
)?
.into_int_value();
dbg!(&value);
cases.push((value, block));
}
ctx.builder
.build_switch(cond, blocks[*targets.targets.last().unwrap()], &cases)?;
}
ir::Terminator::Call {
func,
args,

View file

@ -160,7 +160,10 @@ pub enum StatementKind {
pub enum Terminator {
Target(usize),
Return,
Switch,
SwitchInt {
discriminator: Operand,
targets: SwitchTarget,
},
Call {
/// The function to call.
func: DefId,
@ -174,6 +177,13 @@ pub enum Terminator {
Unreachable,
}
/// Used for ifs, match
#[derive(Debug, Clone)]
pub struct SwitchTarget {
pub values: Vec<ValueTree>,
pub targets: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct TypeInfo {
pub span: Option<Span>,
@ -268,6 +278,29 @@ pub enum ValueTree {
Branch(Vec<Self>),
}
impl ValueTree {
pub fn get_type(&self) -> TypeKind {
match self {
ValueTree::Leaf(value) => match value {
ConstValue::Bool(_) => TypeKind::Bool,
ConstValue::I8(_) => TypeKind::Int(IntTy::I8),
ConstValue::I16(_) => TypeKind::Int(IntTy::I16),
ConstValue::I32(_) => TypeKind::Int(IntTy::I32),
ConstValue::I64(_) => TypeKind::Int(IntTy::I64),
ConstValue::I128(_) => TypeKind::Int(IntTy::I128),
ConstValue::U8(_) => TypeKind::Uint(UintTy::U8),
ConstValue::U16(_) => TypeKind::Uint(UintTy::U16),
ConstValue::U32(_) => TypeKind::Uint(UintTy::U32),
ConstValue::U64(_) => TypeKind::Uint(UintTy::U64),
ConstValue::U128(_) => TypeKind::Uint(UintTy::U8),
ConstValue::F32(_) => TypeKind::Float(FloatTy::F32),
ConstValue::F64(_) => TypeKind::Float(FloatTy::F64),
},
ValueTree::Branch(_) => todo!(),
}
}
}
#[derive(Debug, Clone)]
pub enum RValue {
Use(Operand),

View file

@ -70,7 +70,7 @@ impl BodyBuilder {
id
}
pub fn _add_temp_local(&mut self, ty_kind: TypeKind) -> usize {
pub fn add_temp_local(&mut self, ty_kind: TypeKind) -> usize {
let id = self.body.locals.len();
self.body.locals.push(Local::temp(TypeInfo {
span: None,
@ -79,7 +79,7 @@ impl BodyBuilder {
id
}
pub fn _get_local(&self, name: &str) -> Option<&Local> {
pub fn get_local(&self, name: &str) -> Option<&Local> {
self.body.locals.get(*(self.name_to_local.get(name)?))
}

View file

@ -1,12 +1,12 @@
use std::collections::HashMap;
use ast::ModuleStatement;
use ast::{BinaryOp, ModuleStatement};
use common::{BodyBuilder, BuildCtx};
use edlang_ast as ast;
use edlang_ir as ir;
use ir::{
BasicBlock, Body, DefId, Local, LocalKind, Operand, Place, ProgramBody, Statement,
StatementKind, Terminator, TypeInfo, TypeKind,
BasicBlock, Body, ConstValue, DefId, Local, LocalKind, Operand, Place, ProgramBody, Statement,
StatementKind, SwitchTarget, Terminator, TypeInfo, TypeKind, ValueTree,
};
mod common;
@ -170,7 +170,7 @@ fn lower_statement(builder: &mut BodyBuilder, info: &ast::Statement, ret_type: &
ast::Statement::Assign(info) => lower_assign(builder, info),
ast::Statement::For(_) => todo!(),
ast::Statement::While(_) => todo!(),
ast::Statement::If(_) => todo!(),
ast::Statement::If(info) => lower_if_stmt(builder, info, ret_type),
ast::Statement::Return(info) => lower_return(builder, info, ret_type),
ast::Statement::FnCall(info) => {
lower_fn_call(builder, info);
@ -178,6 +178,76 @@ fn lower_statement(builder: &mut BodyBuilder, info: &ast::Statement, ret_type: &
}
}
fn lower_if_stmt(builder: &mut BodyBuilder, info: &ast::IfStmt, ret_type: &TypeInfo) {
let cond_ty = find_expr_type(builder, &info.condition).expect("coouldnt find cond type");
let condition = lower_expr(builder, &info.condition, Some(&cond_ty));
let local = builder.add_temp_local(TypeKind::Bool);
let place = Place {
local,
projection: vec![].into(),
};
builder.statements.push(Statement {
span: None,
kind: StatementKind::Assign(place.clone(), condition),
});
// keep idx to change terminator
let current_block_idx = builder.body.blocks.len();
let statements = std::mem::take(&mut builder.statements);
builder.body.blocks.push(BasicBlock {
statements: statements.into(),
terminator: Terminator::Unreachable,
});
// keep idx for switch targets
let first_then_block_idx = builder.body.blocks.len();
for stmt in &info.then_block.body {
lower_statement(builder, stmt, ret_type);
}
// keet idx to change terminator
let last_then_block_idx = builder.body.blocks.len();
let statements = std::mem::take(&mut builder.statements);
builder.body.blocks.push(BasicBlock {
statements: statements.into(),
terminator: Terminator::Unreachable,
});
let first_else_block_idx = builder.body.blocks.len();
if let Some(contents) = &info.else_block {
for stmt in &contents.body {
lower_statement(builder, stmt, ret_type);
}
}
let last_else_block_idx = builder.body.blocks.len();
let statements = std::mem::take(&mut builder.statements);
builder.body.blocks.push(BasicBlock {
statements: statements.into(),
terminator: Terminator::Unreachable,
});
let targets = SwitchTarget {
values: vec![TypeKind::Bool.get_falsy_value()],
targets: vec![first_else_block_idx, first_then_block_idx],
};
let kind = Terminator::SwitchInt {
discriminator: Operand::Move(place),
targets,
};
builder.body.blocks[current_block_idx].terminator = kind;
let next_block_idx = builder.body.blocks.len();
builder.body.blocks[last_then_block_idx].terminator = Terminator::Target(next_block_idx);
builder.body.blocks[last_else_block_idx].terminator = Terminator::Target(next_block_idx);
}
fn lower_let(builder: &mut BodyBuilder, info: &ast::LetStmt) {
let ty = lower_type(&builder.ctx, &info.r#type);
let rvalue = lower_expr(builder, &info.value, Some(&ty));
@ -210,6 +280,58 @@ fn lower_assign(builder: &mut BodyBuilder, info: &ast::AssignStmt) {
})
}
fn find_expr_type(builder: &mut BodyBuilder, info: &ast::Expression) -> Option<TypeInfo> {
Some(TypeInfo {
span: None,
kind: match info {
ast::Expression::Value(x) => match x {
ast::ValueExpr::Bool { .. } => TypeKind::Bool,
ast::ValueExpr::Char { .. } => TypeKind::Char,
ast::ValueExpr::Int { .. } => return None,
ast::ValueExpr::Float { .. } => return None,
ast::ValueExpr::Str { .. } => todo!(),
ast::ValueExpr::Path(path) => {
// todo: handle full path
builder.get_local(&path.first.name)?.ty.kind.clone()
}
},
ast::Expression::FnCall(info) => {
let fn_id = {
let mod_body = builder.get_module_body();
if let Some(id) = mod_body.symbols.functions.get(&info.name.name) {
*id
} else {
*mod_body
.imports
.get(&info.name.name)
.expect("function call not found")
}
};
builder
.ctx
.body
.function_signatures
.get(&fn_id)?
.1
.kind
.clone()
}
ast::Expression::Unary(_, info) => find_expr_type(builder, info)?.kind,
ast::Expression::Binary(lhs, op, rhs) => {
if matches!(op, BinaryOp::Logic(_, _)) {
TypeKind::Bool
} else {
find_expr_type(builder, lhs)
.or(find_expr_type(builder, rhs))?
.kind
}
}
},
})
}
fn lower_expr(
builder: &mut BodyBuilder,
info: &ast::Expression,
@ -232,19 +354,27 @@ fn lower_binary_expr(
rhs: &ast::Expression,
type_hint: Option<&TypeInfo>,
) -> ir::RValue {
let expr_type = type_hint.expect("type hint needed");
let lhs = lower_expr(builder, lhs, type_hint);
let rhs = lower_expr(builder, rhs, type_hint);
let (lhs, lhs_ty) = if type_hint.is_none() {
let ty = find_expr_type(builder, lhs);
(lower_expr(builder, lhs, ty.as_ref()), ty)
} else {
(lower_expr(builder, lhs, type_hint), type_hint.cloned())
};
let (rhs, rhs_ty) = if type_hint.is_none() {
let ty = find_expr_type(builder, rhs);
(lower_expr(builder, rhs, ty.as_ref()), ty)
} else {
(lower_expr(builder, rhs, type_hint), type_hint.cloned())
};
let local_ty = expr_type;
let lhs_local = builder.add_local(Local::temp(local_ty.clone()));
let rhs_local = builder.add_local(Local::temp(local_ty.clone()));
let lhs_local = builder.add_local(Local::temp(lhs_ty.unwrap().clone()));
let rhs_local = builder.add_local(Local::temp(rhs_ty.unwrap().clone()));
let lhs_place = Place {
local: lhs_local,
projection: Default::default(),
};
let rhs_place = Place {
local: lhs_local,
local: rhs_local,
projection: Default::default(),
};

View file

@ -1,15 +1,12 @@
mod Main {
pub fn main(argc: i32) -> i32 {
let mut x: i32 = 2;
x = 4;
a();
return x + 2;
}
pub fn main(argc: i64) -> i64 {
let mut a: i64 = 0;
pub fn a() {
let mut x: i32 = 2;
x = 4;
return;
if argc > 2 {
a = 1;
}
return a;
}
}