From 4aa3a9419aeeb215e712429e9b91eea17c2b2f70 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Fri, 24 Mar 2023 10:28:32 -0500 Subject: [PATCH] Very weirdly organized, but it JITs! --- .gitignore | 1 + Cargo.toml | 6 +- src/backend.rs | 16 +++- src/backend/into_crane.rs | 115 ++++++++++++++++++---------- src/{bin.rs => bin/ngrc.rs} | 0 src/bin/ngri.rs | 148 ++++++++++++++++++++++++++++++++++++ src/ir/from_syntax.rs | 8 ++ src/jit.rs | 32 ++++++++ src/jit/engine.rs | 90 ++++++++++++++++++++++ src/lib.rs | 1 + src/syntax.rs | 15 +++- src/syntax/parser.lalrpop | 2 +- src/syntax/simplify.rs | 30 +++++--- src/syntax/validate.rs | 58 +++++++++----- 14 files changed, 441 insertions(+), 81 deletions(-) rename src/{bin.rs => bin/ngrc.rs} (100%) create mode 100644 src/bin/ngri.rs create mode 100644 src/jit.rs create mode 100644 src/jit/engine.rs diff --git a/.gitignore b/.gitignore index 1547085..ce17bee 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ Cargo.lock **/*.o test *.dSYM +.vscode diff --git a/Cargo.toml b/Cargo.toml index 308bb19..ddfdda2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,15 +8,12 @@ edition = "2021" name = "ngr" path = "src/lib.rs" -[[bin]] -name = "ngrc" -path = "src/bin.rs" - [dependencies] clap = { version = "^3.0.14", features = ["derive"] } codespan = "0.11.1" codespan-reporting = "0.11.1" cranelift-codegen = { path = "vendor/wasmtime/cranelift/codegen" } +cranelift-jit = { path = "vendor/wasmtime/cranelift/jit" } cranelift-frontend = { path = "vendor/wasmtime/cranelift/frontend" } cranelift-module = { path = "vendor/wasmtime/cranelift/module" } cranelift-native = { path = "vendor/wasmtime/cranelift/native" } @@ -26,6 +23,7 @@ lalrpop-util = "^0.19.7" lazy_static = "^1.4.0" logos = "^0.12.0" pretty = { version = "^0.11.2", features = ["termcolor"] } +rustyline = "^11.0.0" target-lexicon = "^0.12.5" thiserror = "^1.0.30" diff --git a/src/backend.rs b/src/backend.rs index 049ee8e..348f949 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,19 +1,21 @@ mod into_crane; mod runtime; -use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; +use std::collections::HashMap; + +pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; use crate::ir; use codespan_reporting::diagnostic::Diagnostic; use cranelift_codegen::isa::LookupError; use cranelift_codegen::settings::{Configurable, SetError}; use cranelift_codegen::{isa, settings, CodegenError}; -use cranelift_module::{default_libcall_names, ModuleCompiledFunction, ModuleError}; +use cranelift_module::{default_libcall_names, FuncId, ModuleError}; use cranelift_object::{object, ObjectBuilder, ObjectModule}; use target_lexicon::Triple; use thiserror::Error; pub struct Program { - _compiled: ModuleCompiledFunction, + _func_id: FuncId, module: ObjectModule, } @@ -70,7 +72,13 @@ impl Program { let rtfuns = RuntimeFunctions::new(&platform, &mut object_module)?; Ok(Program { - _compiled: ir.into_cranelift(&mut object_module, &rtfuns)?, + _func_id: ir.into_cranelift::( + &mut object_module, + "gogogo", + &rtfuns, + &HashMap::new(), + &HashMap::new(), + )?, module: object_module, }) } diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 698d0ce..bac7d3f 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -4,37 +4,51 @@ use crate::backend::runtime::RuntimeFunctions; use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef}; use cranelift_codegen::entity::EntityRef; use cranelift_codegen::ir::{ - entities, types, Function, GlobalValue, InstBuilder, Signature, UserFuncName, + entities, types, Function, GlobalValue, InstBuilder, Signature, UserFuncName, MemFlags, }; use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; -use cranelift_module::{DataContext, Linkage, Module, ModuleCompiledFunction, ModuleError}; +use cranelift_module::{DataContext, FuncId, Linkage, Module, ModuleError, DataId}; use internment::ArcIntern; +use super::RuntimeFunctionError; + type StringTable = HashMap, GlobalValue>; impl Program { - pub fn into_cranelift( + pub fn into_cranelift( mut self, module: &mut M, + function_name: &str, rtfuns: &RuntimeFunctions, - ) -> Result { + pre_defined_strings: &HashMap, + pre_defined_symbols: &HashMap, + ) -> Result + where + E: From, + E: From, + M: Module, + { let basic_signature = Signature { params: vec![], returns: vec![], call_conv: CallConv::SystemV, }; - let func_id = module.declare_function("gogogo", Linkage::Export, &basic_signature)?; + let func_id = module.declare_function(function_name, Linkage::Export, &basic_signature)?; let mut ctx = Context::new(); ctx.func = Function::with_name_signature(UserFuncName::user(0, func_id.as_u32()), basic_signature); - let string_table = self.build_string_table(module, &mut ctx.func)?; + let string_table = self.build_string_table(module, &mut ctx.func, pre_defined_strings)?; let mut variable_table = HashMap::new(); let mut next_var_num = 1; let print_func_ref = rtfuns.include_runtime_function("print", module, &mut ctx.func)?; + let pre_defined_symbols: HashMap = pre_defined_symbols.iter().map(|(k, v)| { + let local_data = module.declare_data_in_func(*v, &mut ctx.func); + (k.clone(), local_data) + }).collect(); let mut fctx = FunctionBuilderContext::new(); let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fctx); @@ -43,21 +57,15 @@ impl Program { for stmt in self.statements.drain(..) { match stmt { - Statement::Print(_ann, var) => { + Statement::Print(ann, var) => { let local_name_ref = string_table.get(&var).unwrap(); let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref); - let value_var_num = variable_table.get(&var).unwrap(); - let val = builder.use_var(Variable::new(*value_var_num)); + let val = ValueOrRef::Ref(ann, var).into_cranelift(&mut builder, &variable_table, &pre_defined_symbols)?; builder.ins().call(print_func_ref, &[name_ptr, val]); } Statement::Binding(_, var_name, value) => { - let var = Variable::new(next_var_num); - variable_table.insert(var_name, next_var_num); - next_var_num += 1; - builder.declare_var(var, types::I64); - - let val = match value { + let val = match value { Expression::Value(_, Value::Number(_, v)) => { builder.ins().iconst(types::I64, v) } @@ -71,11 +79,11 @@ impl Program { let right = vals .pop() .unwrap() - .into_cranelift(&mut builder, &variable_table); + .into_cranelift(&mut builder, &variable_table, &pre_defined_symbols)?; let left = vals .pop() .unwrap() - .into_cranelift(&mut builder, &variable_table); + .into_cranelift(&mut builder, &variable_table, &pre_defined_symbols)?; match prim { Primitive::Plus => builder.ins().iadd(left, right), @@ -86,7 +94,16 @@ impl Program { } }; - builder.def_var(var, val); + if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) { + let val_ptr = builder.ins().symbol_value(types::I64, *global_id); + builder.ins().store(MemFlags::new(), val, val_ptr, 0); + } else { + let var = Variable::new(next_var_num); + variable_table.insert(var_name, next_var_num); + next_var_num += 1; + builder.declare_var(var, types::I64); + builder.def_var(var, val); + } } } } @@ -95,33 +112,42 @@ impl Program { builder.seal_block(main_block); builder.finalize(); - Ok(module.define_function(func_id, &mut ctx)?) + let _ = module.define_function(func_id, &mut ctx)?; + + Ok(func_id) } fn build_string_table( &self, module: &mut M, func: &mut Function, + pre_defined_strings: &HashMap, ) -> Result { let mut string_table = HashMap::new(); for (idx, interned_value) in self.strings().drain().enumerate() { - let global_id = module.declare_data( - &format!("local-string-{}", idx), - Linkage::Local, - false, - false, - )?; - let mut data_context = DataContext::new(); - data_context.set_align(8); - data_context.define( - interned_value - .as_str() - .to_owned() - .into_boxed_str() - .into_boxed_bytes(), - ); - module.define_data(global_id, &data_context)?; + let global_id = match pre_defined_strings.get(interned_value.as_str()) { + Some(x) => *x, + None => { + let global_id = module.declare_data( + &format!("local-string-{}", idx), + Linkage::Local, + false, + false, + )?; + let mut data_context = DataContext::new(); + data_context.set_align(8); + data_context.define( + interned_value + .as_str() + .to_owned() + .into_boxed_str() + .into_boxed_bytes(), + ); + module.define_data(global_id, &data_context)?; + global_id + } + }; let local_data = module.declare_data_in_func(global_id, func); string_table.insert(interned_value, local_data); } @@ -134,16 +160,25 @@ impl ValueOrRef { fn into_cranelift( self, builder: &mut FunctionBuilder, - varmap: &HashMap, usize>, - ) -> entities::Value { + local_variables: &HashMap, usize>, + global_variables: &HashMap, + ) -> Result { match self { ValueOrRef::Value(_, value) => match value { - Value::Number(_base, numval) => builder.ins().iconst(types::I64, numval), + Value::Number(_base, numval) => Ok(builder.ins().iconst(types::I64, numval)), }, ValueOrRef::Ref(_, name) => { - let num = varmap.get(&name).unwrap(); - builder.use_var(Variable::new(*num)) + if let Some(local_num) = local_variables.get(&name) { + return Ok(builder.use_var(Variable::new(*local_num))); + } + + if let Some(global_id) = global_variables.get(name.as_str()) { + let val_ptr = builder.ins().symbol_value(types::I64, *global_id); + return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0)) + } + + Err(ModuleError::Undeclared(name.to_string())) } } } diff --git a/src/bin.rs b/src/bin/ngrc.rs similarity index 100% rename from src/bin.rs rename to src/bin/ngrc.rs diff --git a/src/bin/ngri.rs b/src/bin/ngri.rs new file mode 100644 index 0000000..b2ad037 --- /dev/null +++ b/src/bin/ngri.rs @@ -0,0 +1,148 @@ +use codespan_reporting::diagnostic::Diagnostic; +use codespan_reporting::files::SimpleFiles; +use codespan_reporting::term::{self, Config}; +use ngr::ir::Program as IR; +use ngr::jit::{JITEngine, JITError}; +use ngr::syntax::{Location, ParserError, Statement}; +use pretty::termcolor::{ColorChoice, StandardStream, WriteColor}; +use rustyline::error::ReadlineError; +use rustyline::DefaultEditor; +use std::collections::HashMap; + +pub struct RunLoop<'a> { + file_database: SimpleFiles<&'a str, String>, + jitter: JITEngine, + variable_binding_sites: HashMap, + gensym_index: usize, + writer: &'a mut dyn WriteColor, + config: Config, +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, thiserror::Error)] +enum REPLError { + #[error("Error parsing statement: {0}")] + Parser(#[from] ParserError), + #[error("JIT error: {0}")] + JIT(#[from] JITError), + #[error(transparent)] + Reporting(#[from] codespan_reporting::files::Error), +} + +impl From for Diagnostic { + fn from(value: REPLError) -> Self { + match value { + REPLError::Parser(err) => Diagnostic::from(&err), + REPLError::JIT(err) => Diagnostic::from(err), + REPLError::Reporting(err) => Diagnostic::bug().with_message(format!("{}", err)), + } + } +} + +impl<'a> RunLoop<'a> { + pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result { + Ok(RunLoop { + file_database: SimpleFiles::new(), + jitter: JITEngine::new()?, + variable_binding_sites: HashMap::new(), + gensym_index: 1, + writer, + config, + }) + } + + fn emit_diagnostic( + &mut self, + diagnostic: Diagnostic, + ) -> Result<(), codespan_reporting::files::Error> { + term::emit(self.writer, &self.config, &self.file_database, &diagnostic) + } + + fn process_input(&mut self, line_no: usize, command: String) { + if let Err(err) = self.process(line_no, command) { + if let Err(e) = self.emit_diagnostic(Diagnostic::from(err)) { + eprintln!( + "WOAH! System having trouble printing error messages. This is very bad. ({})", + e + ); + } + } + } + + fn process(&mut self, line_no: usize, command: String) -> Result<(), REPLError> { + let entry = self.file_database.add("entry", command); + let source = self + .file_database + .get(entry) + .expect("entry exists") + .source(); + let syntax = Statement::parse(entry, source)?; + + // if this is a variable binding, and we've never defined this variable before, + // we should tell cranelift about it. this is optimistic; if we fail to compile, + // then we won't use this definition until someone tries again. + if let Statement::Binding(_, ref name, _) = syntax { + if !self.variable_binding_sites.contains_key(name.as_str()) { + self.jitter.define_string(name.clone())?; + self.jitter.define_variable(name.clone())?; + } + }; + + let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites); + let stop = !errors.is_empty(); + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); + + for message in messages { + self.emit_diagnostic(message)?; + } + + if stop { + return Ok(()); + } + + let ir = IR::from(syntax.simplify(&mut self.gensym_index)); + let compiled = self.jitter.compile(line_no, ir)?; + compiled(); + Ok(()) + } +} + +fn main() -> Result<(), JITError> { + let mut editor = DefaultEditor::new().expect("rustyline works"); + let mut line_no = 0; + let mut writer = StandardStream::stdout(ColorChoice::Auto); + let config = codespan_reporting::term::Config::default(); + let mut state = RunLoop::new(&mut writer, config)?; + + println!("No Good Reason, the Interpreter!"); + loop { + line_no += 1; + match editor.readline("> ") { + Ok(command) => match command.trim() { + "" => continue, + ":quit" => break, + _ => state.process_input(line_no, command), + }, + Err(ReadlineError::Io(e)) => { + eprintln!("IO error: {}", e); + break; + } + Err(ReadlineError::Eof) => break, + Err(ReadlineError::Interrupted) => break, + Err(ReadlineError::Errno(e)) => { + eprintln!("Unknown syscall error: {}", e); + break; + } + Err(ReadlineError::WindowResized) => continue, + Err(e) => { + eprintln!("Unknown internal error: {}", e); + break; + } + } + } + + Ok(()) +} diff --git a/src/ir/from_syntax.rs b/src/ir/from_syntax.rs index 73de9d0..c4d710d 100644 --- a/src/ir/from_syntax.rs +++ b/src/ir/from_syntax.rs @@ -11,6 +11,14 @@ impl From for ir::Program { } } +impl From> for ir::Program { + fn from(mut value: Vec) -> Self { + ir::Program { + statements: value.drain(..).map(Into::into).collect(), + } + } +} + impl From for ir::Statement { fn from(value: syntax::Statement) -> Self { match value { diff --git a/src/jit.rs b/src/jit.rs new file mode 100644 index 0000000..eb03687 --- /dev/null +++ b/src/jit.rs @@ -0,0 +1,32 @@ +use crate::backend::RuntimeFunctionError; +use codespan_reporting::diagnostic::Diagnostic; +use cranelift_codegen::{isa, settings::SetError, CodegenError}; +use cranelift_module::ModuleError; + +pub mod engine; + +pub use self::engine::JITEngine; + +#[derive(Debug, thiserror::Error)] +pub enum JITError { + #[error("JIT code generation error: {0}")] + Codegen(#[from] CodegenError), + + #[error("JIT configuration flag error: {0}")] + Set(#[from] SetError), + + #[error("ISA lookup error: {0}")] + Lookup(#[from] isa::LookupError), + + #[error("Cranelift module error: {0}")] + Cranelift(#[from] ModuleError), + + #[error("Runtime function error: {0}")] + Runtime(#[from] RuntimeFunctionError), +} + +impl From for Diagnostic { + fn from(value: JITError) -> Self { + Diagnostic::bug().with_message(format!("{}", value)) + } +} diff --git a/src/jit/engine.rs b/src/jit/engine.rs new file mode 100644 index 0000000..5601406 --- /dev/null +++ b/src/jit/engine.rs @@ -0,0 +1,90 @@ +use crate::backend::RuntimeFunctions; +use crate::ir::Program as IR; +use crate::jit::JITError; +use cranelift_codegen::{ + isa, + settings::{self, Configurable}, +}; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{DataContext, DataId, Linkage, Module}; +use std::{collections::HashMap, ffi::{CString, CStr}}; +use target_lexicon::Triple; + +const EMPTY_DATUM: [u8; 8] = [0; 8]; + +pub struct JITEngine { + data_ctx: DataContext, + module: JITModule, + runtime_functions: RuntimeFunctions, + defined_strings: HashMap, + defined_symbols: HashMap, +} + +extern fn runtime_print(name: *const i8, value: u64) { + let cstr = unsafe { CStr::from_ptr(name) }; + let reconstituted = cstr.to_string_lossy(); + println!("{} = {}", reconstituted, value); +} + +impl JITEngine { + pub fn new() -> Result { + let platform = Triple::host(); + let isa_builder = isa::lookup(platform.clone())?; + let mut settings_builder = settings::builder(); + settings_builder.set("use_colocated_libcalls", "false")?; + settings_builder.set("is_pic", "false")?; + let isa = isa_builder.finish(settings::Flags::new(settings_builder))?; + let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); + + builder.symbol("print", runtime_print as *const u8); + + let mut module = JITModule::new(builder); + let runtime_functions = RuntimeFunctions::new(&platform, &mut module)?; + + Ok(JITEngine { + data_ctx: DataContext::new(), + module, + runtime_functions, + defined_strings: HashMap::new(), + defined_symbols: HashMap::new(), + }) + } + + pub fn define_string(&mut self, s: String) -> Result<(), JITError> { + let name = format!("{}",s); + let global_id = self.module.declare_data(&name, Linkage::Local, false, false)?; + let mut data_context = DataContext::new(); + data_context.set_align(8); + data_context.define(s.as_str().to_owned().into_boxed_str().into_boxed_bytes()); + self.module.define_data(global_id, &data_context)?; + self.defined_strings.insert(s, global_id); + Ok(()) + } + + pub fn define_variable(&mut self, name: String) -> Result<(), JITError> { + self.data_ctx.define(Box::new(EMPTY_DATUM)); + let id = self + .module + .declare_data(&name, Linkage::Export, true, false)?; + self.module.define_data(id, &self.data_ctx)?; + self.data_ctx.clear(); + self.module.finalize_definitions()?; + self.defined_symbols.insert(name, id); + Ok(()) + } + + pub fn compile(&mut self, line: usize, program: IR) -> Result (), JITError> { + let function_name = format!("line{}", line); + let function_id = program.into_cranelift::( + &mut self.module, + &function_name, + &self.runtime_functions, + &self.defined_strings, + &self.defined_symbols, + )?; + self.module.finalize_definitions()?; + let code_ptr = self.module.get_finalized_function(function_id); + + unsafe { Ok(std::mem::transmute::<_, fn() -> ()>(code_ptr)) } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6ed733f..765542b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod backend; pub mod ir; +pub mod jit; pub mod syntax; diff --git a/src/syntax.rs b/src/syntax.rs index de64c7d..05b586e 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -1,4 +1,4 @@ -use codespan_reporting::{files::SimpleFiles, diagnostic::Diagnostic}; +use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles}; use lalrpop_util::lalrpop_mod; use logos::Logos; @@ -22,6 +22,8 @@ use lalrpop_util::ParseError; use std::str::FromStr; use thiserror::Error; +use self::parser::StatementParser; + #[derive(Debug, Error)] pub enum ParserError { #[error("Invalid token")] @@ -178,6 +180,17 @@ impl Program { } } +impl Statement { + pub fn parse(file_idx: usize, buffer: &str) -> Result { + let lexer = Token::lexer(buffer) + .spanned() + .map(|(token, range)| (range.start, token, range.end)); + StatementParser::new() + .parse(file_idx, lexer) + .map_err(|e| ParserError::convert(file_idx, e)) + } +} + #[cfg(test)] impl FromStr for Program { type Err = ParserError; diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index cfa2553..12cd2c7 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -41,7 +41,7 @@ Statements: Vec = { } } -Statement: Statement = { +pub Statement: Statement = { "> "=" ";" => Statement::Binding(Location::new(file_idx, l), v.to_string(), e), "print" "> ";" => Statement::Print(Location::new(file_idx, l), v.to_string()), } diff --git a/src/syntax/simplify.rs b/src/syntax/simplify.rs index 58574c0..28ad377 100644 --- a/src/syntax/simplify.rs +++ b/src/syntax/simplify.rs @@ -6,16 +6,7 @@ impl Program { let mut gensym_index = 1; for stmt in self.statements.drain(..) { - match stmt { - Statement::Print(_, _) => new_statements.push(stmt), - Statement::Binding(_, _, Expression::Reference(_, _)) => new_statements.push(stmt), - Statement::Binding(_, _, Expression::Value(_, _)) => new_statements.push(stmt), - Statement::Binding(loc, name, value) => { - let (mut prereqs, new_value) = value.rebind(&name, &mut gensym_index); - new_statements.append(&mut prereqs); - new_statements.push(Statement::Binding(loc, name, new_value)) - } - } + new_statements.append(&mut stmt.simplify(&mut gensym_index)); } self.statements = new_statements; @@ -23,6 +14,25 @@ impl Program { } } +impl Statement { + pub fn simplify(self, gensym_index: &mut usize) -> Vec { + let mut new_statements = vec![]; + + match self { + Statement::Print(_, _) => new_statements.push(self), + Statement::Binding(_, _, Expression::Reference(_, _)) => new_statements.push(self), + Statement::Binding(_, _, Expression::Value(_, _)) => new_statements.push(self), + Statement::Binding(loc, name, value) => { + let (mut prereqs, new_value) = value.rebind(&name, gensym_index); + new_statements.append(&mut prereqs); + new_statements.push(Statement::Binding(loc, name, new_value)) + } + } + + new_statements + } +} + impl Expression { fn rebind(self, base_name: &str, gensym_index: &mut usize) -> (Vec, Expression) { match self { diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 32cefb4..da2410c 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -43,31 +43,47 @@ impl Program { let mut bound_variables = HashMap::new(); for stmt in self.statements.iter() { - match stmt { - Statement::Binding(loc, var, val) => { - // we're going to make the decision that a variable is not bound in the right - // hand side of its binding, which makes a lot of things easier. So we'll just - // immediately check the expression, and go from there. - let (mut exp_errors, mut exp_warnings) = val.validate(&bound_variables); + let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables); + errors.append(&mut new_errors); + warnings.append(&mut new_warnings); + } - errors.append(&mut exp_errors); - warnings.append(&mut exp_warnings); - if let Some(original_binding_site) = bound_variables.get(var) { - warnings.push(Warning::ShadowedVariable( - original_binding_site.clone(), - loc.clone(), - var.clone(), - )); - } else { - bound_variables.insert(var.clone(), loc.clone()); - } - } + (errors, warnings) + } +} - Statement::Print(_, var) if bound_variables.contains_key(var) => {} - Statement::Print(loc, var) => { - errors.push(Error::UnboundVariable(loc.clone(), var.clone())) +impl Statement { + pub fn validate( + &self, + bound_variables: &mut HashMap, + ) -> (Vec, Vec) { + let mut errors = vec![]; + let mut warnings = vec![]; + + match self { + Statement::Binding(loc, var, val) => { + // we're going to make the decision that a variable is not bound in the right + // hand side of its binding, which makes a lot of things easier. So we'll just + // immediately check the expression, and go from there. + let (mut exp_errors, mut exp_warnings) = val.validate(bound_variables); + + errors.append(&mut exp_errors); + warnings.append(&mut exp_warnings); + if let Some(original_binding_site) = bound_variables.get(var) { + warnings.push(Warning::ShadowedVariable( + original_binding_site.clone(), + loc.clone(), + var.clone(), + )); + } else { + bound_variables.insert(var.clone(), loc.clone()); } } + + Statement::Print(_, var) if bound_variables.contains_key(var) => {} + Statement::Print(loc, var) => { + errors.push(Error::UnboundVariable(loc.clone(), var.clone())) + } } (errors, warnings)