diff --git a/build.rs b/build.rs index 266ae6b..00c1dba 100644 --- a/build.rs +++ b/build.rs @@ -45,17 +45,20 @@ fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> { writeln!(f, " let mut file_database = SimpleFiles::new();")?; writeln!( f, - " let syntax = Syntax::parse_file(&mut file_database, {:?});", + " let syntax = crate::syntax::parse_file(&mut file_database, {:?});", entry.path().display() )?; if entry.path().to_string_lossy().contains("broken") { writeln!(f, " if syntax.is_err() {{")?; writeln!(f, " return;")?; writeln!(f, " }}")?; - writeln!(f, " let (errors, _) = syntax.unwrap().validate();")?; writeln!( f, - " assert_ne!(errors.len(), 0, \"should have seen an error\");" + " let mut validation_result = Syntax::validate(syntax.unwrap());" + )?; + writeln!( + f, + " assert!(validation_result.is_err(), \"should have seen an error\");" )?; } else { // NOTE: Since the advent of defaulting rules and type checking, we @@ -67,10 +70,14 @@ fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> { f, " let syntax = syntax.expect(\"file should have parsed\");" )?; - writeln!(f, " let (errors, _) = syntax.validate();")?; + writeln!(f, " let validation_result = Syntax::validate(syntax);")?; writeln!( f, - " assert_eq!(errors.len(), 0, \"file should have no validation errors, but saw: {{:?}}\", errors);" + " assert!(validation_result.is_ok(), \"file should have no validation errors\");" + )?; + writeln!( + f, + " let syntax = validation_result.into_result().unwrap();" )?; writeln!(f, " let syntax_result = syntax.eval();")?; writeln!( diff --git a/src/backend.rs b/src/backend.rs index 2b475e4..4c642ee 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -33,7 +33,8 @@ mod runtime; pub use self::error::BackendError; pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; -use crate::syntax::ConstantType; +use crate::ir::Name; +use crate::syntax::{ConstantType, Location}; use cranelift_codegen::entity::EntityRef; use cranelift_codegen::ir::types; use cranelift_codegen::settings::Configurable; @@ -41,7 +42,6 @@ use cranelift_codegen::{isa, settings}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module}; use cranelift_object::{ObjectBuilder, ObjectModule}; -use internment::ArcIntern; use std::collections::HashMap; use target_lexicon::Triple; @@ -61,8 +61,8 @@ pub struct Backend { data_ctx: DataDescription, runtime_functions: RuntimeFunctions, defined_strings: HashMap, - defined_functions: HashMap, FuncId>, - defined_symbols: HashMap, (DataId, types::Type)>, + defined_functions: HashMap, + defined_symbols: HashMap, output_buffer: Option, platform: Triple, next_variable: usize, @@ -106,7 +106,7 @@ impl Backend { .module .declare_data(&alloc, Linkage::Import, true, false)?; retval.defined_symbols.insert( - ArcIntern::new(alloc), + Name::new(alloc, Location::manufactured()), (id, retval.module.target_config().pointer_type()), ); @@ -158,7 +158,7 @@ impl Backend { .module .declare_data(&alloc, Linkage::Import, true, false)?; retval.defined_symbols.insert( - ArcIntern::new(alloc), + Name::new(alloc, Location::manufactured()), (id, retval.module.target_config().pointer_type()), ); @@ -203,17 +203,16 @@ impl Backend { /// value will be null. pub fn define_variable( &mut self, - name: String, + name: Name, ctype: ConstantType, ) -> Result { self.data_ctx.define(Box::new(EMPTY_DATUM)); let id = self .module - .declare_data(&name, Linkage::Export, true, false)?; + .declare_data(name.current_name(), Linkage::Export, true, false)?; self.module.define_data(id, &self.data_ctx)?; self.data_ctx.clear(); - self.defined_symbols - .insert(ArcIntern::new(name), (id, ctype.into())); + self.defined_symbols.insert(name, (id, ctype.into())); Ok(id) } diff --git a/src/backend/error.rs b/src/backend/error.rs index 0fd0e7e..128dd59 100644 --- a/src/backend/error.rs +++ b/src/backend/error.rs @@ -1,4 +1,5 @@ -use crate::{backend::runtime::RuntimeFunctionError, ir::Type}; +use crate::backend::runtime::RuntimeFunctionError; +use crate::ir::{Name, Type}; use codespan_reporting::diagnostic::Diagnostic; use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError}; use cranelift_module::ModuleError; @@ -30,7 +31,7 @@ pub enum BackendError { #[error("Builtin function error: {0}")] BuiltinError(#[from] RuntimeFunctionError), #[error("Internal variable lookup error")] - VariableLookupFailure(ArcIntern), + VariableLookupFailure(Name), #[error(transparent)] CodegenError(#[from] CodegenError), #[error(transparent)] diff --git a/src/backend/eval.rs b/src/backend/eval.rs index 4aabe7e..097c34e 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -1,6 +1,7 @@ use crate::backend::Backend; use crate::eval::EvalError; -use crate::ir::{Expression, Program, Type}; +use crate::ir::{Expression, Name, Program, Type}; +use crate::syntax::Location; use cranelift_jit::JITModule; use cranelift_object::ObjectModule; #[cfg(test)] @@ -24,7 +25,10 @@ impl Backend { /// of the built-in test systems.) pub fn eval(program: Program) -> Result>> { let mut jitter = Backend::jit(Some(String::new()))?; - let function_id = jitter.compile_program("___test_jit_eval___", program)?; + let function_id = jitter.compile_program( + Name::new("___test_jit_eval___", Location::manufactured()), + program, + )?; jitter.module.finalize_definitions()?; let compiled_bytes = jitter.bytes(function_id); let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; @@ -58,7 +62,7 @@ impl Backend { let object_path = my_directory.path().join("object.o"); let executable_path = my_directory.path().join("test_executable"); - backend.compile_program("gogogo", program)?; + backend.compile_program(Name::new("gogogo", Location::manufactured()), program)?; let bytes = backend.bytes()?; std::fs::write(&object_path, bytes)?; Self::link(&object_path, &executable_path)?; diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 68a7774..58267d0 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -1,7 +1,7 @@ use crate::backend::error::BackendError; use crate::backend::Backend; use crate::eval::PrimitiveType; -use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; +use crate::ir::{Expression, Name, Primitive, Program, Type, Value, ValueOrRef}; use crate::syntax::{ConstantType, Location}; use cranelift_codegen::ir::{ self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, @@ -11,7 +11,6 @@ use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext}; use cranelift_module::{DataDescription, FuncId, Linkage, Module}; -use internment::ArcIntern; use std::collections::{hash_map, HashMap}; const VOID_REPR_TYPE: types::Type = types::I64; @@ -60,17 +59,16 @@ impl Backend { /// are no such statements, the function will do nothing.) pub fn compile_program( &mut self, - function_name: &str, + function_name: Name, program: Program, ) -> Result { - let mut generated_body = vec![]; let mut variables = HashMap::new(); for (top_level_name, top_level_type) in program.get_top_level_variables() { match top_level_type { Type::Function(argument_types, return_type) => { let func_id = self.declare_function( - top_level_name.as_str(), + top_level_name.current_name(), Linkage::Export, argument_types, *return_type, @@ -80,7 +78,7 @@ impl Backend { Type::Primitive(pt) => { let data_id = self.module.declare_data( - top_level_name.as_str(), + top_level_name.current_name(), Linkage::Export, true, false, @@ -94,7 +92,7 @@ impl Backend { Type::Structure(mut fields) => { let data_id = self.module.declare_data( - top_level_name.as_str(), + top_level_name.current_name(), Linkage::Export, true, false, @@ -109,21 +107,24 @@ impl Backend { } let void = Type::Primitive(PrimitiveType::Void); - let main_func_id = - self.declare_function(function_name, Linkage::Export, vec![], void.clone())?; + let main_func_id = self.declare_function( + function_name.current_name(), + Linkage::Export, + vec![], + void.clone(), + )?; self.defined_functions - .insert(ArcIntern::new(function_name.to_string()), main_func_id); + .insert(function_name.clone(), main_func_id); - for item in program.items { - match item { - TopLevel::Function(name, args, rettype, body) => { - self.compile_function(&mut variables, name.as_str(), &args, rettype, body)?; - } - - TopLevel::Statement(stmt) => { - generated_body.push(stmt); - } - } + for (_, function) in program.functions.into_iter() { + let func_id = self.compile_function( + &mut variables, + function.name.clone(), + &function.arguments, + function.return_type, + function.body, + )?; + self.defined_functions.insert(function.name, func_id); } self.compile_function( @@ -131,7 +132,7 @@ impl Backend { function_name, &[], void.clone(), - Expression::Block(Location::manufactured(), void, generated_body), + program.body, ) } @@ -168,9 +169,9 @@ impl Backend { #[tracing::instrument(level = "debug", skip(self, variables, body))] pub fn compile_function( &mut self, - variables: &mut HashMap, - function_name: &str, - arguments: &[(Variable, Type)], + variables: &mut HashMap, + function_name: Name, + arguments: &[(Name, Type)], return_type: Type, body: Expression, ) -> Result { @@ -195,13 +196,12 @@ impl Backend { // return to the user. For now, we declare all functions defined by this // function as public/global/exported, although we may want to reconsider // this decision later. - let interned_name = ArcIntern::new(function_name.to_string()); - let func_id = match self.defined_functions.entry(interned_name) { + let func_id = match self.defined_functions.entry(function_name.clone()) { hash_map::Entry::Occupied(entry) => *entry.get(), hash_map::Entry::Vacant(vac) => { - tracing::warn!(name = ?function_name, "compiling undeclared function"); + tracing::warn!(name = ?function_name.current_name(), "compiling undeclared function"); let func_id = self.module.declare_function( - function_name, + function_name.current_name(), Linkage::Export, &basic_signature, )?; @@ -277,7 +277,7 @@ impl Backend { fn compile_expression( &mut self, expr: Expression, - variables: &mut HashMap, + variables: &mut HashMap, builder: &mut FunctionBuilder, ) -> Result<(entities::Value, types::Type), BackendError> { match expr { @@ -379,7 +379,8 @@ impl Backend { panic!("Got to backend with non-structure type in structure construction?!"); }; - let global_allocator = ArcIntern::new("__global_allocation_pointer__".to_string()); + let global_allocator = + Name::new("__global_allocation_pointer__", Location::manufactured()); let Some(ReferenceBuilder::Global(_, allocator_variable)) = variables.get(&global_allocator) else { @@ -635,7 +636,7 @@ impl Backend { fn compile_value_or_ref( &self, value_or_ref: ValueOrRef, - variables: &HashMap, + variables: &HashMap, builder: &mut FunctionBuilder, ) -> Result<(entities::Value, types::Type), BackendError> { match value_or_ref { diff --git a/src/bin/ngrun.rs b/src/bin/ngrun.rs index 0cc42ec..97cb4a2 100644 --- a/src/bin/ngrun.rs +++ b/src/bin/ngrun.rs @@ -2,7 +2,7 @@ use clap::Parser; use codespan_reporting::files::SimpleFiles; use ngr::backend::Backend; use ngr::eval::Value; -use ngr::syntax; +use ngr::syntax::{self, Location, Name}; use ngr::type_infer::TypeInferenceResult; use pretty::termcolor::StandardStream; use tracing_subscriber::prelude::*; @@ -36,7 +36,7 @@ fn print_result(result: (Value, String)) { fn jit(ir: ngr::ir::Program) -> Result { let mut backend = Backend::jit(None)?; - let function_id = backend.compile_program("gogogo", ir)?; + let function_id = backend.compile_program(Name::new("gogogo", Location::manufactured()), ir)?; backend.module.finalize_definitions()?; let compiled_bytes = backend.bytes(function_id); Ok(unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }) @@ -51,10 +51,11 @@ fn main() { let mut file_database = SimpleFiles::new(); let mut console = StandardStream::stdout(pretty::termcolor::ColorChoice::Auto); let console_options = codespan_reporting::term::Config::default(); - let syntax = syntax::Program::parse_file(&mut file_database, cli.file.as_ref()); + let syntax = syntax::parse_file(&mut file_database, cli.file.as_ref()); let mut emit = |x| { let _ = codespan_reporting::term::emit(&mut console, &console_options, &file_database, &x); }; + let syntax = match syntax { Ok(x) => x, Err(e) => { @@ -63,18 +64,17 @@ fn main() { } }; - let (errors, warnings) = syntax.validate(); - let stop = !errors.is_empty(); - for error in errors { - emit(error.into()); + let mut validation_result = syntax::Program::validate(syntax); + for item in validation_result.diagnostics() { + emit(item); } - for warning in warnings { - emit(warning.into()); - } - if stop { + if validation_result.is_err() { return; } + let syntax = validation_result + .into_result() + .expect("we already checked this"); if cli.interpreter == Interpreter::Syntax { match syntax.eval() { Err(e) => tracing::error!(error = %e, "Evaluation error"), diff --git a/src/compiler.rs b/src/compiler.rs index 8e8e5ec..2925842 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,5 +1,6 @@ -use crate::syntax::Program as Syntax; -use crate::{backend::Backend, type_infer::TypeInferenceResult}; +use crate::backend::Backend; +use crate::syntax::{Location, Name, Program as Syntax}; +use crate::type_infer::TypeInferenceResult; use codespan_reporting::{ diagnostic::Diagnostic, files::SimpleFiles, @@ -76,29 +77,23 @@ impl Compiler { fn compile_internal(&mut self, input_file: &str) -> Result>, CompilerError> { // Try to parse the file into our syntax AST. If we fail, emit the error // and then immediately return `None`. - let syntax = Syntax::parse_file(&mut self.file_database, input_file)?; + let raw_syntax = crate::syntax::parse_file(&mut self.file_database, input_file)?; // Now validate the user's syntax AST. This can possibly find errors and/or // create warnings. We can continue if we only get warnings, but need to stop // if we get any errors. - let (mut errors, mut warnings) = syntax.validate(); - let stop = !errors.is_empty(); - let messages = errors - .drain(..) - .map(Into::into) - .chain(warnings.drain(..).map(Into::into)); - + let mut validation_result = Syntax::validate(raw_syntax); // emit all the messages we receive; warnings *and* errors - for message in messages { + for message in validation_result.diagnostics() { self.emit(message); } // we got errors, so just stop right now. perhaps oddly, this is Ok(None); // we've already said all we're going to say in the messags above, so there's // no need to provide another `Err` result. - if stop { + let Some(syntax) = validation_result.into_result() else { return Ok(None); - } + }; // Now that we've validated it, let's do type inference, potentially turning // into IR while we're at it. @@ -136,7 +131,7 @@ impl Compiler { // Finally, send all this to Cranelift for conversion into an object file. let mut backend = Backend::object_file(Triple::host())?; let unknown = "".to_string(); - backend.compile_program("gogogo", ir)?; + backend.compile_program(Name::new("gogogo", Location::manufactured()), ir)?; for (_, decl) in backend.module.declarations().get_functions() { tracing::debug!(name = %decl.name.as_ref().unwrap_or(&unknown), linkage = ?decl.linkage, "function definition"); diff --git a/src/eval.rs b/src/eval.rs index daa2b4b..5abf096 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -38,8 +38,8 @@ mod primop; mod primtype; mod value; +use crate::syntax::Name; use cranelift_module::ModuleError; -use internment::ArcIntern; pub use primop::PrimOpError; pub use primtype::PrimitiveType; pub use value::Value; @@ -80,20 +80,11 @@ pub enum EvalError { #[error("Attempted to call something that wasn't a function at {0:?} (it was a {1})")] NotAFunction(crate::syntax::Location, Value), #[error("Wrong argument call for function ({1:?}) at {0:?}; expected {2}, saw {3}")] - WrongArgCount( - crate::syntax::Location, - Option>, - usize, - usize, - ), + WrongArgCount(crate::syntax::Location, Option, usize, usize), #[error("Value has no fields {1} (attempt to get field {2} at {0:?})")] - NoFieldForValue(crate::syntax::Location, Value, ArcIntern), + NoFieldForValue(crate::syntax::Location, Value, Name), #[error("Bad field {2} for structure {1:?} at {0:?}")] - BadFieldForStructure( - crate::syntax::Location, - Option>, - ArcIntern, - ), + BadFieldForStructure(crate::syntax::Location, Option, Name), } impl PartialEq> for EvalError { diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index c714627..6f48d5e 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -72,10 +72,10 @@ impl<'a, IR> TryFrom<&'a Value> for PrimitiveType { // not sure this is the right call Value::Number(_) => Ok(PrimitiveType::U64), Value::Closure(name, _, _, _) => Err(ValuePrimitiveTypeError::CannotConvertFunction( - name.as_ref().map(|x| (**x).clone()), + name.as_ref().map(|x| x.current_name().to_string()), )), Value::Structure(name, _) => Err(ValuePrimitiveTypeError::CannotConvertStructure( - name.as_ref().map(|x| (**x).clone()), + name.as_ref().map(|x| x.current_name().to_string()), )), Value::Primitive(prim) => Err(ValuePrimitiveTypeError::CannotConvertPrimitive( prim.clone(), diff --git a/src/eval/value.rs b/src/eval/value.rs index 99e8248..8a574bc 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -1,5 +1,5 @@ +use crate::syntax::Name; use crate::util::scoped_map::ScopedMap; -use internment::ArcIntern; use std::collections::HashMap; use std::fmt; @@ -21,16 +21,8 @@ pub enum Value { U64(u64), // a number of unknown type Number(u64), - Closure( - Option>, - ScopedMap, Value>, - Vec>, - IR, - ), - Structure( - Option>, - HashMap, Value>, - ), + Closure(Option, ScopedMap>, Vec, IR), + Structure(Option, HashMap>), Primitive(String), } @@ -85,7 +77,7 @@ fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Resul Value::Closure(None, _, _, _) => write!(f, ""), Value::Structure(on, fields) => { if let Some(n) = on { - write!(f, "{}", n.as_str())?; + write!(f, "{}", n.current_name())?; } write!(f, "{{")?; for (n, v) in fields.iter() { diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs index 182c43b..ce1aa24 100644 --- a/src/ir/arbitrary.rs +++ b/src/ir/arbitrary.rs @@ -1,6 +1,6 @@ use crate::eval::PrimitiveType; use crate::ir::{ - Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable, + Expression, Name, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, }; use crate::syntax::Location; use crate::util::scoped_map::ScopedMap; @@ -216,8 +216,9 @@ impl ProgramTree { } let current = Program { - items, + functions: HashMap::new(), type_definitions: HashMap::new(), + body: unimplemented!(), }; ProgramTree { _rng: rng, current } @@ -284,7 +285,7 @@ impl ExpressionTree { fn generate_random_expression( rng: &mut TestRng, - env: &mut ScopedMap, + env: &mut ScopedMap, ) -> Expression { match EXPRESSION_TYPE_FREQUENCIES[EXPRESSION_TYPE_DISTRIBUTION.sample(rng)].0 { ExpressionType::Atomic => Expression::Atomic(generate_random_valueref(rng, env, None)), @@ -401,10 +402,7 @@ fn generate_random_expression( } } -fn generate_random_binding( - rng: &mut TestRng, - env: &mut ScopedMap, -) -> Expression { +fn generate_random_binding(rng: &mut TestRng, env: &mut ScopedMap) -> Expression { let name = generate_random_name(rng); let expr = generate_random_expression(rng, env); let ty = expr.type_of(); @@ -414,7 +412,7 @@ fn generate_random_binding( fn generate_random_valueref( rng: &mut TestRng, - env: &mut ScopedMap, + env: &mut ScopedMap, target_type: Option, ) -> ValueOrRef { let mut bindings = env.bindings(); @@ -458,9 +456,9 @@ fn generate_random_valueref( } } -fn generate_random_name(rng: &mut TestRng) -> Variable { +fn generate_random_name(rng: &mut TestRng) -> Name { let start = rng.gen_range('a'..='z'); - crate::ir::gensym(&format!("{}", start)) + Name::gensym(start) } fn generate_random_argument_type(rng: &mut TestRng) -> Type { diff --git a/src/ir/ast.rs b/src/ir/ast.rs index bb88f4a..8920d69 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -1,37 +1,14 @@ use crate::eval::PrimitiveType; pub use crate::ir::fields::Fields; +pub use crate::syntax::Name; use crate::syntax::{ConstantType, Location}; -use internment::ArcIntern; use proptest::arbitrary::Arbitrary; use std::collections::HashMap; use std::convert::TryFrom; use std::str::FromStr; -use std::sync::atomic::AtomicUsize; use super::arbitrary::ProgramGenerator; -/// We're going to represent variables as interned strings. -/// -/// These should be fast enough for comparison that it's OK, since it's going to end up -/// being pretty much the pointer to the string. -pub type Variable = ArcIntern; - -/// Generate a new symbol that is guaranteed to be different from every other symbol -/// currently known. -/// -/// This function will use the provided string as a base name for the symbol, but -/// extend it with numbers and characters to make it unique. While technically you -/// could roll-over these symbols, you probably don't need to worry about it. -pub fn gensym(base: &str) -> Variable { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - ArcIntern::new(format!( - "{}<{}>", - base, - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - )) -} - /// The representation of a program within our IR. For now, this is exactly one file. /// /// A program consists of a series of statements and functions. The statements should @@ -50,11 +27,20 @@ pub fn gensym(base: &str) -> Variable { /// it's easiest to just make it an argument. #[derive(Clone, Debug)] pub struct Program { - // For now, a program is just a vector of statements. In the future, we'll probably - // extend this to include a bunch of other information, but for now: just a list. - pub items: Vec>, + // The set of functions declared in this program. + pub functions: HashMap>, // The set of types declared in this program. - pub type_definitions: HashMap, Type>, + pub type_definitions: HashMap, + // The thing to evaluate in the end. + pub body: Expression, +} + +#[derive(Clone, Debug)] +pub struct FunctionDefinition { + pub name: Name, + pub arguments: Vec<(Name, Type)>, + pub return_type: Type, + pub body: Expression, } impl Arbitrary for Program { @@ -76,7 +62,7 @@ pub enum TopLevel { Statement(Expression), // FIXME: Is the return type actually necessary, given we can infer it from // the expression type? - Function(Variable, Vec<(Variable, Type)>, Type, Expression), + Function(Name, Vec<(Name, Type)>, Type, Expression), } impl TopLevel { @@ -108,16 +94,11 @@ impl TopLevel { pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), - Construct( - Location, - Type, - ArcIntern, - HashMap, ValueOrRef>, - ), - FieldRef(Location, Type, Type, ValueOrRef, ArcIntern), + Construct(Location, Type, Name, HashMap>), + FieldRef(Location, Type, Type, ValueOrRef, Name), Block(Location, Type, Vec>), Call(Location, Type, ValueOrRef, Vec>), - Bind(Location, Variable, Type, Box>), + Bind(Location, Name, Type, Box>), } impl Expression { @@ -191,7 +172,7 @@ impl FromStr for Primitive { #[derive(Clone, Debug)] pub enum ValueOrRef { Value(Location, Type, Value), - Ref(Location, Type, ArcIntern), + Ref(Location, Type, Name), Primitive(Location, Type, Primitive), } @@ -293,7 +274,7 @@ impl<'a> TryInto for &'a Type { #[derive(Clone, Debug, Eq, PartialEq)] pub enum TypeOrVar { Primitive(PrimitiveType), - Variable(Location, ArcIntern), + Variable(Location, Name), Function(Vec, Box), Structure(Fields), } @@ -322,12 +303,12 @@ impl TypeOrVar { /// of this function could potentially cause overflow problems, but you're going to have /// to try really hard (like, 2^64 times) to make that happen. pub fn new_located(loc: Location) -> Self { - TypeOrVar::Variable(loc, gensym("t")) + TypeOrVar::Variable(loc.clone(), Name::located_gensym(loc, "t")) } /// Try replacing the given type variable with the given type, returning true if anything /// was changed. - pub fn replace(&mut self, name: &ArcIntern, replace_with: &TypeOrVar) -> bool { + pub fn replace(&mut self, name: &Name, replace_with: &TypeOrVar) -> bool { match self { TypeOrVar::Variable(_, var_name) if name == var_name => { *self = replace_with.clone(); @@ -487,7 +468,7 @@ impl TryFrom for Type { #[test] fn struct_sizes_are_rational() { - assert_eq!(8, std::mem::size_of::>()); + assert_eq!(8, std::mem::size_of::()); assert_eq!(24, std::mem::size_of::()); assert_eq!(1, std::mem::size_of::()); assert_eq!(1, std::mem::size_of::()); diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 9a42560..ce89490 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,6 +1,6 @@ -use super::{Type, ValueOrRef}; +use super::{FunctionDefinition, Type, ValueOrRef}; use crate::eval::{EvalError, Value}; -use crate::ir::{Expression, Program, TopLevel, Variable}; +use crate::ir::{Expression, Name, Program}; use crate::util::scoped_map::ScopedMap; use std::collections::HashMap; use std::fmt::Display; @@ -9,7 +9,8 @@ type IRValue = Value>; type IREvalError = EvalError>; pub struct Evaluator { - env: ScopedMap>, + env: ScopedMap>, + functions: HashMap>, stdout: String, } @@ -17,6 +18,7 @@ impl Default for Evaluator { fn default() -> Self { Evaluator { env: ScopedMap::new(), + functions: HashMap::new(), stdout: String::new(), } } @@ -32,30 +34,9 @@ where /// /// The print outs will be newline separated, with one print out per line. pub fn eval(mut self, program: Program) -> Result<(IRValue, String), IREvalError> { - let mut last_value = Value::Void; - - for stmt in program.items.into_iter() { - match stmt { - TopLevel::Function(name, args, _, body) => { - let closure = Value::Closure( - Some(name.clone()), - self.env.clone(), - args.iter().map(|(x, _)| x.clone()).collect(), - body.clone(), - ); - - self.env.insert(name.clone(), closure.clone()); - - last_value = closure; - } - - TopLevel::Statement(expr) => { - last_value = self.eval_expr(expr)?; - } - } - } - - Ok((last_value, self.stdout)) + self.functions.extend(program.functions); + let retval = self.eval_expr(program.body)?; + Ok((retval, self.stdout)) } /// Get the current output of the evaluated program. @@ -203,8 +184,11 @@ where #[test] fn two_plus_three() { - let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let ir = input.type_infer().expect("test should be type-valid"); + let input = crate::syntax::parse_string(0, "x = 2 + 3; print x;").expect("parse works"); + let program = crate::syntax::Program::validate(input) + .into_result() + .unwrap(); + let ir = program.type_infer().expect("test should be type-valid"); let evaluator = Evaluator::default(); let (_, result) = evaluator.eval(ir).expect("runs successfully"); assert_eq!("x = 5u64\n", &result); @@ -213,8 +197,11 @@ fn two_plus_three() { #[test] fn lotsa_math() { let input = - crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let ir = input.type_infer().expect("test should be type-valid"); + crate::syntax::parse_string(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); + let program = crate::syntax::Program::validate(input) + .into_result() + .unwrap(); + let ir = program.type_infer().expect("test should be type-valid"); let evaluator = Evaluator::default(); let (_, result) = evaluator.eval(ir).expect("runs successfully"); assert_eq!("x = 7u64\n", &result); diff --git a/src/ir/fields.rs b/src/ir/fields.rs index 239f1cf..744056c 100644 --- a/src/ir/fields.rs +++ b/src/ir/fields.rs @@ -1,12 +1,12 @@ +use crate::syntax::Name; use cranelift_module::DataDescription; -use internment::ArcIntern; use std::fmt; #[derive(Clone)] pub struct Fields { ordering: FieldOrdering, total_size: usize, - fields: Vec<(ArcIntern, T)>, + fields: Vec<(Name, T)>, } impl PartialEq for Fields { @@ -48,12 +48,12 @@ impl Fields { self.ordering } - pub fn insert(&mut self, name: ArcIntern, t: T) { + pub fn insert(&mut self, name: Name, t: T) { self.total_size += 8; self.fields.push((name, t)); } - pub fn get(&self, name: &ArcIntern) -> Option<&T> { + pub fn get(&self, name: &Name) -> Option<&T> { for (n, res) in self.fields.iter() { if n == name { return Some(res); @@ -75,11 +75,11 @@ impl Fields { self.fields.len() } - pub fn has_field(&self, name: &ArcIntern) -> bool { + pub fn has_field(&self, name: &Name) -> bool { self.fields.iter().any(|(current, _)| current == name) } - pub fn remove_field(&mut self, name: &ArcIntern) -> Option { + pub fn remove_field(&mut self, name: &Name) -> Option { let mut field_index = None; for (idx, (current, _)) in self.fields.iter().enumerate() { @@ -92,11 +92,11 @@ impl Fields { field_index.map(|i| self.fields.remove(i).1) } - pub fn iter(&self) -> impl Iterator, &T)> { + pub fn iter(&self) -> impl Iterator { self.fields.iter().map(|(x, y)| (x, y)) } - pub fn field_names(&self) -> impl Iterator> { + pub fn field_names(&self) -> impl Iterator { self.fields.iter().map(|(n, _)| n) } @@ -115,7 +115,7 @@ impl Fields { cranelift_description } - pub fn field_type_and_offset(&self, field: &ArcIntern) -> Option<(&T, i32)> { + pub fn field_type_and_offset(&self, field: &Name) -> Option<(&T, i32)> { let mut offset = 0; for (current, ty) in self.fields.iter() { @@ -135,7 +135,7 @@ impl Fields { } impl IntoIterator for Fields { - type Item = (ArcIntern, T); + type Item = (Name, T); type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { diff --git a/src/ir/pretty.rs b/src/ir/pretty.rs index ab55bc3..9c20e74 100644 --- a/src/ir/pretty.rs +++ b/src/ir/pretty.rs @@ -5,12 +5,55 @@ use pretty::{Arena, DocAllocator, DocBuilder}; impl Program { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { - allocator - .intersperse( - self.items.iter().map(|x| x.pretty(allocator)), - allocator.line(), - ) - .align() + let mut result = allocator.nil(); + + for (name, ty) in self.type_definitions.iter() { + result = result + .append(allocator.text("type")) + .append(allocator.space()) + .append(allocator.text(name.current_name().to_string())) + .append(allocator.space()) + .append(allocator.text("=")) + .append(allocator.space()) + .append(ty.pretty(allocator)) + .append(allocator.hardline()); + } + + if !self.type_definitions.is_empty() { + result = result.append(allocator.hardline()); + } + + for function in self.functions.values() { + result = result + .append(allocator.text("function")) + .append(allocator.space()) + .append(allocator.text(function.name.current_name().to_string())) + .append(allocator.text("(")) + .append(allocator.intersperse( + function.arguments.iter().map(|(x, t)| { + allocator + .text(x.original_name().to_string()) + .append(allocator.text(":")) + .append(allocator.space()) + .append(t.pretty(allocator)) + }), + allocator.text(","), + )) + .append(allocator.text(")")) + .append(allocator.space()) + .append(allocator.text("->")) + .append(allocator.space()) + .append(function.return_type.pretty(allocator)) + .append(allocator.softline()) + .append(function.body.pretty(allocator)) + .append(allocator.hardline()); + } + + if !self.functions.is_empty() { + result = result.append(allocator.hardline()); + } + + result.append(self.body.pretty(allocator)) } } @@ -23,7 +66,7 @@ impl TopLevel { TopLevel::Function(name, args, _, expr) => allocator .text("function") .append(allocator.space()) - .append(allocator.text(name.as_ref().to_string())) + .append(allocator.text(name.current_name().to_string())) .append(allocator.space()) .append( allocator @@ -104,7 +147,7 @@ impl Expression { } }, Expression::Bind(_, var, ty, expr) => allocator - .text(var.as_ref().to_string()) + .text(var.current_name().to_string()) .append(allocator.space()) .append(allocator.text(":")) .append(allocator.space()) @@ -134,7 +177,7 @@ impl ValueOrRef { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { match self { ValueOrRef::Value(_, _, v) => v.pretty(allocator), - ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), + ValueOrRef::Ref(_, _, v) => allocator.text(v.current_name().to_string()), ValueOrRef::Primitive(_, _, p) => p.pretty(allocator), } } diff --git a/src/ir/strings.rs b/src/ir/strings.rs index 5d6d07f..27540b0 100644 --- a/src/ir/strings.rs +++ b/src/ir/strings.rs @@ -1,4 +1,4 @@ -use super::ast::{Expression, Program, TopLevel}; +use super::ast::{Expression, Program}; use internment::ArcIntern; use std::collections::HashSet; @@ -10,23 +10,14 @@ impl Program { pub fn strings(&self) -> HashSet> { let mut result = HashSet::new(); - for stmt in self.items.iter() { - stmt.register_strings(&mut result); + for function in self.functions.values() { + function.body.register_strings(&mut result); } result } } -impl TopLevel { - fn register_strings(&self, string_set: &mut HashSet>) { - match self { - TopLevel::Function(_, _, _, body) => body.register_strings(string_set), - TopLevel::Statement(stmt) => stmt.register_strings(string_set), - } - } -} - impl Expression { fn register_strings(&self, _string_set: &mut HashSet>) { // nothing has a string in here, at the moment diff --git a/src/ir/top_level.rs b/src/ir/top_level.rs index 728472d..7706112 100644 --- a/src/ir/top_level.rs +++ b/src/ir/top_level.rs @@ -1,38 +1,18 @@ -use crate::ir::{Expression, Program, TopLevel, TypeWithFunction, TypeWithVoid, Variable}; +use crate::ir::{Expression, Name, Program, TypeWithFunction, TypeWithVoid}; use std::collections::HashMap; impl Program { /// Retrieve the complete set of variables that are defined at the top level of /// this program. - pub fn get_top_level_variables(&self) -> HashMap { - let mut result = HashMap::new(); - - for item in self.items.iter() { - result.extend(item.get_top_level_variables()); - } - - result - } -} - -impl TopLevel { - /// Retrieve the complete set of variables that are defined at the top level of - /// this top-level item. - /// - /// For functions, this is the function name. For expressions this can be a little - /// bit more complicated, as it sort of depends on the block structuring. - pub fn get_top_level_variables(&self) -> HashMap { - match self { - TopLevel::Function(name, _, _, _) => HashMap::from([(name.clone(), self.type_of())]), - TopLevel::Statement(expr) => expr.get_top_level_variables(), - } + pub fn get_top_level_variables(&self) -> HashMap { + self.body.get_top_level_variables() } } impl Expression { /// Retrieve the complete set of variables that are defined at the top level of /// this expression. Basically, returns the variable named in bind. - pub fn get_top_level_variables(&self) -> HashMap { + pub fn get_top_level_variables(&self) -> HashMap { match self { Expression::Bind(_, name, ty, expr) => { let mut tlvs = expr.get_top_level_variables(); diff --git a/src/lambda_lift.rs b/src/lambda_lift.rs deleted file mode 100644 index 339b1c0..0000000 --- a/src/lambda_lift.rs +++ /dev/null @@ -1 +0,0 @@ -mod free_variables; diff --git a/src/lib.rs b/src/lib.rs index fcb21f6..23f924c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,6 @@ pub mod eval; #[cfg(test)] mod examples; pub mod ir; -pub mod lambda_lift; pub mod syntax; pub mod type_infer; pub mod util; diff --git a/src/repl.rs b/src/repl.rs index 473dfa4..5daacd1 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,5 +1,5 @@ use crate::backend::{Backend, BackendError}; -use crate::syntax::{ConstantType, Expression, Location, ParserError, TopLevel}; +use crate::syntax::{ConstantType, Expression, Location, Name, ParserError, Program, TopLevel}; use crate::type_infer::TypeInferenceResult; use crate::util::scoped_map::ScopedMap; use codespan_reporting::diagnostic::Diagnostic; @@ -128,86 +128,76 @@ impl REPL { .expect("entry exists") .source(); let syntax = TopLevel::parse(entry, source)?; - - let program = match syntax { + let top_levels = match syntax { TopLevel::Expression(Expression::Binding(loc, name, expr)) => { // if this is a variable binding, and we've never defined this variable before, // we should tell cranelift about it. this is optimistic; if we fail to compile, // then we won't use this definition until someone tries again. - if !self.variable_binding_sites.contains_key(&name.current_name().to_string()) { + if !self + .variable_binding_sites + .contains_key(&name.current_name().to_string()) + { self.jitter.define_string(name.current_name())?; self.jitter - .define_variable(name.to_string(), ConstantType::U64)?; + .define_variable(name.clone(), ConstantType::U64)?; } - crate::syntax::Program { - items: vec![ - TopLevel::Expression(Expression::Binding(loc.clone(), name.clone(), expr)), - TopLevel::Expression(Expression::Call( + vec![ + TopLevel::Expression(Expression::Binding(loc.clone(), name.clone(), expr)), + TopLevel::Expression(Expression::Call( + loc.clone(), + Box::new(Expression::Primitive( loc.clone(), - Box::new(Expression::Primitive( - loc.clone(), - crate::syntax::Name::manufactured("print"), - )), - vec![Expression::Reference(name.clone())], + crate::syntax::Name::manufactured("print"), )), - ], - } + vec![Expression::Reference(name.clone())], + )), + ] } - x => crate::syntax::Program { items: vec![x] }, + x => vec![x], }; - - let (mut errors, mut warnings) = - program.validate_with_bindings(&mut self.variable_binding_sites); - let stop = !errors.is_empty(); - let messages = errors - .drain(..) - .map(Into::into) - .chain(warnings.drain(..).map(Into::into)); - - for message in messages { + let mut validation_result = + Program::validate_with_bindings(top_levels, &mut self.variable_binding_sites); + for message in validation_result.diagnostics() { self.emit_diagnostic(message)?; } - if stop { - return Ok(()); - } + if let Some(program) = validation_result.into_result() { + match program.type_infer() { + TypeInferenceResult::Failure { + mut errors, + mut warnings, + } => { + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); - match program.type_infer() { - TypeInferenceResult::Failure { - mut errors, - mut warnings, - } => { - let messages = errors - .drain(..) - .map(Into::into) - .chain(warnings.drain(..).map(Into::into)); - - for message in messages { - self.emit_diagnostic(message)?; + for message in messages { + self.emit_diagnostic(message)?; + } } - Ok(()) - } + TypeInferenceResult::Success { + result, + mut warnings, + } => { + for message in warnings.drain(..).map(Into::into) { + self.emit_diagnostic(message)?; + } - TypeInferenceResult::Success { - result, - mut warnings, - } => { - for message in warnings.drain(..).map(Into::into) { - self.emit_diagnostic(message)?; + let name = Name::new(format!("line{}", line_no), Location::manufactured()); + let function_id = self.jitter.compile_program(name, result)?; + self.jitter.module.finalize_definitions()?; + let compiled_bytes = self.jitter.bytes(function_id); + let compiled_function = + unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; + compiled_function(); } - - let name = format!("line{}", line_no); - let function_id = self.jitter.compile_program(&name, result)?; - self.jitter.module.finalize_definitions()?; - let compiled_bytes = self.jitter.bytes(function_id); - let compiled_function = - unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; - compiled_function(); - Ok(()) } } + + Ok(()) } } diff --git a/src/syntax.rs b/src/syntax.rs index 7e54c31..15ee23a 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -30,8 +30,10 @@ use logos::Logos; pub mod arbitrary; mod ast; pub mod eval; +mod free_variables; mod location; mod name; +mod replace_references; mod tokens; lalrpop_mod!( #[allow(clippy::just_underscores_and_digits, clippy::clone_on_copy)] @@ -41,19 +43,13 @@ lalrpop_mod!( pub mod pretty; mod validate; -#[cfg(test)] -use crate::syntax::arbitrary::GenerationEnvironment; pub use crate::syntax::ast::*; pub use crate::syntax::location::Location; pub use crate::syntax::name::Name; -pub use crate::syntax::parser::{ProgramParser, TopLevelParser, ExpressionParser}; +pub use crate::syntax::parser::{ExpressionParser, ProgramParser, TopLevelParser}; pub use crate::syntax::tokens::{LexerError, Token}; use lalrpop_util::ParseError; -#[cfg(test)] -use proptest::{arbitrary::Arbitrary, prop_assert}; use std::ops::Range; -#[cfg(test)] -use std::str::FromStr; use thiserror::Error; /// One of the many errors that can occur when processing text input. @@ -206,38 +202,36 @@ impl<'a> From<&'a ParserError> for Diagnostic { } } -impl Program { - /// Parse the given file, adding it to the database as part of the process. - /// - /// This operation reads the file from disk and adds it to the database for future - /// reference. If you get an error, we strongly suggest conversion to [`Diagnostic`] - /// and then reporting it to the user via [`codespan_reporting`]. You should use - /// this function if you're pretty sure that you've never seen this file before, - /// and [`Program::parse`] if you have and know its index and already have it in - /// memory. - pub fn parse_file( - file_database: &mut SimpleFiles, - file_name: &str, - ) -> Result { - let file_contents = std::fs::read_to_string(file_name)?; - let file_handle = file_database.add(file_name.to_string(), file_contents); - let file_db_info = file_database.get(file_handle)?; - Program::parse(file_handle, file_db_info.source()) - } +/// Parse the given file, adding it to the database as part of the process. +/// +/// This operation reads the file from disk and adds it to the database for future +/// reference. If you get an error, we strongly suggest conversion to [`Diagnostic`] +/// and then reporting it to the user via [`codespan_reporting`]. You should use +/// this function if you're pretty sure that you've never seen this file before, +/// and [`Program::parse`] if you have and know its index and already have it in +/// memory. +pub fn parse_file( + file_database: &mut SimpleFiles, + file_name: &str, +) -> Result, ParserError> { + let file_contents = std::fs::read_to_string(file_name)?; + let file_handle = file_database.add(file_name.to_string(), file_contents); + let file_db_info = file_database.get(file_handle)?; + parse_string(file_handle, file_db_info.source()) +} - /// Parse a block of text you have in memory, using the given index for [`Location`]s. - /// - /// If you use a nonsensical file index, everything will work fine until you try to - /// report an error, at which point [`codespan_reporting`] may have some nasty things - /// to say to you. - pub fn parse(file_idx: usize, buffer: &str) -> Result { - let lexer = Token::lexer(buffer) - .spanned() - .map(|x| permute_lexer_result(file_idx, x)); - ProgramParser::new() - .parse(file_idx, lexer) - .map_err(|e| ParserError::convert(file_idx, e)) - } +/// Parse a block of text you have in memory, using the given index for [`Location`]s. +/// +/// If you use a nonsensical file index, everything will work fine until you try to +/// report an error, at which point [`codespan_reporting`] may have some nasty things +/// to say to you. +pub fn parse_string(file_idx: usize, buffer: &str) -> Result, ParserError> { + let lexer = Token::lexer(buffer) + .spanned() + .map(|x| permute_lexer_result(file_idx, x)); + ProgramParser::new() + .parse(file_idx, lexer) + .map_err(|e| ParserError::convert(file_idx, e)) } impl TopLevel { @@ -286,70 +280,58 @@ fn permute_lexer_result( } } -#[cfg(test)] -impl FromStr for Program { - type Err = ParserError; - - fn from_str(s: &str) -> Result { - Program::parse(0, s) - } -} - #[test] fn order_of_operations() { let muladd1 = "x = 1 + 2 * 3;"; let testfile = 0; assert_eq!( - Program::from_str(muladd1).unwrap(), - Program { - items: vec![TopLevel::Expression(Expression::Binding( - Location::new(testfile, 0..1), - Name::manufactured("x"), - Box::new(Expression::Call( + parse_string(0, muladd1).unwrap(), + vec![TopLevel::Expression(Expression::Binding( + Location::new(testfile, 0..1), + Name::manufactured("x"), + Box::new(Expression::Call( + Location::new(testfile, 6..7), + Box::new(Expression::Primitive( Location::new(testfile, 6..7), - Box::new(Expression::Primitive( - Location::new(testfile, 6..7), - Name::manufactured("+") - )), - vec![ - Expression::Value( - Location::new(testfile, 4..5), - Value::Number(None, None, 1), - ), - Expression::Call( + Name::manufactured("+") + )), + vec![ + Expression::Value(Location::new(testfile, 4..5), Value::Number(None, None, 1),), + Expression::Call( + Location::new(testfile, 10..11), + Box::new(Expression::Primitive( Location::new(testfile, 10..11), - Box::new(Expression::Primitive( - Location::new(testfile, 10..11), - Name::manufactured("*") - )), - vec![ - Expression::Value( - Location::new(testfile, 8..9), - Value::Number(None, None, 2), - ), - Expression::Value( - Location::new(testfile, 12..13), - Value::Number(None, None, 3), - ), - ] - ) - ] - )) - ))], - } + Name::manufactured("*") + )), + vec![ + Expression::Value( + Location::new(testfile, 8..9), + Value::Number(None, None, 2), + ), + Expression::Value( + Location::new(testfile, 12..13), + Value::Number(None, None, 3), + ), + ] + ) + ] + )) + ))], ); } proptest::proptest! { #[test] - fn random_syntaxes_validate(program: Program) { - let (errors, _) = program.validate(); - prop_assert!(errors.is_empty()); + fn random_syntaxes_validate(program in self::arbitrary::ProgramGenerator::default()) { + let result = Program::validate(program); + proptest::prop_assert!(result.is_ok()); } #[test] - fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { + fn generated_run_or_overflow(program in self::arbitrary::ProgramGenerator::default()) { use crate::eval::{EvalError, PrimOpError}; - prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))); + let validated = Program::validate(program); + let actual_program = validated.into_result().expect("got a valid result"); + proptest::prop_assert!(matches!(actual_program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))); } } diff --git a/src/syntax/arbitrary.rs b/src/syntax/arbitrary.rs index c87d941..1d489a7 100644 --- a/src/syntax/arbitrary.rs +++ b/src/syntax/arbitrary.rs @@ -1,7 +1,9 @@ use crate::syntax::ast::{ConstantType, Expression, Program, TopLevel, Value}; -use crate::syntax::name::Name; use crate::syntax::location::Location; +use crate::syntax::name::Name; use proptest::sample::select; +use proptest::strategy::{NewTree, ValueTree}; +use proptest::test_runner::{TestRng, TestRunner}; use proptest::{ prelude::{Arbitrary, BoxedStrategy, Strategy}, strategy::{Just, Union}, @@ -11,248 +13,296 @@ use std::ops::Range; pub const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*"; -impl ConstantType { - fn get_operators(&self) -> &'static [(&'static str, usize)] { - match self { - ConstantType::Void => &[], - ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => { - &[("+", 2), ("negate", 1), ("-", 2), ("*", 2), ("/", 2)] - } - ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => { - &[("+", 2), ("-", 2), ("*", 2), ("/", 2)] - } +#[derive(Debug, Default)] +pub struct ProgramGenerator {} + +impl Strategy for ProgramGenerator { + type Tree = ProgramTree; + type Value = Vec; + + fn new_tree(&self, runner: &mut TestRunner) -> NewTree { + unimplemented!() + } +} + +pub struct ProgramTree { + _rng: TestRng, + current: Vec, +} + +impl ProgramTree { + fn new(mut rng: TestRng) -> Self { + ProgramTree { + _rng: rng, + current: vec![], } } } -#[derive(Clone)] -pub struct GenerationEnvironment { - allow_inference: bool, - block_length: Range, - bindings: HashMap, - return_type: ConstantType, -} +impl ValueTree for ProgramTree { + type Value = Vec; -impl Default for GenerationEnvironment { - fn default() -> Self { - GenerationEnvironment { - allow_inference: true, - block_length: 2..10, - bindings: HashMap::new(), - return_type: ConstantType::U64, - } + fn current(&self) -> Self::Value { + self.current.clone() + } + + fn simplify(&mut self) -> bool { + false + } + + fn complicate(&mut self) -> bool { + false } } -impl GenerationEnvironment { - pub fn new(allow_inference: bool) -> Self { - GenerationEnvironment { - allow_inference, - ..Default::default() - } - } -} - -impl Arbitrary for Program { - type Parameters = GenerationEnvironment; - type Strategy = BoxedStrategy; - - fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { - proptest::collection::vec( - ProgramTopLevelInfo::arbitrary(), - genenv.block_length.clone(), - ) - .prop_flat_map(move |mut ptlis| { - let mut items = Vec::new(); - let mut genenv = genenv.clone(); - - for psi in ptlis.drain(..) { - if genenv.bindings.is_empty() || psi.should_be_binding { - genenv.return_type = psi.binding_type; - let expr = Expression::arbitrary_with(genenv.clone()); - genenv.bindings.insert(psi.name.clone(), psi.binding_type); - items.push( - expr.prop_map(move |expr| { - TopLevel::Expression(Expression::Binding( - Location::manufactured(), - psi.name.clone(), - Box::new(expr), - )) - }) - .boxed(), - ); - } else { - let printers = genenv.bindings.keys().map(|n| { - Just(TopLevel::Expression(Expression::Call( - Location::manufactured(), - Box::new(Expression::Primitive( - Location::manufactured(), - Name::manufactured("print"), - )), - vec![Expression::Reference(n.clone())], - ))) - }); - items.push(Union::new(printers).boxed()); - } - } - - items.prop_map(|items| Program { items }).boxed() - }) - .boxed() - } -} - -impl Arbitrary for Name { - type Parameters = (); - type Strategy = BoxedStrategy; - - fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed() - } -} - -#[derive(Debug)] -struct ProgramTopLevelInfo { - should_be_binding: bool, - name: Name, - binding_type: ConstantType, -} - -impl Arbitrary for ProgramTopLevelInfo { - type Parameters = (); - type Strategy = BoxedStrategy; - - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - ( - Union::new(vec![Just(true), Just(true), Just(false)]), - Name::arbitrary(), - ConstantType::arbitrary(), - ) - .prop_map( - |(should_be_binding, name, binding_type)| ProgramTopLevelInfo { - should_be_binding, - name, - binding_type, - }, - ) - .boxed() - } -} - -impl Arbitrary for Expression { - type Parameters = GenerationEnvironment; - type Strategy = BoxedStrategy; - - fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { - // Value(Location, Value). These are the easiest variations to create, because we can always - // create one. - let value_strategy = Value::arbitrary_with(genenv.clone()) - .prop_map(|x| Expression::Value(Location::manufactured(), x)) - .boxed(); - - // Reference(Location, String), These are slightly trickier, because we can end up in a situation - // where either no variables are defined, or where none of the defined variables have a type we - // can work with. So what we're going to do is combine this one with the previous one as a "leaf - // strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy - // if we can't actually create an references. - let mut bound_variables_of_type = genenv - .bindings - .iter() - .filter(|(_, v)| genenv.return_type == **v) - .map(|(n, _)| n) - .collect::>(); - let leaf_strategy = if bound_variables_of_type.is_empty() { - value_strategy - } else { - let mut strats = bound_variables_of_type - .drain(..) - .map(|x| { Just(Expression::Reference(x.clone())).boxed() - }) - .collect::>(); - strats.push(value_strategy); - Union::new(strats).boxed() - }; - - // now we generate our recursive types, given our leaf strategy - leaf_strategy - .prop_recursive(3, 10, 2, move |strat| { - ( - select(genenv.return_type.get_operators()), - strat.clone(), - strat, - ) - .prop_map(|((oper, count), left, right)| { - let mut args = vec![left, right]; - while args.len() > count { - args.pop(); - } - Expression::Call( - Location::manufactured(), - Box::new(Expression::Primitive( - Location::manufactured(), - Name::manufactured(oper), - )), - args, - ) - }) - }) - .boxed() - } -} - -impl Arbitrary for Value { - type Parameters = GenerationEnvironment; - type Strategy = BoxedStrategy; - - fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { - let printed_base_strategy = Union::new([ - Just(None::), - Just(Some(2)), - Just(Some(8)), - Just(Some(10)), - Just(Some(16)), - ]); - let value_strategy = u64::arbitrary(); - - (printed_base_strategy, bool::arbitrary(), value_strategy) - .prop_map(move |(base, declare_type, value)| { - let converted_value = match genenv.return_type { - ConstantType::Void => value, - ConstantType::I8 => value % (i8::MAX as u64), - ConstantType::U8 => value % (u8::MAX as u64), - ConstantType::I16 => value % (i16::MAX as u64), - ConstantType::U16 => value % (u16::MAX as u64), - ConstantType::I32 => value % (i32::MAX as u64), - ConstantType::U32 => value % (u32::MAX as u64), - ConstantType::I64 => value % (i64::MAX as u64), - ConstantType::U64 => value, - }; - let ty = if declare_type || !genenv.allow_inference { - Some(genenv.return_type) - } else { - None - }; - Value::Number(base, ty, converted_value) - }) - .boxed() - } -} - -impl Arbitrary for ConstantType { - type Parameters = (); - type Strategy = BoxedStrategy; - - fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - Union::new([ - Just(ConstantType::I8), - Just(ConstantType::I16), - Just(ConstantType::I32), - Just(ConstantType::I64), - Just(ConstantType::U8), - Just(ConstantType::U16), - Just(ConstantType::U32), - Just(ConstantType::U64), - ]) - .boxed() - } -} +//impl ConstantType { +// fn get_operators(&self) -> &'static [(&'static str, usize)] { +// match self { +// ConstantType::Void => &[], +// ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => { +// &[("+", 2), ("negate", 1), ("-", 2), ("*", 2), ("/", 2)] +// } +// ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => { +// &[("+", 2), ("-", 2), ("*", 2), ("/", 2)] +// } +// } +// } +//} +// +//#[derive(Clone)] +//pub struct GenerationEnvironment { +// allow_inference: bool, +// block_length: Range, +// bindings: HashMap, +// return_type: ConstantType, +//} +// +//impl Default for GenerationEnvironment { +// fn default() -> Self { +// GenerationEnvironment { +// allow_inference: true, +// block_length: 2..10, +// bindings: HashMap::new(), +// return_type: ConstantType::U64, +// } +// } +//} +// +//impl GenerationEnvironment { +// pub fn new(allow_inference: bool) -> Self { +// GenerationEnvironment { +// allow_inference, +// ..Default::default() +// } +// } +//} +// +//impl Arbitrary for Program { +// type Parameters = GenerationEnvironment; +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { +// proptest::collection::vec( +// ProgramTopLevelInfo::arbitrary(), +// genenv.block_length.clone(), +// ) +// .prop_flat_map(move |mut ptlis| { +// let mut items = Vec::new(); +// let mut genenv = genenv.clone(); +// +// for psi in ptlis.drain(..) { +// if genenv.bindings.is_empty() || psi.should_be_binding { +// genenv.return_type = psi.binding_type; +// let expr = Expression::arbitrary_with(genenv.clone()); +// genenv.bindings.insert(psi.name.clone(), psi.binding_type); +// items.push( +// expr.prop_map(move |expr| { +// TopLevel::Expression(Expression::Binding( +// Location::manufactured(), +// psi.name.clone(), +// Box::new(expr), +// )) +// }) +// .boxed(), +// ); +// } else { +// let printers = genenv.bindings.keys().map(|n| { +// Just(TopLevel::Expression(Expression::Call( +// Location::manufactured(), +// Box::new(Expression::Primitive( +// Location::manufactured(), +// Name::manufactured("print"), +// )), +// vec![Expression::Reference(n.clone())], +// ))) +// }); +// items.push(Union::new(printers).boxed()); +// } +// } +// +// items +// .prop_map(|items| Program { +// functions: HashMap::new(), +// structures: HashMap::new(), +// body: unimplemented!(), +// }) +// .boxed() +// }) +// .boxed() +// } +//} +// +//impl Arbitrary for Name { +// type Parameters = (); +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { +// VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed() +// } +//} +// +//#[derive(Debug)] +//struct ProgramTopLevelInfo { +// should_be_binding: bool, +// name: Name, +// binding_type: ConstantType, +//} +// +//impl Arbitrary for ProgramTopLevelInfo { +// type Parameters = (); +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { +// ( +// Union::new(vec![Just(true), Just(true), Just(false)]), +// Name::arbitrary(), +// ConstantType::arbitrary(), +// ) +// .prop_map( +// |(should_be_binding, name, binding_type)| ProgramTopLevelInfo { +// should_be_binding, +// name, +// binding_type, +// }, +// ) +// .boxed() +// } +//} +// +//impl Arbitrary for Expression { +// type Parameters = GenerationEnvironment; +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { +// // Value(Location, Value). These are the easiest variations to create, because we can always +// // create one. +// let value_strategy = Value::arbitrary_with(genenv.clone()) +// .prop_map(|x| Expression::Value(Location::manufactured(), x)) +// .boxed(); +// +// // Reference(Location, String), These are slightly trickier, because we can end up in a situation +// // where either no variables are defined, or where none of the defined variables have a type we +// // can work with. So what we're going to do is combine this one with the previous one as a "leaf +// // strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy +// // if we can't actually create an references. +// let mut bound_variables_of_type = genenv +// .bindings +// .iter() +// .filter(|(_, v)| genenv.return_type == **v) +// .map(|(n, _)| n) +// .collect::>(); +// let leaf_strategy = if bound_variables_of_type.is_empty() { +// value_strategy +// } else { +// let mut strats = bound_variables_of_type +// .drain(..) +// .map(|x| Just(Expression::Reference(x.clone())).boxed()) +// .collect::>(); +// strats.push(value_strategy); +// Union::new(strats).boxed() +// }; +// +// // now we generate our recursive types, given our leaf strategy +// leaf_strategy +// .prop_recursive(3, 10, 2, move |strat| { +// ( +// select(genenv.return_type.get_operators()), +// strat.clone(), +// strat, +// ) +// .prop_map(|((oper, count), left, right)| { +// let mut args = vec![left, right]; +// while args.len() > count { +// args.pop(); +// } +// Expression::Call( +// Location::manufactured(), +// Box::new(Expression::Primitive( +// Location::manufactured(), +// Name::manufactured(oper), +// )), +// args, +// ) +// }) +// }) +// .boxed() +// } +//} +// +//impl Arbitrary for Value { +// type Parameters = GenerationEnvironment; +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { +// let printed_base_strategy = Union::new([ +// Just(None::), +// Just(Some(2)), +// Just(Some(8)), +// Just(Some(10)), +// Just(Some(16)), +// ]); +// let value_strategy = u64::arbitrary(); +// +// (printed_base_strategy, bool::arbitrary(), value_strategy) +// .prop_map(move |(base, declare_type, value)| { +// let converted_value = match genenv.return_type { +// ConstantType::Void => value, +// ConstantType::I8 => value % (i8::MAX as u64), +// ConstantType::U8 => value % (u8::MAX as u64), +// ConstantType::I16 => value % (i16::MAX as u64), +// ConstantType::U16 => value % (u16::MAX as u64), +// ConstantType::I32 => value % (i32::MAX as u64), +// ConstantType::U32 => value % (u32::MAX as u64), +// ConstantType::I64 => value % (i64::MAX as u64), +// ConstantType::U64 => value, +// }; +// let ty = if declare_type || !genenv.allow_inference { +// Some(genenv.return_type) +// } else { +// None +// }; +// Value::Number(base, ty, converted_value) +// }) +// .boxed() +// } +//} +// +//impl Arbitrary for ConstantType { +// type Parameters = (); +// type Strategy = BoxedStrategy; +// +// fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { +// Union::new([ +// Just(ConstantType::I8), +// Just(ConstantType::I16), +// Just(ConstantType::I32), +// Just(ConstantType::I64), +// Just(ConstantType::U8), +// Just(ConstantType::U16), +// Just(ConstantType::U32), +// Just(ConstantType::U64), +// ]) +// .boxed() +// } +//} +// diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index c0fa09f..7fb24e9 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -1,6 +1,9 @@ use crate::syntax::name::Name; -use crate::syntax::Location; pub use crate::syntax::tokens::ConstantType; +use crate::syntax::Location; +use std::collections::HashMap; + +use super::location::Located; /// A structure represented a parsed program. /// @@ -12,7 +15,57 @@ pub use crate::syntax::tokens::ConstantType; /// `validate` and it comes back without errors. #[derive(Clone, Debug, PartialEq)] pub struct Program { - pub items: Vec, + pub functions: HashMap, + pub structures: HashMap, + pub body: Expression, +} + +/// A function that we want to compile. +/// +/// Later, when we've done a lot of analysis, the `Option`s +/// will turn into concrete types. For now, though, we stick with +/// the surface syntax and leave them as optional. The name of the +/// function is intentionally duplicated, to make our life easier. +#[derive(Clone, Debug, PartialEq)] +pub struct FunctionDefinition { + pub name: Name, + pub arguments: Vec<(Name, Option)>, + pub return_type: Option, + pub body: Expression, +} + +impl FunctionDefinition { + pub fn new( + name: Name, + arguments: Vec<(Name, Option)>, + return_type: Option, + body: Expression, + ) -> Self { + FunctionDefinition { + name, + arguments, + return_type, + body, + } + } +} + +/// A structure type that we might want to reference in the future. +#[derive(Clone, Debug, PartialEq)] +pub struct StructureDefinition { + pub name: Name, + pub location: Location, + pub fields: Vec<(Name, Option)>, +} + +impl StructureDefinition { + pub fn new(location: Location, name: Name, fields: Vec<(Name, Option)>) -> Self { + StructureDefinition { + name, + location, + fields, + } + } } /// A thing that can sit at the top level of a file. @@ -26,6 +79,15 @@ pub enum TopLevel { Structure(Location, Name, Vec<(Name, Type)>), } +impl Located for TopLevel { + fn location(&self) -> &Location { + match self { + TopLevel::Expression(exp) => exp.location(), + TopLevel::Structure(loc, _, _) => loc, + } + } +} + /// An expression in the underlying syntax. /// /// Like statements, these expressions are guaranteed to have been @@ -114,9 +176,9 @@ impl PartialEq for Expression { } } -impl Expression { +impl Located for Expression { /// Get the location of the expression in the source file (if there is one). - pub fn location(&self) -> &Location { + fn location(&self) -> &Location { match self { Expression::Value(loc, _) => loc, Expression::Constructor(loc, _, _) => loc, diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 58ac657..e07a876 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -1,7 +1,6 @@ use crate::eval::{EvalError, PrimitiveType, Value}; -use crate::syntax::{ConstantType, Expression, Program, TopLevel}; +use crate::syntax::{ConstantType, Expression, Name, Program}; use crate::util::scoped_map::ScopedMap; -use internment::ArcIntern; use std::collections::HashMap; use std::str::FromStr; @@ -20,19 +19,8 @@ impl Program { pub fn eval(&self) -> Result<(Value, String), EvalError> { let mut env = ScopedMap::new(); let mut stdout = String::new(); - let mut last_result = Value::Void; - - for stmt in self.items.iter() { - match stmt { - TopLevel::Expression(expr) => last_result = expr.eval(&mut stdout, &mut env)?, - - TopLevel::Structure(_, _, _) => { - last_result = Value::Void; - } - } - } - - Ok((last_result, stdout)) + let result = self.body.eval(&mut stdout, &mut env)?; + Ok((result, stdout)) } } @@ -40,7 +28,7 @@ impl Expression { fn eval( &self, stdout: &mut String, - env: &mut ScopedMap, Value>, + env: &mut ScopedMap>, ) -> Result, EvalError> { match self { Expression::Value(_, v) => match v { @@ -63,35 +51,37 @@ impl Expression { let mut map = HashMap::with_capacity(fields.len()); for (k, v) in fields.iter() { - map.insert(k.clone().intern(), v.eval(stdout, env)?); + map.insert(k.clone(), v.eval(stdout, env)?); } - Ok(Value::Structure(Some(on.clone().intern()), map)) + Ok(Value::Structure(Some(on.clone()), map)) } Expression::Reference(n) => env - .get(n.current_interned()) - .ok_or_else(|| EvalError::LookupFailed(n.location().clone(), n.current_name().to_string())) + .get(n) + .ok_or_else(|| { + EvalError::LookupFailed(n.location().clone(), n.current_name().to_string()) + }) .cloned(), Expression::FieldRef(loc, expr, field) => { let struck = expr.eval(stdout, env)?; if let Value::Structure(on, mut fields) = struck { - if let Some(value) = fields.remove(&field.clone().intern()) { + if let Some(value) = fields.remove(&field.clone()) { Ok(value) } else { Err(EvalError::BadFieldForStructure( loc.clone(), on, - field.clone().intern(), + field.clone(), )) } } else { Err(EvalError::NoFieldForValue( loc.clone(), struck, - field.clone().intern(), + field.clone(), )) } } @@ -130,8 +120,7 @@ impl Expression { Value::Primitive(name) if name == "print" => { if let [Expression::Reference(name)] = &args[..] { - let value = Expression::Reference(name.clone()) - .eval(stdout, env)?; + let value = Expression::Reference(name.clone()).eval(stdout, env)?; let value = match value { Value::Number(x) => Value::U64(x), x => x, @@ -172,24 +161,20 @@ impl Expression { Expression::Binding(_, name, value) => { let actual_value = value.eval(stdout, env)?; - env.insert(name.clone().intern(), actual_value.clone()); + env.insert(name.clone(), actual_value.clone()); Ok(actual_value) } Expression::Function(_, name, arg_names, _, body) => { let result = Value::Closure( - name.as_ref().map(|n| n.current_interned().clone()), + name.clone(), env.clone(), - arg_names - .iter() - .cloned() - .map(|(x, _)| x.current_interned().clone()) - .collect(), + arg_names.iter().cloned().map(|(x, _)| x.clone()).collect(), *body.clone(), ); if let Some(name) = name { - env.insert(name.clone().intern(), result.clone()); + env.insert(name.clone(), result.clone()); } Ok(result) @@ -200,14 +185,17 @@ impl Expression { #[test] fn two_plus_three() { - let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let (_, output) = input.eval().expect("runs successfully"); + let input = crate::syntax::parse_string(0, "x = 2 + 3; print x;").expect("parse works"); + let program = Program::validate(input).into_result().unwrap(); + let (_, output) = program.eval().expect("runs successfully"); assert_eq!("x = 5u64\n", &output); } #[test] fn lotsa_math() { - let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let (_, output) = input.eval().expect("runs successfully"); + let input = + crate::syntax::parse_string(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); + let program = Program::validate(input).into_result().unwrap(); + let (_, output) = program.eval().expect("runs successfully"); assert_eq!("x = 7u64\n", &output); } diff --git a/src/lambda_lift/free_variables.rs b/src/syntax/free_variables.rs similarity index 88% rename from src/lambda_lift/free_variables.rs rename to src/syntax/free_variables.rs index 2f8f868..1c53ac3 100644 --- a/src/lambda_lift/free_variables.rs +++ b/src/syntax/free_variables.rs @@ -9,22 +9,22 @@ impl Expression { pub fn free_variables(&self) -> HashSet { match self { Expression::Value(_, _) => HashSet::new(), - Expression::Constructor(_, _, args) => + Expression::Constructor(_, _, args) => { args.iter().fold(HashSet::new(), |mut existing, (_, expr)| { existing.extend(expr.free_variables()); existing - }), + }) + } Expression::Reference(n) => HashSet::from([n.clone()]), Expression::FieldRef(_, expr, _) => expr.free_variables(), Expression::Cast(_, _, expr) => expr.free_variables(), Expression::Primitive(_, _) => HashSet::new(), - Expression::Call(_, f, args) => - args.iter() - .fold(f.free_variables(), - |mut existing, expr| { - existing.extend(expr.free_variables()); - existing - }), + Expression::Call(_, f, args) => { + args.iter().fold(f.free_variables(), |mut existing, expr| { + existing.extend(expr.free_variables()); + existing + }) + } Expression::Block(_, exprs) => { let mut free_vars = HashSet::new(); let mut bound_vars = HashSet::new(); @@ -64,22 +64,22 @@ impl Expression { pub fn new_bindings(&self) -> HashSet { match self { Expression::Value(_, _) => HashSet::new(), - Expression::Constructor(_, _, args) => + Expression::Constructor(_, _, args) => { args.iter().fold(HashSet::new(), |mut existing, (_, expr)| { existing.extend(expr.new_bindings()); existing - }), + }) + } Expression::Reference(_) => HashSet::new(), Expression::FieldRef(_, expr, _) => expr.new_bindings(), Expression::Cast(_, _, expr) => expr.new_bindings(), Expression::Primitive(_, _) => HashSet::new(), - Expression::Call(_, f, args) => - args.iter() - .fold(f.new_bindings(), - |mut existing, expr| { - existing.extend(expr.new_bindings()); - existing - }), + Expression::Call(_, f, args) => { + args.iter().fold(f.new_bindings(), |mut existing, expr| { + existing.extend(expr.new_bindings()); + existing + }) + } Expression::Block(_, _) => HashSet::new(), Expression::Binding(_, name, expr) => { let mut others = expr.new_bindings(); diff --git a/src/syntax/location.rs b/src/syntax/location.rs index 09e6e0c..bf55823 100644 --- a/src/syntax/location.rs +++ b/src/syntax/location.rs @@ -12,6 +12,10 @@ pub struct Location { location: Range, } +pub trait Located { + fn location(&self) -> &Location; +} + impl Location { /// Generate a new `Location` from a file index and an offset from the /// start of the file. @@ -116,4 +120,36 @@ impl Location { }) } } + + /// Infer a location set by combining all of the information we have + /// in the list of located things. + /// + /// This will attempt to throw away manufactured locations whenever + /// possible, but if there's multiple files mixed in it will likely + /// fail. In all failure cases, including when the set of items is + /// empty, will return a manufactured location to use. + pub fn infer_from(items: &[T]) -> Location { + let mut current = None; + + for item in items { + let location = item.location(); + + if (location.file_idx != 0) + || (location.location.start != 0) + || (location.location.end != 0) + { + current = match current { + None => Some(Some(location.clone())), + Some(None) => Some(None), // we ran into an error somewhere + Some(Some(actual)) => Some(actual.merge(location)), + }; + } + } + + match current { + None => Location::manufactured(), + Some(None) => Location::manufactured(), + Some(Some(x)) => x, + } + } } diff --git a/src/syntax/name.rs b/src/syntax/name.rs index 5ffc5b5..f532de6 100644 --- a/src/syntax/name.rs +++ b/src/syntax/name.rs @@ -2,19 +2,20 @@ use crate::syntax::Location; use internment::ArcIntern; use std::fmt; use std::hash::Hash; +use std::sync::atomic::{AtomicU64, Ordering}; /// The name of a thing in the source language. /// /// In many ways, you can treat this like a string, but it's a very tricky /// string in a couple of ways: -/// +/// /// First, it's a string associated with a particular location in the source /// file, and you can find out what that source location is relatively easily. -/// +/// /// Second, it's a name that retains something of its identity across renaming, /// so that you can keep track of what a variables original name was, as well as /// what it's new name is if it's been renamed. -/// +/// /// Finally, when it comes to equality tests, comparisons, and hashing, `Name` /// uses *only* the new name, if the variable has been renamed, or the original /// name, if it has not been renamed. It never uses the location. This allows @@ -30,7 +31,7 @@ pub struct Name { impl Name { /// Create a new name at the given location. - /// + /// /// This creates an "original" name, which has not been renamed, at the /// given location. pub fn new(n: S, location: Location) -> Name { @@ -42,7 +43,7 @@ impl Name { } /// Create a new name with no location information. - /// + /// /// This creates an "original" name, which has not been renamed, at the /// given location. You should always prefer to use [`Location::new`] if /// there is any possible way to get it, because that will be more @@ -55,8 +56,35 @@ impl Name { } } + /// Create a unique name based on the original name provided. + /// + /// This will automatically append a number and wrap that in + /// <>, which is hoped to be unique. + pub fn gensym(n: S) -> Name { + static GENSYM_COUNTER: AtomicU64 = AtomicU64::new(0); + + let new_name = format!( + "<{}{}>", + n.to_string(), + GENSYM_COUNTER.fetch_add(1, Ordering::SeqCst) + ); + Name { + name: ArcIntern::new(new_name), + rename: None, + location: Location::manufactured(), + } + } + + /// As with gensym, but tie the name to the given location + pub fn located_gensym(location: Location, n: S) -> Name { + Name { + location, + ..Name::gensym(n) + } + } + /// Returns a reference to the original name of the variable. - /// + /// /// Regardless of whether or not the function has been renamed, this will /// return whatever name this variable started with. pub fn original_name(&self) -> &str { @@ -64,11 +92,14 @@ impl Name { } /// Returns a reference to the current name of the variable. - /// + /// /// If the variable has been renamed, it will return that, otherwise we'll /// return the current name. pub fn current_name(&self) -> &str { - self.rename.as_ref().map(|x| x.as_str()).unwrap_or_else(|| self.name.as_str()) + self.rename + .as_ref() + .map(|x| x.as_str()) + .unwrap_or_else(|| self.name.as_str()) } /// Returns the current name of the variable as an interned string. @@ -110,4 +141,3 @@ impl fmt::Display for Name { self.current_name().fmt(f) } } - diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index e2777b5..2a9516a 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -9,7 +9,7 @@ //! eventually want to leave lalrpop behind.) //! use crate::syntax::{Location, ParserError}; -use crate::syntax::ast::{Program,TopLevel,Expression,Value,Type}; +use crate::syntax::ast::{TopLevel,Expression,Value,Type}; use crate::syntax::name::Name; use crate::syntax::tokens::{ConstantType, Token}; use internment::ArcIntern; @@ -63,20 +63,12 @@ extern { } } -pub Program: Program = { - // a program is just a set of statements - => Program { - items - }, - => Program { items: vec![] }, -} - -ProgramTopLevel: Vec = { - => { +pub Program: Vec = { + => { rest.push(t); rest }, - => vec![t], + => vec![], } pub TopLevel: TopLevel = { diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 3ba412e..e184aaf 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -6,14 +6,78 @@ impl Program { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { let mut result = allocator.nil(); - for tl in self.items.iter() { + for definition in self.structures.values() { result = result - .append(tl.pretty(allocator)) - .append(allocator.text(";")) + .append(allocator.text("struct")) + .append(allocator.space()) + .append(allocator.text(definition.name.original_name().to_string())) + .append(allocator.space()) + .append(allocator.text("{")) + .append(allocator.hardline()) + .append( + allocator + .concat(definition.fields.iter().map(|(name, ty)| { + let mut type_bit = allocator.nil(); + + if let Some(ty) = ty { + type_bit = allocator + .text(":") + .append(allocator.space()) + .append(ty.pretty(allocator)); + } + + allocator + .text(name.original_name().to_string()) + .append(type_bit) + .append(allocator.text(";")) + .append(allocator.hardline()) + })) + .nest(2), + ) + .append(allocator.text("}")) .append(allocator.hardline()); } - result + for definition in self.functions.values() { + let mut return_type_bit = allocator.nil(); + + if let Some(rettype) = definition.return_type.as_ref() { + return_type_bit = allocator + .text("->") + .append(allocator.space()) + .append(rettype.pretty(allocator)); + } + + result = result + .append(allocator.text("function")) + .append(allocator.space()) + .append(allocator.text(definition.name.original_name().to_string())) + .append(allocator.text("(")) + .append(allocator.intersperse( + definition.arguments.iter().map(|(x, t)| { + let mut type_bit = allocator.nil(); + + if let Some(ty) = t { + type_bit = allocator + .text(":") + .append(allocator.space()) + .append(ty.pretty(allocator)); + } + + allocator + .text(x.original_name().to_string()) + .append(type_bit) + }), + allocator.text(","), + )) + .append(allocator.text(")")) + .append(return_type_bit) + .append(allocator.softline()) + .append(definition.body.pretty(allocator)) + .append(allocator.hardline()); + } + + result.append(self.body.pretty(allocator)) } } diff --git a/src/syntax/replace_references.rs b/src/syntax/replace_references.rs new file mode 100644 index 0000000..3c7b7d5 --- /dev/null +++ b/src/syntax/replace_references.rs @@ -0,0 +1,51 @@ +use super::{Expression, Name}; +use std::collections::HashMap; + +impl Expression { + /// Replace all references in the given map to their alternative expression values + pub fn replace_references(&mut self, map: &mut HashMap) { + match self { + Expression::Value(_, _) => {} + Expression::Constructor(_, _, items) => { + for (_, item) in items.iter_mut() { + item.replace_references(map); + } + } + Expression::Reference(name) => match map.get(name) { + None => {} + Some(x) => *self = x.clone(), + }, + Expression::FieldRef(_, subexp, _) => { + subexp.replace_references(map); + } + Expression::Cast(_, _, subexp) => { + subexp.replace_references(map); + } + Expression::Primitive(_, _) => {} + Expression::Call(_, fun, args) => { + fun.replace_references(map); + for arg in args.iter_mut() { + arg.replace_references(map); + } + } + Expression::Block(_, exprs) => { + for expr in exprs.iter_mut() { + expr.replace_references(map); + } + } + Expression::Binding(_, n, expr) => { + expr.replace_references(map); + map.remove(n); + } + Expression::Function(_, mname, args, _, body) => { + if let Some(name) = mname { + map.remove(name); + } + for (arg_name, _) in args.iter() { + map.remove(arg_name); + } + body.replace_references(map); + } + } + } +} diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 4db79c8..8655643 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -1,11 +1,13 @@ -use crate::{ - eval::PrimitiveType, - syntax::{Expression, Location, Program, TopLevel}, - util::scoped_map::ScopedMap, -}; +use crate::eval::PrimitiveType; +use crate::syntax::{Expression, Location, Program, StructureDefinition, TopLevel}; +use crate::util::scoped_map::ScopedMap; +use crate::util::warning_result::WarningResult; use codespan_reporting::diagnostic::Diagnostic; +use std::collections::HashMap; use std::str::FromStr; +use super::{FunctionDefinition, Name, Type}; + /// An error we found while validating the input program. /// /// These errors indicate that we should stop trying to compile @@ -66,9 +68,9 @@ impl Program { /// This checks for things like references to variables that don't exist, for /// example, and generates warnings for things that are inadvisable but not /// actually a problem. - pub fn validate(&self) -> (Vec, Vec) { + pub fn validate(raw_syntax: Vec) -> WarningResult { let mut bound_variables = ScopedMap::new(); - self.validate_with_bindings(&mut bound_variables) + Self::validate_with_bindings(raw_syntax, &mut bound_variables) } /// Validate that the program makes semantic sense, not just syntactic sense. @@ -77,118 +79,137 @@ impl Program { /// example, and generates warnings for things that are inadvisable but not /// actually a problem. pub fn validate_with_bindings( - &self, + raw_syntax: Vec, bound_variables: &mut ScopedMap, - ) -> (Vec, Vec) { - let mut errors = vec![]; - let mut warnings = vec![]; + ) -> WarningResult { + let mut functions = HashMap::new(); + let mut structures = HashMap::new(); + let mut result = WarningResult::ok(vec![]); + let location = Location::infer_from(&raw_syntax); - for stmt in self.items.iter() { - let (mut new_errors, mut new_warnings) = stmt.validate_with_bindings(bound_variables); - errors.append(&mut new_errors); - warnings.append(&mut new_warnings); + for stmt in raw_syntax.into_iter() { + match stmt { + TopLevel::Expression(expr) => { + let expr_result = + expr.validate(bound_variables, &mut structures, &mut functions); + result = result.merge_with(expr_result, |mut previous, current| { + previous.push(current); + Ok(previous) + }); + } + + TopLevel::Structure(loc, name, fields) => { + let definition = StructureDefinition::new( + loc, + name.clone(), + fields.into_iter().map(|(n, t)| (n, Some(t))).collect(), + ); + + structures.insert(name, definition); + } + } } - (errors, warnings) - } -} - -impl TopLevel { - /// Validate that the top level item makes semantic sense, not just syntactic - /// sense. - /// - /// This checks for things like references to variables that don't exist, for - /// example, and generates warnings for thins that are inadvisable but not - /// actually a problem. - pub fn validate(&self) -> (Vec, Vec) { - let mut bound_variables = ScopedMap::new(); - self.validate_with_bindings(&mut bound_variables) - } - - /// Validate that the top level item makes semantic sense, not just syntactic - /// sense. - /// - /// This checks for things like references to variables that don't exist, for - /// example, and generates warnings for thins that are inadvisable but not - /// actually a problem. - pub fn validate_with_bindings( - &self, - bound_variables: &mut ScopedMap, - ) -> (Vec, Vec) { - match self { - TopLevel::Expression(expr) => expr.validate(bound_variables), - TopLevel::Structure(_, _, _) => (vec![], vec![]), - } + result.map(move |exprs| Program { + functions, + structures, + body: Expression::Block(location, exprs), + }) } } impl Expression { fn validate( - &self, + self, variable_map: &mut ScopedMap, - ) -> (Vec, Vec) { + structure_map: &mut HashMap, + function_map: &mut HashMap, + ) -> WarningResult { match self { - Expression::Value(_, _) => (vec![], vec![]), - Expression::Constructor(_, _, fields) => { - let mut errors = vec![]; - let mut warnings = vec![]; + Expression::Value(_, _) => WarningResult::ok(self), - for (_, expr) in fields.iter() { - let (mut e, mut w) = expr.validate(variable_map); - errors.append(&mut e); - warnings.append(&mut w); + Expression::Constructor(location, name, fields) => { + let mut result = WarningResult::ok(vec![]); + + for (name, expr) in fields.into_iter() { + let expr_result = expr.validate(variable_map, structure_map, function_map); + result = result.merge_with(expr_result, move |mut fields, new_expr| { + fields.push((name, new_expr)); + Ok(fields) + }); } - (errors, warnings) + result.map(move |fields| Expression::Constructor(location, name, fields)) } - Expression::Reference(var) if variable_map.contains_key(&var.original_name().to_string()) => (vec![], vec![]), - Expression::Reference(var) => ( - vec![Error::UnboundVariable(var.location().clone(), var.original_name().to_string())], - vec![], - ), - Expression::FieldRef(_, exp, _) => exp.validate(variable_map), + + Expression::Reference(ref var) + if variable_map.contains_key(&var.original_name().to_string()) => + { + WarningResult::ok(self) + } + Expression::Reference(var) => WarningResult::err(Error::UnboundVariable( + var.location().clone(), + var.original_name().to_string(), + )), + + Expression::FieldRef(location, exp, field) => exp + .validate(variable_map, structure_map, function_map) + .map(|x| Expression::FieldRef(location, Box::new(x), field)), + Expression::Cast(location, t, expr) => { - let (mut errs, warns) = expr.validate(variable_map); + let mut expr_result = expr.validate(variable_map, structure_map, function_map); - if PrimitiveType::from_str(t).is_err() { - errs.push(Error::UnknownType(location.clone(), t.clone())) + if PrimitiveType::from_str(&t).is_err() { + expr_result.add_error(Error::UnknownType(location.clone(), t.clone())); } - (errs, warns) + expr_result.map(|e| Expression::Cast(location, t, Box::new(e))) } - Expression::Primitive(_, _) => (vec![], vec![]), - Expression::Call(_, func, args) => { - let (mut errors, mut warnings) = func.validate(variable_map); - for arg in args.iter() { - let (mut e, mut w) = arg.validate(variable_map); - errors.append(&mut e); - warnings.append(&mut w); + // FIXME: Check for valid primitives here!! + Expression::Primitive(_, _) => WarningResult::ok(self), + + Expression::Call(loc, func, args) => { + let mut result = func + .validate(variable_map, structure_map, function_map) + .map(|x| (x, vec![])); + + for arg in args.into_iter() { + let expr_result = arg.validate(variable_map, structure_map, function_map); + result = + result.merge_with(expr_result, |(func, mut previous_args), new_arg| { + previous_args.push(new_arg); + Ok((func, previous_args)) + }); } - (errors, warnings) + result.map(|(func, args)| Expression::Call(loc, Box::new(func), args)) } - Expression::Block(_, stmts) => { - let mut errors = vec![]; - let mut warnings = vec![]; - for stmt in stmts.iter() { - let (mut local_errors, mut local_warnings) = stmt.validate(variable_map); + Expression::Block(loc, stmts) => { + let mut result = WarningResult::ok(vec![]); - errors.append(&mut local_errors); - warnings.append(&mut local_warnings); + for stmt in stmts.into_iter() { + let stmt_result = stmt.validate(variable_map, structure_map, function_map); + result = result.merge_with(stmt_result, |mut stmts, stmt| { + stmts.push(stmt); + Ok(stmts) + }); } - (errors, warnings) + result.map(|stmts| Expression::Block(loc, stmts)) } + Expression::Binding(loc, var, val) => { // we're going to make the decision that a variable is not bound in the right // hand side of its binding, which makes a lot of things easier. So we'll just // immediately check the expression, and go from there. - let (errors, mut warnings) = val.validate(variable_map); + let mut result = val.validate(variable_map, structure_map, function_map); - if let Some(original_binding_site) = variable_map.get(&var.original_name().to_string()) { - warnings.push(Warning::ShadowedVariable( + if let Some(original_binding_site) = + variable_map.get(&var.original_name().to_string()) + { + result.add_warning(Warning::ShadowedVariable( original_binding_site.clone(), loc.clone(), var.to_string(), @@ -197,19 +218,118 @@ impl Expression { variable_map.insert(var.to_string(), loc.clone()); } - (errors, warnings) + result.map(|val| Expression::Binding(loc, var, Box::new(val))) } - Expression::Function(_, name, arguments, _, body) => { - if let Some(name) = name { + + Expression::Function(loc, name, mut arguments, return_type, body) => { + let mut result = WarningResult::ok(()); + + // first we should check for shadowing + for new_name in name.iter().chain(arguments.iter().map(|x| &x.0)) { + if let Some(original_site) = variable_map.get(new_name.original_name()) { + result.add_warning(Warning::ShadowedVariable( + original_site.clone(), + loc.clone(), + new_name.original_name().to_string(), + )); + } + } + + // the function name is now available in our current scope, if the function was given one + if let Some(name) = &name { variable_map.insert(name.original_name().to_string(), name.location().clone()); } + + // the arguments are available in a new scope, which we will use to validate the function + // body variable_map.new_scope(); for (arg, _) in arguments.iter() { variable_map.insert(arg.original_name().to_string(), arg.location().clone()); } - let result = body.validate(variable_map); + + let body_result = body.validate(variable_map, structure_map, function_map); variable_map.release_scope(); - result + + body_result.merge_with(result, move |mut body, _| { + // figure out what, if anything, needs to be in the closure for this function. + let mut free_variables = body.free_variables(); + for (n, _) in arguments.iter() { + free_variables.remove(n); + } + // generate a new name for the closure type we're about to create + let closure_type_name = Name::located_gensym( + loc.clone(), + name.as_ref().map(Name::original_name).unwrap_or("closure_"), + ); + // ... and then create a structure type that has all of the free variables + // in it + let closure_type = StructureDefinition::new( + loc.clone(), + closure_type_name.clone(), + free_variables.iter().map(|x| (x.clone(), None)).collect(), + ); + // this will become the first argument of the function, so name it and add + // it to the argument list. + let closure_arg = Name::gensym("__closure_arg"); + arguments.insert( + 0, + ( + closure_arg.clone(), + Some(Type::Named(closure_type_name.clone())), + ), + ); + // Now make a map from the old free variable names to references into + // our closure argument + let rebinds = free_variables + .into_iter() + .map(|n| { + ( + n.clone(), + Expression::FieldRef( + n.location().clone(), + Box::new(Expression::Reference(closure_arg.clone())), + n, + ), + ) + }) + .collect::>(); + let mut rebind_map = rebinds.iter().cloned().collect(); + // and replace all the references in the function with this map + body.replace_references(&mut rebind_map); + // OK! This function definitely needs a name; if the user didn't give + // it one, we'll do so. + let function_name = + name.unwrap_or_else(|| Name::located_gensym(loc.clone(), "function")); + // And finally, we can make the function definition and insert it into our global + // list along with the new closure type. + let function = FunctionDefinition::new( + function_name.clone(), + arguments.clone(), + return_type.clone(), + body, + ); + + structure_map.insert(closure_type_name.clone(), closure_type); + function_map.insert(function_name.clone(), function); + + // And the result of this function is a call to a primitive that generates + // the closure value in some sort of reasonable way. + Ok(Expression::Call( + Location::manufactured(), + Box::new(Expression::Primitive( + Location::manufactured(), + Name::new("", Location::manufactured()), + )), + vec![ + Expression::Reference(function_name), + Expression::Constructor( + Location::manufactured(), + closure_type_name, + rebinds, + ), + ], + )) + }) } } } @@ -217,16 +337,19 @@ impl Expression { #[test] fn cast_checks_are_reasonable() { - let good_stmt = TopLevel::parse(0, "x = 4u8;").expect("valid test case"); - let (good_errs, good_warns) = good_stmt.validate(); + let mut variable_map = ScopedMap::new(); + let mut structure_map = HashMap::new(); + let mut function_map = HashMap::new(); - assert!(good_errs.is_empty()); - assert!(good_warns.is_empty()); + let good_stmt = Expression::parse(0, "x = 4u8;").expect("valid test case"); + let result_good = good_stmt.validate(&mut variable_map, &mut structure_map, &mut function_map); - let bad_stmt = TopLevel::parse(0, "x = 4u8;").expect("valid test case"); - let (bad_errs, bad_warns) = bad_stmt.validate(); + assert!(result_good.is_ok()); + assert!(result_good.warnings().is_empty()); - assert!(bad_warns.is_empty()); - assert_eq!(bad_errs.len(), 1); - assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple")); + let bad_stmt = Expression::parse(0, "x = 4u8;").expect("valid test case"); + let result_err = bad_stmt.validate(&mut variable_map, &mut structure_map, &mut function_map); + + assert!(result_err.is_err()); + assert!(result_err.warnings().is_empty()); } diff --git a/src/type_infer.rs b/src/type_infer.rs index 94bb3cb..6a8e5e2 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -24,26 +24,15 @@ pub use self::result::TypeInferenceResult; use self::warning::TypeInferenceWarning; use crate::ir::ast as ir; use crate::syntax; -#[cfg(test)] -use crate::syntax::arbitrary::GenerationEnvironment; -use internment::ArcIntern; -#[cfg(test)] -use proptest::prelude::Arbitrary; +use crate::syntax::Name; use std::collections::HashMap; -#[derive(Default)] struct InferenceEngine { constraints: Vec, - type_definitions: HashMap, ir::TypeOrVar>, - variable_types: HashMap, ir::TypeOrVar>, - functions: HashMap< - ArcIntern, - ( - Vec<(ArcIntern, ir::TypeOrVar)>, - ir::Expression, - ), - >, - statements: Vec>, + type_definitions: HashMap, + variable_types: HashMap, + functions: HashMap>, + body: ir::Expression, errors: Vec, warnings: Vec, } @@ -55,8 +44,7 @@ impl syntax::Program { /// You really should have made sure that this program was validated before running /// this method, otherwise you may experience panics during operation. pub fn type_infer(self) -> TypeInferenceResult> { - let mut engine = InferenceEngine::default(); - engine.injest_program(self); + let mut engine = InferenceEngine::from(self); engine.solve_constraints(); if engine.errors.is_empty() { @@ -89,10 +77,11 @@ impl syntax::Program { proptest::proptest! { #[test] - fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) { - let syntax_result = input.eval().map(|(x,o)| (x.strip(), o)); - let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); - let ir_evaluator = crate::ir::Evaluator::default(); + fn translation_maintains_semantics(input in syntax::arbitrary::ProgramGenerator::default()) { + let input_program = syntax::Program::validate(input).into_result().expect("can validate random program"); + let syntax_result = input_program.eval().map(|(x,o)| (x.strip(), o)); + let ir = input_program.type_infer().expect("arbitrary should generate type-safe programs"); + let ir_evaluator = crate::ir::Evaluator::::default(); let ir_result = ir_evaluator.eval(ir).map(|(x, o)| (x.strip(), o)); match (syntax_result, ir_result) { (Err(e1), Err(e2)) => proptest::prop_assert_eq!(e1, e2), diff --git a/src/type_infer/constraint.rs b/src/type_infer/constraint.rs index e6520fb..6d05ba2 100644 --- a/src/type_infer/constraint.rs +++ b/src/type_infer/constraint.rs @@ -1,6 +1,5 @@ use crate::ir::TypeOrVar; -use crate::syntax::Location; -use internment::ArcIntern; +use crate::syntax::{Location, Name}; use std::fmt; /// A type inference constraint that we're going to need to solve. @@ -14,7 +13,7 @@ pub enum Constraint { CanCastTo(Location, TypeOrVar, TypeOrVar), /// The given type has the given field in it, and the type of that field /// is as given. - TypeHasField(Location, TypeOrVar, ArcIntern, TypeOrVar), + TypeHasField(Location, TypeOrVar, Name, TypeOrVar), /// The given type must be some numeric type, but this is not a constant /// value, so don't try to default it if we can't figure it out NumericType(Location, TypeOrVar), @@ -29,7 +28,7 @@ pub enum Constraint { /// The given type can be negated IsSigned(Location, TypeOrVar), /// Checks to see if the given named type is equivalent to the provided one. - NamedTypeIs(Location, ArcIntern, TypeOrVar), + NamedTypeIs(Location, Name, TypeOrVar), } impl fmt::Display for Constraint { @@ -56,7 +55,7 @@ impl Constraint { /// with the given type. /// /// Returns whether or not anything was changed in the constraint. - pub fn replace(&mut self, name: &ArcIntern, replace_with: &TypeOrVar) -> bool { + pub fn replace(&mut self, name: &Name, replace_with: &TypeOrVar) -> bool { match self { Constraint::Printable(_, ty) => ty.replace(name, replace_with), Constraint::FitsInNumType(_, ty, _) => ty.replace(name, replace_with), diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 2eac333..1c0060d 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -1,74 +1,109 @@ use super::constraint::Constraint; use super::InferenceEngine; use crate::eval::PrimitiveType; -use crate::ir; -use crate::syntax::{self, ConstantType}; +use crate::ir::{self, Fields}; +use crate::syntax::Name; +use crate::syntax::{self, ConstantType, Location}; use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::str::FromStr; -struct ExpressionInfo { - expression: ir::Expression, - result_type: ir::TypeOrVar, - free_variables: HashSet>, - bound_variables: HashSet>, -} +impl From for InferenceEngine { + fn from(value: syntax::Program) -> Self { + let syntax::Program { + functions, + structures, + body, + } = value; + let mut result = InferenceEngine { + constraints: Vec::new(), + type_definitions: HashMap::new(), + variable_types: HashMap::new(), + functions: HashMap::new(), + body: ir::Expression::Block(Location::manufactured(), ir::TypeOrVar::new(), vec![]), + errors: vec![], + warnings: vec![], + }; + let mut renames = ScopedMap::new(); -impl ExpressionInfo { - fn simple(expression: ir::Expression, result_type: ir::TypeOrVar) -> Self { - ExpressionInfo { - expression, - result_type, - free_variables: HashSet::new(), - bound_variables: HashSet::new(), + // first let's transfer all the type information over into our new + // data structures + for (_, structure) in structures.into_iter() { + let mut fields = Fields::default(); + + for (name, optty) in structure.fields { + match optty { + None => { + let newty = ir::TypeOrVar::new_located(name.location().clone()); + fields.insert(name, newty); + } + + Some(t) => { + let existing_ty = result.convert_type(t); + fields.insert(name, existing_ty); + } + } + } + + result + .type_definitions + .insert(structure.name.clone(), ir::TypeOrVar::Structure(fields)); } + + // then transfer all the functions over to the new system + for (_, function) in functions.into_iter() { + // convert the arguments into the new type scheme. if given, use the ones + // given, otherwise generate a new type variable for us to solve for. + let mut arguments = vec![]; + for (name, ty) in function.arguments.into_iter() { + match ty { + None => { + let inferred_type = ir::TypeOrVar::new_located(name.location().clone()); + arguments.push((name, inferred_type)); + } + + Some(t) => { + arguments.push((name, result.convert_type(t))); + } + } + } + + // similarly, use the provided return type if given, otherwise generate + // a new type variable to use. + let return_type = if let Some(t) = function.return_type { + result.convert_type(t) + } else { + ir::TypeOrVar::new_located(function.name.location().clone()) + }; + + let (body, body_type) = result.convert_expression(function.body, &mut renames); + result.constraints.push(Constraint::Equivalent( + function.name.location().clone(), + return_type.clone(), + body_type, + )); + + let new_function = ir::FunctionDefinition { + name: function.name, + arguments, + return_type, + body, + }; + + result + .functions + .insert(new_function.name.clone(), new_function); + } + + // finally we can transfer the body over + result.body = result.convert_expression(body, &mut renames).0; + + result } } impl InferenceEngine { - /// This function takes a syntactic program and converts it into the IR version of the - /// program, with appropriate type variables introduced and their constraints added to - /// the given database. - /// - /// If the input function has been validated (which it should be), then this should run - /// into no error conditions. However, if you failed to validate the input, then this - /// function can panic. - pub fn injest_program(&mut self, program: syntax::Program) { - let mut renames = ScopedMap::new(); - - for item in program.items.into_iter() { - self.convert_top_level(item, &mut renames); - } - } - - /// This function takes a top-level item and converts it into the IR version of the - /// program, with all the appropriate type variables introduced and their constraints - /// added to the given database. - fn convert_top_level( - &mut self, - top_level: syntax::TopLevel, - renames: &mut ScopedMap, ArcIntern>, - ) { - match top_level { - syntax::TopLevel::Expression(expr) => { - let expr_info = self.convert_expression(expr, renames); - self.statements.push(expr_info.expression); - } - - syntax::TopLevel::Structure(_loc, name, fields) => { - let mut updated_fields = ir::Fields::default(); - - for (name, field_type) in fields.into_iter() { - updated_fields.insert(name.intern(), self.convert_type(field_type)); - } - - self.type_definitions - .insert(name.intern(), ir::TypeOrVar::Structure(updated_fields)); - } - } - } - /// This function takes a syntactic expression and converts it into a series /// of IR statements, adding type variables and constraints as necessary. /// @@ -83,8 +118,8 @@ impl InferenceEngine { fn convert_expression( &mut self, expression: syntax::Expression, - renames: &mut ScopedMap, ArcIntern>, - ) -> ExpressionInfo { + renames: &mut ScopedMap>, + ) -> (ir::Expression, ir::TypeOrVar) { match expression { // converting values is mostly tedious, because there's so many cases // involved @@ -145,7 +180,7 @@ impl InferenceEngine { value, )); - ExpressionInfo::simple( + ( ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), newtype, ) @@ -156,93 +191,73 @@ impl InferenceEngine { let mut result_fields = HashMap::new(); let mut type_fields = ir::Fields::default(); let mut prereqs = vec![]; - let mut free_variables = HashSet::new(); - let mut bound_variables = HashSet::new(); for (name, syntax_expr) in fields.into_iter() { - let field_expr_info = self.convert_expression(syntax_expr, renames); - type_fields.insert(name.clone().intern(), field_expr_info.result_type); - let (prereq, value) = simplify_expr(field_expr_info.expression); - result_fields.insert(name.clone().intern(), value); + let (field_expr, field_type) = self.convert_expression(syntax_expr, renames); + type_fields.insert(name.clone(), field_type); + let (prereq, value) = simplify_expr(field_expr); + result_fields.insert(name.clone(), value); merge_prereq(&mut prereqs, prereq); - free_variables.extend(field_expr_info.free_variables); - bound_variables.extend(field_expr_info.bound_variables); } let result_type = ir::TypeOrVar::Structure(type_fields); self.constraints.push(Constraint::NamedTypeIs( loc.clone(), - name.clone().intern(), + name.clone(), result_type.clone(), )); - let expression = ir::Expression::Construct( - loc, - result_type.clone(), - name.intern(), - result_fields, - ); + let expression = + ir::Expression::Construct(loc, result_type.clone(), name, result_fields); - ExpressionInfo { - expression, - result_type, - free_variables, - bound_variables, - } + (expression, result_type) } syntax::Expression::Reference(mut name) => { - if let Some(rename) = renames.get(name.current_interned()) { + if let Some(rename) = renames.get(&name) { name.rename(rename); } + let result_type = self .variable_types - .get(name.current_interned()) + .get(&name) .cloned() .expect("variable bound before use"); + let expression = ir::Expression::Atomic(ir::ValueOrRef::Ref( name.location().clone(), result_type.clone(), - name.current_interned().clone(), + name.clone(), )); - let free_variables = HashSet::from([name.current_interned().clone()]); - ExpressionInfo { - expression, - result_type, - free_variables, - bound_variables: HashSet::new(), - } + (expression, result_type) } syntax::Expression::FieldRef(loc, expr, field) => { - let mut expr_info = self.convert_expression(*expr, renames); - let (prereqs, val_or_ref) = simplify_expr(expr_info.expression); + let (expr, expr_type) = self.convert_expression(*expr, renames); + let (prereqs, val_or_ref) = simplify_expr(expr); let result_type = ir::TypeOrVar::new(); let result = ir::Expression::FieldRef( loc.clone(), result_type.clone(), - expr_info.result_type.clone(), + expr_type.clone(), val_or_ref, - field.clone().intern(), + field.clone(), ); self.constraints.push(Constraint::TypeHasField( loc, - expr_info.result_type.clone(), - field.intern(), + expr_type.clone(), + field, result_type.clone(), )); - expr_info.expression = finalize_expression(prereqs, result); - expr_info.result_type = result_type; - - expr_info + (finalize_expression(prereqs, result), result_type) } syntax::Expression::Cast(loc, target, expr) => { - let mut expr_info = self.convert_expression(*expr, renames); - let (prereqs, val_or_ref) = simplify_expr(expr_info.expression); + let (expr, expr_type) = self.convert_expression(*expr, renames); + let (prereqs, val_or_ref) = simplify_expr(expr); let target_type: ir::TypeOrVar = PrimitiveType::from_str(&target) .expect("valid type for cast") .into(); @@ -250,14 +265,11 @@ impl InferenceEngine { self.constraints.push(Constraint::CanCastTo( loc, - expr_info.result_type.clone(), + expr_type.clone(), target_type.clone(), )); - expr_info.expression = finalize_expression(prereqs, res); - expr_info.result_type = target_type; - - expr_info + (finalize_expression(prereqs, res), target_type) } syntax::Expression::Primitive(loc, name) => { @@ -273,7 +285,7 @@ impl InferenceEngine { Box::new(numeric_type.clone()), ); let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + (ir::Expression::Atomic(result_value), funtype) } ir::Primitive::Minus => { @@ -285,7 +297,7 @@ impl InferenceEngine { Box::new(numeric_type.clone()), ); let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + (ir::Expression::Atomic(result_value), funtype) } ir::Primitive::Print => { @@ -297,7 +309,7 @@ impl InferenceEngine { Box::new(ir::TypeOrVar::Primitive(PrimitiveType::Void)), ); let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + (ir::Expression::Atomic(result_value), funtype) } ir::Primitive::Negate => { @@ -309,7 +321,7 @@ impl InferenceEngine { let funtype = ir::TypeOrVar::Function(vec![arg_type.clone()], Box::new(arg_type)); let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + (ir::Expression::Atomic(result_value), funtype) } } } @@ -321,34 +333,32 @@ impl InferenceEngine { .map(|_| ir::TypeOrVar::new()) .collect::>(); - let mut expr_info = self.convert_expression(*fun, renames); + let (fun, fun_type) = self.convert_expression(*fun, renames); let target_fun_type = ir::TypeOrVar::Function(arg_types.clone(), Box::new(return_type.clone())); self.constraints.push(Constraint::Equivalent( loc.clone(), - expr_info.result_type, + fun_type, target_fun_type, )); let mut prereqs = vec![]; - let (fun_prereqs, fun) = simplify_expr(expr_info.expression); + let (fun_prereqs, fun) = simplify_expr(fun); merge_prereq(&mut prereqs, fun_prereqs); let new_args = args .into_iter() .zip(arg_types) .map(|(arg, target_type)| { - let arg_info = self.convert_expression(arg, renames); - let location = arg_info.expression.location().clone(); - let (arg_prereq, new_valref) = simplify_expr(arg_info.expression); + let (arg, arg_type) = self.convert_expression(arg, renames); + let location = arg.location().clone(); + let (arg_prereq, new_valref) = simplify_expr(arg); merge_prereq(&mut prereqs, arg_prereq); self.constraints.push(Constraint::Equivalent( location, - arg_info.result_type, + arg_type, target_type, )); - expr_info.free_variables.extend(arg_info.free_variables); - expr_info.bound_variables.extend(arg_info.bound_variables); new_valref }) .collect(); @@ -356,140 +366,40 @@ impl InferenceEngine { let last_call = ir::Expression::Call(loc.clone(), return_type.clone(), fun, new_args); - expr_info.expression = finalize_expressions(prereqs, last_call); - expr_info.result_type = return_type; - - expr_info + (finalize_expressions(prereqs, last_call), return_type) } syntax::Expression::Block(loc, stmts) => { let mut result_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); let mut exprs = vec![]; - let mut free_variables = HashSet::new(); - let mut bound_variables = HashSet::new(); for xpr in stmts.into_iter() { - let expr_info = self.convert_expression(xpr, renames); - result_type = expr_info.result_type; - exprs.push(expr_info.expression); - free_variables.extend( - expr_info - .free_variables - .difference(&bound_variables) - .cloned() - .collect::>(), - ); - bound_variables.extend(expr_info.bound_variables); + let (expr, expr_type) = self.convert_expression(xpr, renames); + result_type = expr_type; + exprs.push(expr); } - ExpressionInfo { - expression: ir::Expression::Block(loc, result_type.clone(), exprs), + ( + ir::Expression::Block(loc, result_type.clone(), exprs), result_type, - free_variables, - bound_variables, - } + ) } syntax::Expression::Binding(loc, name, expr) => { - let mut expr_info = self.convert_expression(*expr, renames); + let (expr, expr_type) = self.convert_expression(*expr, renames); let final_name = self.finalize_name(renames, name); self.variable_types - .insert(final_name.clone(), expr_info.result_type.clone()); - expr_info.expression = ir::Expression::Bind( - loc, - final_name.clone(), - expr_info.result_type.clone(), - Box::new(expr_info.expression), - ); - expr_info.bound_variables.insert(final_name); - expr_info + .insert(final_name.clone(), expr_type.clone()); + let result_expr = + ir::Expression::Bind(loc, final_name, expr_type.clone(), Box::new(expr)); + + (result_expr, expr_type) } - syntax::Expression::Function(loc, name, args, _, expr) => { - // First, at some point we're going to want to know a location for this function, - // which should either be the name if we have one, or the body if we don't. - let function_location = match name { - None => loc, - Some(ref name) => name.location().clone(), - }; - // Next, let us figure out what we're going to name this function. If the user - // didn't provide one, we'll just call it "function:" for them. (We'll - // want a name for this function, eventually, so we might as well do it now.) - // - // If they did provide a name, see if we're shadowed. IF we are, then we'll have - // to specialize the name a bit. Otherwise we'll stick with their name. - let function_name = match name { - None => ir::gensym("function"), - Some(unbound) => self.finalize_name(renames, unbound), - }; - - // This function is going to have a type. We don't know what it is, but it'll have - // one. - let function_type = ir::TypeOrVar::new(); - self.variable_types - .insert(function_name.clone(), function_type.clone()); - - // Then, let's figure out what to do with the argument names, which similarly - // may need to be renamed. We'll also generate some new type variables to associate - // with all of them. - // - // Note that we want to do all this in a new renaming scope, so that we shadow - // appropriately. - renames.new_scope(); - let arginfo = args - .into_iter() - .map(|(name, mut declared_type)| { - let new_type = ir::TypeOrVar::new(); - self.constraints.push(Constraint::IsSomething( - name.location().clone(), - new_type.clone(), - )); - let new_name = self.finalize_name(renames, name.clone()); - self.variable_types - .insert(new_name.clone(), new_type.clone()); - - if let Some(declared_type) = declared_type.take() { - let declared_type = self.convert_type(declared_type); - self.constraints.push(Constraint::Equivalent( - name.location().clone(), - new_type.clone(), - declared_type, - )); - } - - (new_name, new_type) - }) - .collect::>(); - - // Now we manufacture types for the outputs and then a type for the function itself. - // We're not going to make any claims on these types, yet; they're all just unknown - // type variables we need to work out. - let rettype = ir::TypeOrVar::new(); - let actual_function_type = ir::TypeOrVar::Function( - arginfo.iter().map(|x| x.1.clone()).collect(), - Box::new(rettype.clone()), + syntax::Expression::Function(_, _, _, _, _) => { + panic!( + "Function expressions should not survive validation to get to type checking!" ); - self.constraints.push(Constraint::Equivalent( - function_location, - function_type, - actual_function_type, - )); - - // Now let's convert the body over to the new IR. - let expr_info = self.convert_expression(*expr, renames); - self.constraints.push(Constraint::Equivalent( - expr_info.expression.location().clone(), - rettype.clone(), - expr_info.result_type.clone(), - )); - - // Remember to exit this scoping level! - renames.release_scope(); - - self.functions - .insert(function_name, (arginfo, expr_info.expression.clone())); - - unimplemented!() } } } @@ -501,13 +411,14 @@ impl InferenceEngine { let retval = ir::TypeOrVar::new_located(x.location().clone()); self.constraints.push(Constraint::NamedTypeIs( x.location().clone(), - x.intern(), + x, retval.clone(), )); retval } Ok(v) => ir::TypeOrVar::Primitive(v), }, + syntax::Type::Struct(fields) => { let mut new_fields = ir::Fields::default(); @@ -515,7 +426,7 @@ impl InferenceEngine { let new_field_type = field_type .map(|x| self.convert_type(x)) .unwrap_or_else(ir::TypeOrVar::new); - new_fields.insert(name.intern(), new_field_type); + new_fields.insert(name, new_field_type); } ir::TypeOrVar::Structure(new_fields) @@ -525,18 +436,16 @@ impl InferenceEngine { fn finalize_name( &mut self, - renames: &mut ScopedMap, ArcIntern>, - name: syntax::Name, - ) -> ArcIntern { - if self - .variable_types - .contains_key(name.current_interned()) - { - let new_name = ir::gensym(name.original_name()); - renames.insert(name.current_interned().clone(), new_name.clone()); - new_name + renames: &mut ScopedMap>, + mut name: syntax::Name, + ) -> syntax::Name { + if self.variable_types.contains_key(&name) { + let new_name = Name::gensym(name.original_name()).intern(); + renames.insert(name.clone(), new_name.clone()); + name.rename(&new_name); + name } else { - ArcIntern::new(name.to_string()) + name } } } @@ -552,7 +461,7 @@ fn simplify_expr( expr => { let etype = expr.type_of().clone(); let loc = expr.location().clone(); - let nname = ir::gensym("g"); + let nname = Name::located_gensym(loc.clone(), "g"); let nbinding = ir::Expression::Bind(loc.clone(), nname.clone(), etype.clone(), Box::new(expr)); diff --git a/src/type_infer/error.rs b/src/type_infer/error.rs index dcd0a6c..c611be5 100644 --- a/src/type_infer/error.rs +++ b/src/type_infer/error.rs @@ -1,9 +1,8 @@ use super::constraint::Constraint; use crate::eval::PrimitiveType; use crate::ir::{Primitive, TypeOrVar}; -use crate::syntax::Location; +use crate::syntax::{Location, Name}; use codespan_reporting::diagnostic::Diagnostic; -use internment::ArcIntern; /// The various kinds of errors that can occur while doing type inference. pub enum TypeInferenceError { @@ -30,9 +29,9 @@ pub enum TypeInferenceError { /// The given type isn't signed, and can't be negated IsNotSigned(Location, TypeOrVar), /// The given type doesn't have the given field. - NoFieldForType(Location, ArcIntern, TypeOrVar), + NoFieldForType(Location, Name, TypeOrVar), /// There is no type with the given name. - UnknownTypeName(Location, ArcIntern), + UnknownTypeName(Location, Name), } impl From for Diagnostic { diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 9265953..9ab53fe 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -1,60 +1,57 @@ use crate::eval::PrimitiveType; -use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, TypeWithVoid, Value, ValueOrRef}; -use crate::syntax::Location; -use internment::ArcIntern; +use crate::ir::{Expression, FunctionDefinition, Program, Type, TypeOrVar, Value, ValueOrRef}; +use crate::syntax::Name; use std::collections::HashMap; -pub type TypeResolutions = HashMap, Type>; +pub type TypeResolutions = HashMap; impl super::InferenceEngine { pub fn finalize_program(self, resolutions: TypeResolutions) -> Program { + // we can't do this in place without some type nonsense, so we're going to + // create a brand new set of program arguments and then construct the new + // `Program` from them. + let mut functions = HashMap::new(); + let mut type_definitions = HashMap::new(); + + // this is handy for debugging for (name, ty) in resolutions.iter() { tracing::debug!(name = %name, resolved_type = %ty, "resolved type variable"); } - let mut type_definitions = HashMap::new(); - let mut items = Vec::new(); - + // copy over the type definitions for (name, def) in self.type_definitions.into_iter() { type_definitions.insert(name, finalize_type(def, &resolutions)); } - for (name, (arguments, body)) in self.functions.into_iter() { - let new_body = finalize_expression(body, &resolutions); - let arguments = arguments + // now copy over the functions + for (name, function_def) in self.functions.into_iter() { + assert_eq!(name, function_def.name); + + let body = finalize_expression(function_def.body, &resolutions); + let arguments = function_def + .arguments .into_iter() .map(|(name, t)| (name, finalize_type(t, &resolutions))) .collect(); - items.push(TopLevel::Function( + + functions.insert( name, - arguments, - new_body.type_of(), - new_body, - )); + FunctionDefinition { + name: function_def.name, + arguments, + return_type: body.type_of(), + body, + }, + ); } - let mut body = vec![]; - let mut last_type = Type::void(); - let mut location = None; - - for expr in self.statements.into_iter() { - let next = finalize_expression(expr, &resolutions); - location = location - .map(|x: Location| x.merge(next.location())) - .unwrap_or_else(|| Some(next.location().clone())); - last_type = next.type_of(); - body.push(next); - } - - items.push(TopLevel::Statement(Expression::Block( - location.unwrap_or_else(Location::manufactured), - last_type, - body, - ))); + // and now we can finally compute the new body + let body = finalize_expression(self.body, &resolutions); Program { - items, + functions, type_definitions, + body, } } } diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 74f1c9f..e2a162e 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -3,7 +3,7 @@ use super::error::TypeInferenceError; use super::warning::TypeInferenceWarning; use crate::eval::PrimitiveType; use crate::ir::TypeOrVar; -use internment::ArcIntern; +use crate::syntax::Name; impl super::InferenceEngine { /// Solve all the constraints in the provided database. @@ -562,7 +562,7 @@ impl super::InferenceEngine { /// runs if you don't want to lose it. (If you do want to lose it, of course, go ahead.) fn replace_variable( constraint_db: &mut Vec, - variable: &ArcIntern, + variable: &Name, replace_with: &TypeOrVar, ) -> bool { let mut changed_anything = false; diff --git a/src/util.rs b/src/util.rs index 5e15869..d872744 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,2 +1,3 @@ pub mod pretty; pub mod scoped_map; +pub mod warning_result; diff --git a/src/util/warning_result.rs b/src/util/warning_result.rs new file mode 100644 index 0000000..ff410a8 --- /dev/null +++ b/src/util/warning_result.rs @@ -0,0 +1,155 @@ +use codespan_reporting::diagnostic::Diagnostic; + +/// This type is like `Result`, except that both the Ok case and Err case +/// can also include warning data. +/// +/// Unfortunately, this type cannot be used with `?`, and so should probably +/// only be used when it's really, really handy to be able to carry this +/// sort of information. But when it is handy (for example, in type checking +/// and in early validation), it's really useful. +pub enum WarningResult { + Ok(T, Vec), + Err(Vec, Vec), +} + +impl WarningResult { + pub fn ok(value: T) -> Self { + WarningResult::Ok(value, vec![]) + } + + pub fn err(error: E) -> Self { + WarningResult::Err(vec![error], vec![]) + } + + pub fn is_ok(&self) -> bool { + matches!(self, WarningResult::Ok(_, _)) + } + + pub fn is_err(&self) -> bool { + matches!(self, WarningResult::Err(_, _)) + } + + pub fn warnings(&self) -> &[W] { + match self { + WarningResult::Ok(_, warns) => warns.as_slice(), + WarningResult::Err(_, warns) => warns.as_slice(), + } + } + + pub fn into_result(self) -> Option { + match self { + WarningResult::Ok(v, _) => Some(v), + WarningResult::Err(_, _) => None, + } + } + + pub fn into_errors(self) -> Option> { + match self { + WarningResult::Ok(_, _) => None, + WarningResult::Err(errs, _) => Some(errs), + } + } + + pub fn add_warning(&mut self, warning: W) { + match self { + WarningResult::Ok(_, warns) => warns.push(warning), + WarningResult::Err(_, warns) => warns.push(warning), + } + } + + pub fn add_error(&mut self, error: E) { + match self { + WarningResult::Ok(_, warns) => { + *self = WarningResult::Err(vec![error], std::mem::take(warns)) + } + WarningResult::Err(errs, _) => errs.push(error), + } + } + + pub fn modify(&mut self, f: F) + where + F: FnOnce(&mut T), + { + if let WarningResult::Ok(v, _) = self { + f(v); + } + } + + pub fn map(self, f: F) -> WarningResult + where + F: FnOnce(T) -> R, + { + match self { + WarningResult::Ok(v, ws) => WarningResult::Ok(f(v), ws), + WarningResult::Err(e, ws) => WarningResult::Err(e, ws), + } + } + + /// Merges two results together using the given function to combine `Ok` + /// results into a single value. + /// + pub fn merge_with(mut self, other: WarningResult, f: F) -> WarningResult + where + F: FnOnce(T, O) -> Result, + { + match self { + WarningResult::Err(ref mut errors1, ref mut warns1) => match other { + WarningResult::Err(mut errors2, mut warns2) => { + errors1.append(&mut errors2); + warns1.append(&mut warns2); + self + } + + WarningResult::Ok(_, mut ws) => { + warns1.append(&mut ws); + self + } + }, + + WarningResult::Ok(value1, mut warns1) => match other { + WarningResult::Err(errors, mut warns2) => { + warns2.append(&mut warns1); + WarningResult::Err(errors, warns2) + } + + WarningResult::Ok(value2, mut warns2) => { + warns1.append(&mut warns2); + match f(value1, value2) { + Ok(final_value) => WarningResult::Ok(final_value, warns1), + Err(e) => WarningResult::Err(vec![e], warns1), + } + } + }, + } + } +} + +impl WarningResult +where + W: Into>, + E: Into>, +{ + /// Returns the complete set of diagnostics (warnings and errors) as an + /// Iterator. + /// + /// This function removes the diagnostics from the result! So calling + /// this twice is not advised. + pub fn diagnostics(&mut self) -> impl Iterator> { + match self { + WarningResult::Err(errors, warnings) => std::mem::take(errors) + .into_iter() + .map(Into::into) + .chain(std::mem::take(warnings).into_iter().map(Into::into)), + WarningResult::Ok(_, warnings) => + // this is a moderately ridiculous hack to get around + // the two match arms returning different iterator + // types + { + vec![] + .into_iter() + .map(Into::into) + .chain(std::mem::take(warnings).into_iter().map(Into::into)) + } + } + } +}