From ad353f749d4ab3c6fac34f6ef5d1539e21921ebe Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sun, 4 Feb 2024 16:21:02 +0100 Subject: [PATCH] basic binop --- lib/edlang_codegen_mlir/src/codegen.rs | 227 ++++++++++++++++++++----- programs/simple.ed | 2 +- 2 files changed, 183 insertions(+), 46 deletions(-) diff --git a/lib/edlang_codegen_mlir/src/codegen.rs b/lib/edlang_codegen_mlir/src/codegen.rs index e53fb4c26..1d1fb19ea 100644 --- a/lib/edlang_codegen_mlir/src/codegen.rs +++ b/lib/edlang_codegen_mlir/src/codegen.rs @@ -10,9 +10,9 @@ use inkwell::{ module::Module, targets::{InitializationConfig, Target, TargetData, TargetMachine, TargetTriple}, types::{AnyType, BasicMetadataTypeEnum, BasicType}, - values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum}, + values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue}, }; -use ir::ValueTree; +use ir::{TypeInfo, ValueTree}; use tracing::info; #[derive(Debug, Clone, Copy)] @@ -229,47 +229,21 @@ fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderErro for stmt in &block.statements { info!("compiling stmt"); match &stmt.kind { - ir::StatementKind::Assign(place, rvalue) => { - match rvalue { - ir::RValue::Use(op) => match op { - ir::Operand::Copy(other_place) => { - // should this just copy the local? - let pointee_ty = - compile_basic_type(ctx, &body.locals[other_place.local].ty); - let value = ctx.builder.build_load( - pointee_ty, - *locals.get(&other_place.local).unwrap(), - "", - )?; - ctx.builder - .build_store(*locals.get(&place.local).unwrap(), value)?; - } - ir::Operand::Move(other_place) => { - let pointee_ty = - compile_basic_type(ctx, &body.locals[other_place.local].ty); - let value = ctx.builder.build_load( - pointee_ty, - *locals.get(&other_place.local).unwrap(), - "", - )?; - ctx.builder - .build_store(*locals.get(&place.local).unwrap(), value)?; - } - ir::Operand::Constant(data) => match &data.kind { - ir::ConstKind::Value(val) => { - let value = compile_value(ctx, val, &data.type_info); - ctx.builder - .build_store(*locals.get(&place.local).unwrap(), value)?; - } - ir::ConstKind::ZeroSized => todo!(), - }, - }, - ir::RValue::Ref(_, _) => todo!(), - ir::RValue::BinOp(_, _, _) => todo!(), - ir::RValue::LogicOp(_, _, _) => todo!(), - ir::RValue::UnOp(_, _) => todo!(), + ir::StatementKind::Assign(place, rvalue) => match rvalue { + ir::RValue::Use(op) => { + let value = compile_load_operand(ctx, body, &locals, op)?.0; + ctx.builder + .build_store(*locals.get(&place.local).unwrap(), value)?; } - } + ir::RValue::Ref(_, _) => todo!(), + ir::RValue::BinOp(op, lhs, rhs) => { + let value = compile_bin_op(ctx, body, &locals, *op, lhs, rhs)?; + ctx.builder + .build_store(*locals.get(&place.local).unwrap(), value)?; + } + ir::RValue::LogicOp(_, _, _) => todo!(), + ir::RValue::UnOp(_, _) => todo!(), + }, ir::StatementKind::StorageLive(_) => { // https://llvm.org/docs/LangRef.html#int-lifestart } @@ -307,13 +281,176 @@ fn compile_fn(ctx: &ModuleCompileCtx, body: &ir::Body) -> Result<(), BuilderErro Ok(()) } +fn compile_bin_op<'ctx>( + ctx: &ModuleCompileCtx<'ctx, '_>, + body: &ir::Body, + locals: &HashMap>, + op: ir::BinOp, + lhs: &ir::Operand, + rhs: &ir::Operand, +) -> Result, BuilderError> { + let (lhs_value, lhs_ty) = compile_load_operand(ctx, body, locals, lhs)?; + let (rhs_value, _rhs_ty) = compile_load_operand(ctx, body, locals, rhs)?; + + let is_float = matches!(lhs_ty.kind, ir::TypeKind::Float(_)); + let is_signed = matches!(lhs_ty.kind, ir::TypeKind::Int(_)); + + Ok(match op { + ir::BinOp::Add => { + if is_float { + ctx.builder + .build_float_add( + lhs_value.into_float_value(), + rhs_value.into_float_value(), + "", + )? + .as_basic_value_enum() + } else { + ctx.builder + .build_int_add(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum() + } + } + ir::BinOp::Sub => { + if is_float { + ctx.builder + .build_float_sub( + lhs_value.into_float_value(), + rhs_value.into_float_value(), + "", + )? + .as_basic_value_enum() + } else { + ctx.builder + .build_int_sub(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum() + } + } + ir::BinOp::Mul => { + if is_float { + ctx.builder + .build_float_mul( + lhs_value.into_float_value(), + rhs_value.into_float_value(), + "", + )? + .as_basic_value_enum() + } else { + ctx.builder + .build_int_add(lhs_value.into_int_value(), rhs_value.into_int_value(), "")? + .as_basic_value_enum() + } + } + ir::BinOp::Div => { + if is_float { + ctx.builder + .build_float_div( + lhs_value.into_float_value(), + rhs_value.into_float_value(), + "", + )? + .as_basic_value_enum() + } else if is_signed { + ctx.builder + .build_int_signed_div( + lhs_value.into_int_value(), + rhs_value.into_int_value(), + "", + )? + .as_basic_value_enum() + } else { + ctx.builder + .build_int_unsigned_div( + lhs_value.into_int_value(), + rhs_value.into_int_value(), + "", + )? + .as_basic_value_enum() + } + } + ir::BinOp::Rem => { + if is_float { + ctx.builder + .build_float_rem( + lhs_value.into_float_value(), + rhs_value.into_float_value(), + "", + )? + .as_basic_value_enum() + } else if is_signed { + ctx.builder + .build_int_signed_rem( + lhs_value.into_int_value(), + rhs_value.into_int_value(), + "", + )? + .as_basic_value_enum() + } else { + ctx.builder + .build_int_unsigned_rem( + lhs_value.into_int_value(), + rhs_value.into_int_value(), + "", + )? + .as_basic_value_enum() + } + } + ir::BinOp::BitXor => todo!(), + ir::BinOp::BitAnd => todo!(), + ir::BinOp::BitOr => todo!(), + ir::BinOp::Shl => todo!(), + ir::BinOp::Shr => todo!(), + ir::BinOp::Eq => todo!(), + ir::BinOp::Lt => todo!(), + ir::BinOp::Le => todo!(), + ir::BinOp::Ne => todo!(), + ir::BinOp::Ge => todo!(), + ir::BinOp::Gt => todo!(), + ir::BinOp::Offset => todo!(), + }) +} + +fn compile_load_operand<'ctx>( + ctx: &ModuleCompileCtx<'ctx, '_>, + body: &ir::Body, + locals: &HashMap>, + op: &ir::Operand, +) -> Result<(BasicValueEnum<'ctx>, TypeInfo), BuilderError> { + // todo: implement projection + Ok(match op { + ir::Operand::Copy(place) => { + let pointee_ty = compile_basic_type(ctx, &body.locals[place.local].ty); + let ptr = *locals.get(&place.local).unwrap(); + ( + ctx.builder.build_load(pointee_ty, ptr, "")?, + body.locals[place.local].ty.clone(), + ) + } + ir::Operand::Move(place) => { + let pointee_ty = compile_basic_type(ctx, &body.locals[place.local].ty); + let ptr = *locals.get(&place.local).unwrap(); + ( + ctx.builder.build_load(pointee_ty, ptr, "")?, + body.locals[place.local].ty.clone(), + ) + } + ir::Operand::Constant(data) => match &data.kind { + ir::ConstKind::Value(value) => ( + compile_value(ctx, value, &data.type_info)?, + data.type_info.clone(), + ), + ir::ConstKind::ZeroSized => todo!(), + }, + }) +} + fn compile_value<'ctx>( ctx: &ModuleCompileCtx<'ctx, '_>, val: &ValueTree, ty: &ir::TypeInfo, -) -> BasicValueEnum<'ctx> { +) -> Result, BuilderError> { let ty = compile_basic_type(ctx, ty); - match val { + Ok(match val { ValueTree::Leaf(const_val) => match const_val { ir::ConstValue::Bool(x) => ty .into_int_type() @@ -367,7 +504,7 @@ fn compile_value<'ctx>( ir::ConstValue::F64(x) => ty.into_float_type().const_float(*x).as_basic_value_enum(), }, ValueTree::Branch(_) => todo!(), - } + }) } fn compile_type<'a>( diff --git a/programs/simple.ed b/programs/simple.ed index 93fddc209..bd874366f 100644 --- a/programs/simple.ed +++ b/programs/simple.ed @@ -3,7 +3,7 @@ mod Main { pub fn main(argc: i32) -> i32 { let mut x: i32 = 2; x = 4; - return x; + return x + 2; } pub fn a() {