From d455ee87b55804a380e51c8c8dd20424a55c155e Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Thu, 13 Apr 2023 21:04:43 -0700 Subject: [PATCH] Add a proptest for the JIT backend. --- runtime/rts.c | 4 ++-- src/backend.rs | 21 ++++++++++++++++++++- src/backend/eval.rs | 22 ++++++++++++++++++++-- src/backend/into_crane.rs | 6 +++++- src/backend/runtime.rs | 12 +++++++++--- src/bin/ngri.rs | 2 +- src/eval.rs | 36 +++++++++++++++++++++++++++++++++++- 7 files changed, 92 insertions(+), 11 deletions(-) diff --git a/runtime/rts.c b/runtime/rts.c index c7c87fa..0ddd90f 100644 --- a/runtime/rts.c +++ b/runtime/rts.c @@ -1,12 +1,12 @@ #include #include -void print(char *variable_name, uint64_t value) { +void print(char *_ignore, char *variable_name, uint64_t value) { printf("%s = %llii64\n", variable_name, value); } void caller() { - print("x", 4); + print(NULL, "x", 4); } extern void gogogo(); diff --git a/src/backend.rs b/src/backend.rs index 24d71e6..f70a52f 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -22,10 +22,11 @@ pub struct Backend { runtime_functions: RuntimeFunctions, defined_strings: HashMap, defined_symbols: HashMap, + output_buffer: Option, } impl Backend { - pub fn jit() -> Result { + pub fn jit(output_buffer: Option) -> Result { let platform = Triple::host(); let isa_builder = isa::lookup(platform.clone())?; let mut settings_builder = settings::builder(); @@ -45,6 +46,7 @@ impl Backend { runtime_functions, defined_strings: HashMap::new(), defined_symbols: HashMap::new(), + output_buffer, }) } @@ -70,6 +72,7 @@ impl Backend { runtime_functions, defined_strings: HashMap::new(), defined_symbols: HashMap::new(), + output_buffer: None, }) } @@ -104,4 +107,20 @@ impl Backend { self.defined_symbols.insert(name, id); Ok(id) } + + pub fn output_buffer_ptr(&mut self) -> *mut String { + if let Some(str) = self.output_buffer.as_mut() { + str as *mut String + } else { + std::ptr::null_mut() + } + } + + pub fn output(self) -> String { + if let Some(s) = self.output_buffer { + s + } else { + String::new() + } + } } diff --git a/src/backend/eval.rs b/src/backend/eval.rs index fef25bc..950a5a1 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -8,8 +8,14 @@ use cranelift_object::ObjectModule; use target_lexicon::Triple; impl Backend { - pub fn eval(_program: Program) -> Result { - unimplemented!() + pub fn eval(program: Program) -> Result { + let mut jitter = Backend::jit(Some(String::new()))?; + let function_id = jitter.compile_function("test", program)?; + jitter.module.finalize_definitions()?; + let compiled_bytes = jitter.bytes(function_id); + let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; + compiled_function(); + Ok(jitter.output()) } } @@ -86,4 +92,16 @@ proptest::proptest! { assert_eq!(basic_result, compiled_result); } } + + #[test] + fn jit_backend_works(program: Program) { + use crate::eval::PrimOpError; + + let basic_result = program.eval(); + + if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { + let compiled_result = Backend::::eval(program); + assert_eq!(basic_result, compiled_result); + } + } } diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 8605feb..6ceff02 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -60,6 +60,8 @@ impl Backend { for stmt in program.statements.drain(..) { match stmt { Statement::Print(ann, var) => { + let buffer_ptr = self.output_buffer_ptr(); + let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); let local_name_ref = string_table.get(&var).unwrap(); let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref); let val = ValueOrRef::Ref(ann, var).into_cranelift( @@ -67,7 +69,9 @@ impl Backend { &variable_table, &pre_defined_symbols, )?; - builder.ins().call(print_func_ref, &[name_ptr, val]); + builder + .ins() + .call(print_func_ref, &[buffer_ptr, name_ptr, val]); } Statement::Binding(_, var_name, value) => { diff --git a/src/backend/runtime.rs b/src/backend/runtime.rs index 01e00bc..1338a73 100644 --- a/src/backend/runtime.rs +++ b/src/backend/runtime.rs @@ -4,6 +4,7 @@ use cranelift_jit::JITBuilder; use cranelift_module::{FuncId, Linkage, Module, ModuleResult}; use std::collections::HashMap; use std::ffi::CStr; +use std::fmt::Write; use target_lexicon::Triple; use thiserror::Error; @@ -18,10 +19,15 @@ pub enum RuntimeFunctionError { CannotFindRuntimeFunction(String), } -extern "C" fn runtime_print(name: *const i8, value: u64) { +extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) { let cstr = unsafe { CStr::from_ptr(name) }; let reconstituted = cstr.to_string_lossy(); - println!("{} = {}", reconstituted, value); + + if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { + writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap(); + } else { + println!("{} = {}", reconstituted, value); + } } impl RuntimeFunctions { @@ -36,7 +42,7 @@ impl RuntimeFunctions { "print", Linkage::Import, &Signature { - params: vec![string_param, int64_param], + params: vec![string_param, string_param, int64_param], returns: vec![], call_conv: CallConv::triple_default(platform), }, diff --git a/src/bin/ngri.rs b/src/bin/ngri.rs index c78179f..b1d74de 100644 --- a/src/bin/ngri.rs +++ b/src/bin/ngri.rs @@ -48,7 +48,7 @@ impl<'a> RunLoop<'a> { pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result { Ok(RunLoop { file_database: SimpleFiles::new(), - jitter: Backend::jit()?, + jitter: Backend::jit(None)?, variable_binding_sites: HashMap::new(), gensym_index: 1, writer, diff --git a/src/eval.rs b/src/eval.rs index edb2ca9..b764eb5 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -2,13 +2,14 @@ mod env; mod primop; mod value; +use cranelift_module::ModuleError; pub use env::{EvalEnvironment, LookupError}; pub use primop::PrimOpError; pub use value::Value; use crate::backend::BackendError; -#[derive(Debug, PartialEq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum EvalError { #[error(transparent)] Lookup(#[from] LookupError), @@ -18,6 +19,8 @@ pub enum EvalError { Backend(#[from] BackendError), #[error("IO error: {0}")] IO(String), + #[error(transparent)] + Module(#[from] ModuleError), } impl From for EvalError { @@ -25,3 +28,34 @@ impl From for EvalError { EvalError::IO(value.to_string()) } } + +impl PartialEq for EvalError { + fn eq(&self, other: &Self) -> bool { + match self { + EvalError::Lookup(a) => match other { + EvalError::Lookup(b) => a == b, + _ => false, + }, + + EvalError::PrimOp(a) => match other { + EvalError::PrimOp(b) => a == b, + _ => false, + }, + + EvalError::Backend(a) => match other { + EvalError::Backend(b) => a == b, + _ => false, + }, + + EvalError::IO(a) => match other { + EvalError::IO(b) => a == b, + _ => false, + }, + + EvalError::Module(a) => match other { + EvalError::Module(b) => a.to_string() == b.to_string(), + _ => false, + }, + } + } +}