From 2c2268925a39e47784f9e0c2062738797fceb753 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sun, 3 Dec 2023 17:32:37 -0800 Subject: [PATCH] todo: arbitrary ir --- src/backend/eval.rs | 10 ++--- src/backend/into_crane.rs | 39 +++++-------------- src/eval.rs | 20 +++++----- src/eval/primop.rs | 48 ++++++++++++++++++----- src/eval/primtype.rs | 13 ++++--- src/eval/value.rs | 80 +++++++++++++++++++++++++------------- src/ir.rs | 1 + src/ir/arbitrary.rs | 21 ++++++++++ src/ir/ast.rs | 31 +++++---------- src/ir/eval.rs | 73 +++++++++++++++++++++++++--------- src/repl.rs | 1 - src/syntax/eval.rs | 61 +++++++++++++++++++---------- src/type_infer.rs | 13 +++++-- src/type_infer/convert.rs | 11 +++++- src/type_infer/finalize.rs | 18 ++++----- src/util/scoped_map.rs | 21 ++++++++++ 16 files changed, 298 insertions(+), 163 deletions(-) create mode 100644 src/ir/arbitrary.rs diff --git a/src/backend/eval.rs b/src/backend/eval.rs index 81129eb..de7a1fa 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -26,7 +26,7 @@ impl Backend { /// library do. So, if you're validating equivalence between them, you'll want to weed /// out examples that overflow/underflow before checking equivalence. (This is the behavior /// of the built-in test systems.) - pub fn eval(program: Program) -> Result { + pub fn eval(program: Program) -> Result>> { let mut jitter = Backend::jit(Some(String::new()))?; let mut function_map = HashMap::new(); let mut main_function_body = vec![]; @@ -80,7 +80,7 @@ impl Backend { /// library do. So, if you're validating equivalence between them, you'll want to weed /// out examples that overflow/underflow before checking equivalence. (This is the behavior /// of the built-in test systems.) - pub fn eval(program: Program) -> Result { + pub fn eval(program: Program) -> Result>> { //use pretty::{Arena, Pretty}; //let allocator = Arena::<()>::new(); //program.pretty(&allocator).render(80, &mut std::io::stdout())?; @@ -147,7 +147,7 @@ impl Backend { /// This function assumes that this compilation and linking should run without any /// output, so changes to the RTS should make 100% sure that they do not generate /// any compiler warnings. - fn link(object_file: &Path, executable_path: &Path) -> Result<(), EvalError> { + fn link(object_file: &Path, executable_path: &Path) -> Result<(), EvalError>> { use std::path::PathBuf; let output = std::process::Command::new("clang") @@ -179,7 +179,7 @@ proptest::proptest! { fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { use crate::eval::PrimOpError; - let basic_result = program.eval(); + let basic_result = program.eval().map(|(_,x)| x); // windows `printf` is going to terminate lines with "\r\n", so we need to adjust // our test result here. @@ -219,7 +219,7 @@ proptest::proptest! { // .expect("rendering works"); - let basic_result = program.eval(); + let basic_result = program.eval().map(|(_,x)| x); if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { let compiled_result = Backend::::eval(program); diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 0f77339..d897f7c 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use crate::eval::PrimitiveType; use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; use crate::syntax::{ConstantType, Location}; @@ -16,14 +14,6 @@ use internment::ArcIntern; use crate::backend::error::BackendError; use crate::backend::Backend; -/// When we're compiling, we might need to reference some of the strings built into -/// the source code; to do so, we need a `GlobalValue`. Perhaps unexpectedly, given -/// the name, `GlobalValue`s are specific to a single function we're compiling, so -/// we end up computing this table for every function. -/// -/// This just a handy type alias to avoid a lot of confusion in the functions. -type StringTable = HashMap, GlobalValue>; - /// When we're talking about variables, it's handy to just have a table that points /// from a variable to "what to do if you want to reference this variable", which is /// agnostic about whether the variable is local, global, an argument, etc. Since @@ -36,7 +26,9 @@ struct ReferenceBuilder { impl ReferenceBuilder { fn refer_to(&self, builder: &mut FunctionBuilder) -> (entities::Value, ConstantType) { - let value = builder.ins().symbol_value(self.cranelift_type, self.local_data); + let value = builder + .ins() + .symbol_value(self.cranelift_type, self.local_data); (value, self.ir_type) } } @@ -83,7 +75,7 @@ impl Backend { for item in program.items { match item { TopLevel::Function(name, args, rettype, body) => { - self.compile_function(name.as_str(), &args, rettype, body); + self.compile_function(name.as_str(), &args, rettype, body)?; } TopLevel::Statement(stmt) => { @@ -139,18 +131,6 @@ impl Backend { let user_func_name = UserFuncName::user(0, func_id.as_u32()); ctx.func = Function::with_name_signature(user_func_name, basic_signature); - // In the future, we might want to see what runtime functions the function - // we were given uses, and then only include those functions that we care - // about. Presumably, we'd use some sort of lookup table like we do for - // strings. But for now, we only have one runtime function, and we're pretty - // sure we're always going to use it, so we just declare it (and reference - // it) directly. - let print_func_ref = self.runtime_functions.include_runtime_function( - "print", - &mut self.module, - &mut ctx.func, - )?; - // Let's start creating the variable table we'll use when we're dereferencing // them later. This table is a little interesting because instead of pointing // from data to data, we're going to point from data (the variable) to an @@ -166,7 +146,11 @@ impl Backend { let cranelift_type = ir::Type::from(*ty); variables.insert( name.clone(), - ReferenceBuilder { cranelift_type, local_data, ir_type: *ty }, + ReferenceBuilder { + cranelift_type, + local_data, + ir_type: *ty, + }, ); } @@ -176,9 +160,6 @@ impl Backend { // to win. variables.new_scope(); - // FIXME: Add arguments - let mut next_var_num = 1; - // Finally (!), we generate the function builder that we're going to use to // make this function! let mut fctx = FunctionBuilderContext::new(); @@ -326,7 +307,7 @@ impl Backend { for inner in exprs { // we can ignore all of these return values and such, because we // don't actually use them anywhere - self.compile_expression(inner, variables, builder); + self.compile_expression(inner, variables, builder)?; } // instead, we just return the last one self.compile_expression(last, variables, builder) diff --git a/src/eval.rs b/src/eval.rs index b45b4d6..f0fee04 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -33,13 +33,12 @@ //! because the implementation of some parts of these primitives is really //! awful to look at. //! -mod env; +//mod env; mod primop; mod primtype; mod value; use cranelift_module::ModuleError; -pub use env::{EvalEnvironment, LookupError}; pub use primop::PrimOpError; pub use primtype::PrimitiveType; pub use value::Value; @@ -56,11 +55,9 @@ use self::primtype::UnknownPrimType; /// of converting those errors to strings and then seeing if they're the /// same. #[derive(Debug, thiserror::Error)] -pub enum EvalError { +pub enum EvalError { #[error(transparent)] - Lookup(#[from] LookupError), - #[error(transparent)] - PrimOp(#[from] PrimOpError), + PrimOp(#[from] PrimOpError), #[error(transparent)] Backend(#[from] BackendError), #[error("IO error: {0}")] @@ -77,16 +74,17 @@ pub enum EvalError { CastToFunction(String), #[error(transparent)] UnknownPrimType(#[from] UnknownPrimType), + #[error("Variable lookup failed for {1} at {0:?}")] + LookupFailed(crate::syntax::Location, String), } -impl PartialEq for EvalError { - fn eq(&self, other: &Self) -> bool { +impl PartialEq> for EvalError { + fn eq(&self, other: &EvalError) -> bool { match self { - EvalError::Lookup(a) => match other { - EvalError::Lookup(b) => a == b, + EvalError::LookupFailed(a, b) => match other { + EvalError::LookupFailed(x, y) => a == x && b == y, _ => false, }, - EvalError::PrimOp(a) => match other { EvalError::PrimOp(b) => a == b, _ => false, diff --git a/src/eval/primop.rs b/src/eval/primop.rs index 41634a3..c4fcd25 100644 --- a/src/eval/primop.rs +++ b/src/eval/primop.rs @@ -4,19 +4,19 @@ use crate::eval::value::Value; use super::primtype::{UnknownPrimType, ValuePrimitiveTypeError}; /// Errors that can occur running primitive operations in the evaluators. -#[derive(Clone, Debug, PartialEq, thiserror::Error)] -pub enum PrimOpError { +#[derive(Clone, Debug, thiserror::Error)] +pub enum PrimOpError { #[error("Math error (underflow or overflow) computing {0} operator")] MathFailure(&'static str), /// This particular variant covers the case in which a primitive /// operator takes two arguments that are supposed to be the same, /// but they differ. (So, like, all the math operators.) #[error("Type mismatch ({1} vs {2}) computing {0} operator")] - TypeMismatch(String, Value, Value), + TypeMismatch(String, Value, Value), /// This variant covers when an operator must take a particular /// type, but the user has provided a different one. #[error("Bad type for operator {0}: {1}")] - BadTypeFor(String, Value), + BadTypeFor(String, Value), /// Probably obvious from the name, but just to be very clear: this /// happens when you pass three arguments to a two argument operator, /// etc. Technically that's a type error of some sort, but we split @@ -36,6 +36,29 @@ pub enum PrimOpError { ValuePrimitiveTypeError(#[from] ValuePrimitiveTypeError), } +impl PartialEq> for PrimOpError { + fn eq(&self, other: &PrimOpError) -> bool { + match (self, other) { + (PrimOpError::MathFailure(a), PrimOpError::MathFailure(b)) => a == b, + (PrimOpError::TypeMismatch(a, b, c), PrimOpError::TypeMismatch(x, y, z)) => { + a == x && b.strip() == y.strip() && c.strip() == z.strip() + } + (PrimOpError::BadTypeFor(a, b), PrimOpError::BadTypeFor(x, y)) => a == x && b.strip() == y.strip(), + (PrimOpError::BadArgCount(a, b), PrimOpError::BadArgCount(x, y)) => a == x && b == y, + (PrimOpError::UnknownPrimOp(a), PrimOpError::UnknownPrimOp(x)) => a == x, + ( + PrimOpError::UnsafeCast { from: a, to: b }, + PrimOpError::UnsafeCast { from: x, to: y }, + ) => a == x && b == y, + (PrimOpError::UnknownPrimType(a), PrimOpError::UnknownPrimType(x)) => a == x, + (PrimOpError::ValuePrimitiveTypeError(a), PrimOpError::ValuePrimitiveTypeError(x)) => { + a == x + } + _ => false, + } + } +} + // Implementing primitives in an interpreter like this is *super* tedious, // and the only way to make it even somewhat manageable is to use macros. // This particular macro works for binary operations, and assumes that @@ -59,8 +82,8 @@ macro_rules! run_op { }; } -impl Value { - fn unary_op(operation: &str, value: &Value) -> Result { +impl Value { + fn unary_op(operation: &str, value: &Value) -> Result, PrimOpError> { match operation { "-" => match value { Value::I8(x) => Ok(Value::I8(x.wrapping_neg())), @@ -73,7 +96,11 @@ impl Value { } } - fn binary_op(operation: &str, left: &Value, right: &Value) -> Result { + fn binary_op( + operation: &str, + left: &Value, + right: &Value, + ) -> Result, PrimOpError> { match left { Value::I8(x) => match right { Value::I8(y) => run_op!(operation, x, *y), @@ -139,7 +166,7 @@ impl Value { right.clone(), )), }, - Value::Function(_, _) => { + Value::Closure(_, _, _, _) | Value::Void => { Err(PrimOpError::BadTypeFor(operation.to_string(), left.clone())) } } @@ -153,7 +180,10 @@ impl Value { /// implementation catches and raises an error on overflow or underflow, so /// its worth being careful to make sure that your inputs won't cause either /// condition. - pub fn calculate(operation: &str, values: Vec) -> Result { + pub fn calculate( + operation: &str, + values: Vec>, + ) -> Result, PrimOpError> { match values.len() { 1 => Value::unary_op(operation, &values[0]), 2 => Value::binary_op(operation, &values[0], &values[1]), diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 3c5818e..0876310 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -39,11 +39,12 @@ pub enum ValuePrimitiveTypeError { CannotConvertFunction(Option), } -impl<'a> TryFrom<&'a Value> for PrimitiveType { +impl<'a, IR> TryFrom<&'a Value> for PrimitiveType { type Error = ValuePrimitiveTypeError; - fn try_from(value: &'a Value) -> Result { + fn try_from(value: &'a Value) -> Result { match value { + Value::Void => Ok(PrimitiveType::Void), Value::I8(_) => Ok(PrimitiveType::I8), Value::I16(_) => Ok(PrimitiveType::I16), Value::I32(_) => Ok(PrimitiveType::I32), @@ -52,9 +53,9 @@ impl<'a> TryFrom<&'a Value> for PrimitiveType { Value::U16(_) => Ok(PrimitiveType::U16), Value::U32(_) => Ok(PrimitiveType::U32), Value::U64(_) => Ok(PrimitiveType::U64), - Value::Function(name, _) => { - Err(ValuePrimitiveTypeError::CannotConvertFunction(name.clone())) - } + Value::Closure(name, _, _, _) => Err(ValuePrimitiveTypeError::CannotConvertFunction( + name.as_ref().map(|x| (**x).clone()), + )), } } } @@ -147,7 +148,7 @@ impl PrimitiveType { /// type to the target type. (So, for example, "1i64" is a number that could /// work as a "u64", but since negative numbers wouldn't work, a cast from /// "1i64" to "u64" will fail.) - pub fn safe_cast(&self, source: &Value) -> Result { + pub fn safe_cast(&self, source: &Value) -> Result, PrimOpError> { match (self, source) { (PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)), (PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)), diff --git a/src/eval/value.rs b/src/eval/value.rs index 3634a2f..26c020b 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -1,6 +1,6 @@ -use super::EvalError; +use crate::util::scoped_map::ScopedMap; +use internment::ArcIntern; use std::fmt; -use std::rc::Rc; /// Values in the interpreter. /// @@ -8,7 +8,8 @@ use std::rc::Rc; /// are almost entirely identical. However, it's nice to have them separated /// by type so that we don't mix them up. #[derive(Clone)] -pub enum Value { +pub enum Value { + Void, I8(i8), I16(i16), I32(i32), @@ -17,14 +18,44 @@ pub enum Value { U16(u16), U32(u32), U64(u64), - Function( - Option, - Rc) -> Result>, + Closure( + Option>, + ScopedMap, Value>, + Vec>, + IR, ), } -fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Value { + /// Given a Value associated with some expression type, just strip out + /// the expressions and replace them with unit. + /// + /// Doing this transformation will likely make this value useless for + /// computation, but is very useful in allowing equivalence checks. + pub fn strip(&self) -> Value<()> { + match self { + Value::Void => Value::Void, + Value::U8(x) => Value::U8(*x), + Value::U16(x) => Value::U16(*x), + Value::U32(x) => Value::U32(*x), + Value::U64(x) => Value::U64(*x), + Value::I8(x) => Value::I8(*x), + Value::I16(x) => Value::I16(*x), + Value::I32(x) => Value::I32(*x), + Value::I64(x) => Value::I64(*x), + Value::Closure(name, env, args, _) => { + let new_env = env + .clone() + .map_values(|x| x.strip()); + Value::Closure(name.clone(), new_env, args.clone(), ()) + } + } + } +} + +fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Result { match value { + Value::Void => write!(f, ""), Value::I8(x) => write!(f, "{}i8", x), Value::I16(x) => write!(f, "{}i16", x), Value::I32(x) => write!(f, "{}i32", x), @@ -33,26 +64,27 @@ fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Result { Value::U16(x) => write!(f, "{}u16", x), Value::U32(x) => write!(f, "{}u32", x), Value::U64(x) => write!(f, "{}u64", x), - Value::Function(Some(name), _) => write!(f, "", name), - Value::Function(None, _) => write!(f, ""), + Value::Closure(Some(name), _, _, _) => write!(f, "", name), + Value::Closure(None, _, _, _) => write!(f, ""), } } -impl fmt::Debug for Value { +impl fmt::Debug for Value { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { format_value(self, f) } } -impl fmt::Display for Value { +impl fmt::Display for Value { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { format_value(self, f) } } -impl PartialEq for Value { - fn eq(&self, other: &Self) -> bool { +impl PartialEq> for Value { + fn eq(&self, other: &Value) -> bool { match self { + Value::Void => matches!(other, Value::Void), Value::I8(x) => match other { Value::I8(y) => x == y, _ => false, @@ -85,58 +117,54 @@ impl PartialEq for Value { Value::U64(y) => x == y, _ => false, }, - Value::Function(Some(x), _) => match other { - Value::Function(Some(y), _) => x == y, - _ => false, - }, - Value::Function(None, _) => false, + Value::Closure(_, _, _, _) => false, } } } -impl From for Value { +impl From for Value { fn from(value: i8) -> Self { Value::I8(value) } } -impl From for Value { +impl From for Value { fn from(value: i16) -> Self { Value::I16(value) } } -impl From for Value { +impl From for Value { fn from(value: i32) -> Self { Value::I32(value) } } -impl From for Value { +impl From for Value { fn from(value: i64) -> Self { Value::I64(value) } } -impl From for Value { +impl From for Value { fn from(value: u8) -> Self { Value::U8(value) } } -impl From for Value { +impl From for Value { fn from(value: u16) -> Self { Value::U16(value) } } -impl From for Value { +impl From for Value { fn from(value: u32) -> Self { Value::U32(value) } } -impl From for Value { +impl From for Value { fn from(value: u64) -> Self { Value::U64(value) } diff --git a/src/ir.rs b/src/ir.rs index af38b02..bb4d0bd 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -12,6 +12,7 @@ //! validating syntax, and then figuring out how to turn it into Cranelift //! and object code. After that point, however, this will be the module to //! come to for analysis and optimization work. +mod arbitrary; pub mod ast; mod eval; mod strings; diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs new file mode 100644 index 0000000..94f925a --- /dev/null +++ b/src/ir/arbitrary.rs @@ -0,0 +1,21 @@ +use crate::ir::{Program, TopLevel, Expression, ValueOrRef, Value, Type}; +use proptest::{ + prelude::Arbitrary, + strategy::{BoxedStrategy, Strategy}, +}; + +impl Arbitrary for Program { + type Parameters = crate::syntax::arbitrary::GenerationEnvironment; + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + unimplemented!() + //crate::syntax::Program::arbitrary_with(args) + // .prop_map(|x| { + // x.type_infer() + // .expect("arbitrary_with should generate type-correct programs") + // }) + // .boxed() + } +} + diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 728a4d0..67c8b12 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -5,10 +5,6 @@ use crate::{ }; use internment::ArcIntern; use pretty::{BoxAllocator, DocAllocator, Pretty}; -use proptest::{ - prelude::Arbitrary, - strategy::{BoxedStrategy, Strategy}, -}; use std::{fmt, str::FromStr, sync::atomic::AtomicUsize}; /// We're going to represent variables as interned strings. @@ -78,21 +74,6 @@ where } } -impl Arbitrary for Program { - type Parameters = crate::syntax::arbitrary::GenerationEnvironment; - type Strategy = BoxedStrategy; - - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - unimplemented!() - //crate::syntax::Program::arbitrary_with(args) - // .prop_map(|x| { - // x.type_infer() - // .expect("arbitrary_with should generate type-correct programs") - // }) - // .boxed() - } -} - /// A thing that can sit at the top level of a file. /// /// For the moment, these are statements and functions. Other things @@ -144,7 +125,7 @@ where /// a primitive), any subexpressions have been bound to variables so /// that the referenced data will always either be a constant or a /// variable reference. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), @@ -497,7 +478,7 @@ impl fmt::Display for TypeOrVar { Some((last_one, rest)) => { write!(f, "(")?; for arg in rest.iter() { - write!(f, "{}, ", arg); + write!(f, "{}, ", arg)?; } write!(f, "{})", last_one)?; } @@ -510,6 +491,12 @@ impl fmt::Display for TypeOrVar { } } +impl Default for TypeOrVar { + fn default() -> Self { + TypeOrVar::new() + } +} + impl TypeOrVar { /// Generate a fresh type variable that is different from all previous type variables. /// @@ -532,7 +519,7 @@ impl TypeOrVar { } } -trait TypeWithVoid { +pub trait TypeWithVoid { fn void() -> Self; } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index af3e31e..dd6340d 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,33 +1,55 @@ use super::{Primitive, Type, ValueOrRef}; -use crate::eval::{EvalEnvironment, EvalError, Value}; -use crate::ir::{Expression, Program, TopLevel}; +use crate::eval::{EvalError, Value}; +use crate::ir::{Expression, Program, TopLevel, Variable}; +use crate::util::scoped_map::ScopedMap; -impl Program { - /// Evaluate the program, returning either an error or a string containing everything - /// the program printed out. +type IRValue = Value>; +type IREvalError = EvalError>; + +impl> Program { + /// Evaluate the program, returning either an error or the result of the final + /// statement and the complete contents of the console output. /// /// The print outs will be newline separated, with one print out per line. - pub fn eval(&self) -> Result { - let mut env = EvalEnvironment::empty(); + pub fn eval(&self) -> Result<(IRValue, String), IREvalError> { + let mut env: ScopedMap> = ScopedMap::new(); let mut stdout = String::new(); + let mut last_value = Value::Void; for stmt in self.items.iter() { match stmt { - TopLevel::Function(_, _, _, _) => unimplemented!(), + TopLevel::Function(name, args, _, body) => { + let closure = Value::Closure( + Some(name.clone()), + env.clone(), + args.iter().map(|(x, _)| x.clone()).collect(), + body.clone(), + ); - TopLevel::Statement(_) => unimplemented!(), + env.insert(name.clone(), closure.clone()); + + last_value = closure; + } + + TopLevel::Statement(expr) => { + last_value = expr.eval(&env, &mut stdout)?; + } } } - Ok(stdout) + Ok((last_value, stdout)) } } impl Expression where - T: Clone + Into, + T: Clone + Into, { - fn eval(&self, env: &EvalEnvironment) -> Result { + fn eval( + &self, + env: &ScopedMap>, + stdout: &mut String, + ) -> Result, IREvalError> { match self { Expression::Atomic(x) => x.eval(env), @@ -45,7 +67,7 @@ where let arg_values = args .iter() .map(|x| x.eval(env)) - .collect::, EvalError>>()?; + .collect::>, IREvalError>>()?; // and then finally we call `calculate` to run them. trust me, it's nice // to not have to deal with all the nonsense hidden under `calculate`. @@ -61,15 +83,25 @@ where unimplemented!() } - Expression::Print(_, _) => unimplemented!(), + Expression::Print(loc, n) => { + let value = env + .get(n) + .cloned() + .ok_or_else(|| EvalError::LookupFailed(loc.clone(), n.to_string()))?; + stdout.push_str(&format!("{} = {}\n", n, value)); + Ok(Value::Void) + } Expression::Bind(_, _, _, _) => unimplemented!(), } } } -impl ValueOrRef { - fn eval(&self, env: &EvalEnvironment) -> Result { +impl ValueOrRef { + fn eval( + &self, + env: &ScopedMap>, + ) -> Result, IREvalError> { match self { ValueOrRef::Value(_, _, v) => match v { super::Value::I8(_, v) => Ok(Value::I8(*v)), @@ -82,7 +114,10 @@ impl ValueOrRef { super::Value::U64(_, v) => Ok(Value::U64(*v)), }, - ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?), + ValueOrRef::Ref(loc, _, n) => env + .get(n) + .cloned() + .ok_or_else(|| EvalError::LookupFailed(loc.clone(), n.to_string())), } } } @@ -91,7 +126,7 @@ impl ValueOrRef { fn two_plus_three() { let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let ir = input.type_infer().expect("test should be type-valid"); - let output = ir.eval().expect("runs successfully"); + let (_, output) = ir.eval().expect("runs successfully"); assert_eq!("x = 5u64\n", &output); } @@ -100,6 +135,6 @@ fn lotsa_math() { let input = crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); let ir = input.type_infer().expect("test should be type-valid"); - let output = ir.eval().expect("runs successfully"); + let (_, output) = ir.eval().expect("runs successfully"); assert_eq!("x = 7u64\n", &output); } diff --git a/src/repl.rs b/src/repl.rs index ea4803e..7d90a50 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,5 +1,4 @@ use crate::backend::{Backend, BackendError}; -use crate::eval::PrimitiveType; use crate::syntax::{ConstantType, Location, ParserError, Statement, TopLevel}; use crate::type_infer::TypeInferenceResult; use crate::util::scoped_map::ScopedMap; diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 928d182..8025d07 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -1,11 +1,12 @@ +use crate::eval::{EvalError, PrimitiveType, Value}; +use crate::syntax::{ConstantType, Expression, Name, Program, Statement, TopLevel}; +use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; - -use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value}; -use crate::syntax::{ConstantType, Expression, Program, Statement, TopLevel}; use std::str::FromStr; impl Program { - /// Evaluate the program, returning either an error or what it prints out when run. + /// Evaluate the program, returning either an error or a pair of the final value + /// produced and the output printed to the console. /// /// Doing this evaluation is particularly useful for testing, to ensure that if we /// modify a program in some way it does the same thing on both sides of the @@ -15,36 +16,51 @@ impl Program { /// Note that the errors here are slightly more strict that we enforce at runtime. /// For example, we check for overflow and underflow errors during evaluation, and /// we don't check for those in the compiled code. - pub fn eval(&self) -> Result { - let mut env = EvalEnvironment::empty(); + pub fn eval(&self) -> Result<(Value, String), EvalError> { + let mut env = ScopedMap::new(); let mut stdout = String::new(); + let mut last_result = Value::Void; for stmt in self.items.iter() { match stmt { - TopLevel::Function(_name, _arg_names, _body) => { - unimplemented!() - } - // at this point, evaluation is pretty simple. just walk through each - // statement, in order, and record printouts as we come to them. - TopLevel::Statement(Statement::Binding(_, name, value)) => { - let actual_value = value.eval(&env)?; - env = env.extend(name.clone().intern(), actual_value); + TopLevel::Function(name, arg_names, body) => { + last_result = Value::Closure( + name.clone().map(Name::intern), + env.clone(), + arg_names.iter().cloned().map(Name::intern).collect(), + body.clone(), + ); + if let Some(name) = name { + env.insert(name.clone().intern(), last_result.clone()); + } } - TopLevel::Statement(Statement::Print(_, name)) => { - let value = env.lookup(name.clone().intern())?; + TopLevel::Statement(Statement::Binding(_, name, value)) => { + let actual_value = value.eval(&env)?; + env.insert(name.clone().intern(), actual_value); + last_result = Value::Void; + } + + TopLevel::Statement(Statement::Print(loc, name)) => { + let value = env + .get(&name.clone().intern()) + .ok_or_else(|| EvalError::LookupFailed(loc.clone(), name.name.clone()))?; let line = format!("{} = {}\n", name, value); stdout.push_str(&line); + last_result = Value::Void; } } } - Ok(stdout) + Ok((last_result, stdout)) } } impl Expression { - fn eval(&self, env: &EvalEnvironment) -> Result { + fn eval( + &self, + env: &ScopedMap, Value>, + ) -> Result, EvalError> { match self { Expression::Value(_, v) => match v { super::Value::Number(_, ty, v) => match ty { @@ -61,7 +77,10 @@ impl Expression { }, }, - Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?), + Expression::Reference(loc, n) => env + .get(&ArcIntern::new(n.clone())) + .ok_or_else(|| EvalError::LookupFailed(loc.clone(), n.clone())) + .cloned(), Expression::Cast(_, target, expr) => { let target_type = PrimitiveType::from_str(target)?; @@ -86,13 +105,13 @@ impl Expression { #[test] fn two_plus_three() { let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let output = input.eval().expect("runs successfully"); + let (_, output) = input.eval().expect("runs successfully"); assert_eq!("x = 5u64\n", &output); } #[test] fn lotsa_math() { let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let output = input.eval().expect("runs successfully"); + let (_, output) = input.eval().expect("runs successfully"); assert_eq!("x = 7u64\n", &output); } diff --git a/src/type_infer.rs b/src/type_infer.rs index d47c7f4..080ce8d 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -42,9 +42,16 @@ impl syntax::Program { proptest::proptest! { #[test] fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) { - let syntax_result = input.eval(); + let syntax_result = input.eval().map(|(x,o)| (x.strip(), o)); let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); - let ir_result = ir.eval(); - proptest::prop_assert!(syntax_result.eq(&ir_result)); + let ir_result = ir.eval().map(|(x,o)| (x.strip(), o)); + match (syntax_result, ir_result) { + (Err(e1), Err(e2)) => proptest::prop_assert_eq!(e1, e2), + (Ok((v1, o1)), Ok((v2, o2))) => { + proptest::prop_assert_eq!(v1, v2); + proptest::prop_assert_eq!(o1, o2); + } + _ => proptest::prop_assert!(false), + } } } diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index d67f3d8..5cfa210 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -77,7 +77,11 @@ pub fn convert_top_level( } let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); - constraint_db.push(Constraint::Equivalent(expr.location().clone(), rettype.clone(), ty)); + constraint_db.push(Constraint::Equivalent( + expr.location().clone(), + rettype.clone(), + ty, + )); ir::TopLevel::Function(funname, function_args, rettype, expr) } @@ -267,7 +271,10 @@ fn convert_expression( (last_call, ret_type) } else { prereqs.push(last_call); - (ir::Expression::Block(loc, ret_type.clone(), prereqs), ret_type) + ( + ir::Expression::Block(loc, ret_type.clone(), prereqs), + ret_type, + ) } } } diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 577f4d8..59ff329 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -17,14 +17,14 @@ pub fn finalize_program( fn finalize_top_level(item: TopLevel, resolutions: &TypeResolutions) -> TopLevel { match item { - TopLevel::Function(name, args, rettype, expr) => { - TopLevel::Function( - name, - args.into_iter().map(|(name, t)| (name, finalize_type(t, resolutions))).collect(), - finalize_type(rettype, resolutions), - finalize_expression(expr, resolutions) - ) - } + TopLevel::Function(name, args, rettype, expr) => TopLevel::Function( + name, + args.into_iter() + .map(|(name, t)| (name, finalize_type(t, resolutions))) + .collect(), + finalize_type(rettype, resolutions), + finalize_expression(expr, resolutions), + ), TopLevel::Statement(expr) => TopLevel::Statement(finalize_expression(expr, resolutions)), } } @@ -53,7 +53,7 @@ fn finalize_expression( .collect(), ), - Expression::Block(loc, ty, mut exprs) => { + Expression::Block(loc, ty, exprs) => { let mut final_exprs = Vec::with_capacity(exprs.len()); for expr in exprs { diff --git a/src/util/scoped_map.rs b/src/util/scoped_map.rs index 69f442c..34e968b 100644 --- a/src/util/scoped_map.rs +++ b/src/util/scoped_map.rs @@ -1,6 +1,7 @@ use std::{borrow::Borrow, collections::HashMap, hash::Hash}; /// A version of [`std::collections::HashMap`] with a built-in notion of scope. +#[derive(Clone)] pub struct ScopedMap { scopes: Vec>, } @@ -84,4 +85,24 @@ impl ScopedMap { pub fn release_scope(&mut self) -> Option> { self.scopes.pop() } + + /// Create a new scoped set by mapping over the values of this one. + pub fn map_values(self, f: F) -> ScopedMap + where + F: Fn(V) -> W, + { + let mut scopes = Vec::with_capacity(self.scopes.len()); + + for scope in self.scopes { + let mut map = HashMap::with_capacity(scope.len()); + + for (k, v) in scope { + map.insert(k, f(v)); + } + + scopes.push(map); + } + + ScopedMap { scopes } + } }