diff --git a/rust/interpreter/src/lib.rs b/rust/interpreter/src/lib.rs index b8ecdeb..2170ac2 100644 --- a/rust/interpreter/src/lib.rs +++ b/rust/interpreter/src/lib.rs @@ -560,6 +560,24 @@ impl<'a> Interpreter<'a> { Err(anyhow!("Type mismatch for fn:minus")) } } + "fn:mult" => { + if let (Value::Number(a), Value::Number(b)) = (&vals[0], &vals[1]) { + Ok(Value::Number(a * b)) + } else { + Err(anyhow!("Type mismatch for fn:mult")) + } + } + "fn:div" => { + if let (Value::Number(a), Value::Number(b)) = (&vals[0], &vals[1]) { + if *b == 0 { + Err(anyhow!("Division by zero")) + } else { + Ok(Value::Number(a / b)) + } + } else { + Err(anyhow!("Type mismatch for fn:div")) + } + } _ => Err(anyhow!("Unknown function: {fn_name}")), } } @@ -582,3 +600,63 @@ impl<'a> Interpreter<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use mangle_ir::{Inst, Ir, Operand}; + + #[test] + fn test_mult_div() -> Result<()> { + let mut ir = Ir::new(); + let store = MemStore::new(); + let interpreter = Interpreter::new(&ir, Box::new(store)); + let mut env = Env::new(); + + // Test fn:mult(6, 7) + let fn_mult = ir.intern_name("fn:mult"); + let op6 = Operand::Const(Constant::Number(6)); + let op7 = Operand::Const(Constant::Number(7)); + + // We use Expr::Call manually + let expr_mult = Expr::Call { + function: fn_mult, + args: vec![op6, op7], + }; + + if let Value::Number(n) = interpreter.eval_expr(&expr_mult, &env)? { + assert_eq!(n, 42); + } else { + panic!("Expected number"); + } + + // Test fn:div(10, 2) + let fn_div = ir.intern_name("fn:div"); + let op10 = Operand::Const(Constant::Number(10)); + let op2 = Operand::Const(Constant::Number(2)); + + let expr_div = Expr::Call { + function: fn_div, + args: vec![op10, op2], + }; + + if let Value::Number(n) = interpreter.eval_expr(&expr_div, &env)? { + assert_eq!(n, 5); + } else { + panic!("Expected number"); + } + + // Test fn:div(10, 0) -> Error + let op0 = Operand::Const(Constant::Number(0)); + let expr_div_zero = Expr::Call { + function: fn_div, + args: vec![op10, op0], + }; + + let res = interpreter.eval_expr(&expr_div_zero, &env); + assert!(res.is_err()); + assert_eq!(res.unwrap_err().to_string(), "Division by zero"); + + Ok(()) + } +}