🧪 Add evaluation tests to ensure that passes retain NGR semantics. #2

Merged
acw merged 6 commits from acw/eval-tests into develop 2023-04-16 16:07:45 -07:00
7 changed files with 92 additions and 11 deletions
Showing only changes of commit d455ee87b5 - Show all commits

View File

@@ -1,12 +1,12 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
void print(char *variable_name, uint64_t value) { void print(char *_ignore, char *variable_name, uint64_t value) {
printf("%s = %llii64\n", variable_name, value); printf("%s = %llii64\n", variable_name, value);
} }
void caller() { void caller() {
print("x", 4); print(NULL, "x", 4);
} }
extern void gogogo(); extern void gogogo();

View File

@@ -22,10 +22,11 @@ pub struct Backend<M: Module> {
runtime_functions: RuntimeFunctions, runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>, defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, DataId>, defined_symbols: HashMap<String, DataId>,
output_buffer: Option<String>,
} }
impl Backend<JITModule> { impl Backend<JITModule> {
pub fn jit() -> Result<Self, BackendError> { pub fn jit(output_buffer: Option<String>) -> Result<Self, BackendError> {
let platform = Triple::host(); let platform = Triple::host();
let isa_builder = isa::lookup(platform.clone())?; let isa_builder = isa::lookup(platform.clone())?;
let mut settings_builder = settings::builder(); let mut settings_builder = settings::builder();
@@ -45,6 +46,7 @@ impl Backend<JITModule> {
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer,
}) })
} }
@@ -70,6 +72,7 @@ impl Backend<ObjectModule> {
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer: None,
}) })
} }
@@ -104,4 +107,20 @@ impl<M: Module> Backend<M> {
self.defined_symbols.insert(name, id); self.defined_symbols.insert(name, id);
Ok(id) Ok(id)
} }
pub fn output_buffer_ptr(&mut self) -> *mut String {
if let Some(str) = self.output_buffer.as_mut() {
str as *mut String
} else {
std::ptr::null_mut()
}
}
pub fn output(self) -> String {
if let Some(s) = self.output_buffer {
s
} else {
String::new()
}
}
} }

View File

@@ -8,8 +8,14 @@ use cranelift_object::ObjectModule;
use target_lexicon::Triple; use target_lexicon::Triple;
impl Backend<JITModule> { impl Backend<JITModule> {
pub fn eval(_program: Program) -> Result<String, EvalError> { pub fn eval(program: Program) -> Result<String, EvalError> {
unimplemented!() let mut jitter = Backend::jit(Some(String::new()))?;
let function_id = jitter.compile_function("test", program)?;
jitter.module.finalize_definitions()?;
let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(jitter.output())
} }
} }
@@ -86,4 +92,16 @@ proptest::proptest! {
assert_eq!(basic_result, compiled_result); assert_eq!(basic_result, compiled_result);
} }
} }
#[test]
fn jit_backend_works(program: Program) {
use crate::eval::PrimOpError;
let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result);
}
}
} }

View File

@@ -60,6 +60,8 @@ impl<M: Module> Backend<M> {
for stmt in program.statements.drain(..) { for stmt in program.statements.drain(..) {
match stmt { match stmt {
Statement::Print(ann, var) => { Statement::Print(ann, var) => {
let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
let local_name_ref = string_table.get(&var).unwrap(); let local_name_ref = string_table.get(&var).unwrap();
let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref); let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref);
let val = ValueOrRef::Ref(ann, var).into_cranelift( let val = ValueOrRef::Ref(ann, var).into_cranelift(
@@ -67,7 +69,9 @@ impl<M: Module> Backend<M> {
&variable_table, &variable_table,
&pre_defined_symbols, &pre_defined_symbols,
)?; )?;
builder.ins().call(print_func_ref, &[name_ptr, val]); builder
.ins()
.call(print_func_ref, &[buffer_ptr, name_ptr, val]);
} }
Statement::Binding(_, var_name, value) => { Statement::Binding(_, var_name, value) => {

View File

@@ -4,6 +4,7 @@ use cranelift_jit::JITBuilder;
use cranelift_module::{FuncId, Linkage, Module, ModuleResult}; use cranelift_module::{FuncId, Linkage, Module, ModuleResult};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CStr; use std::ffi::CStr;
use std::fmt::Write;
use target_lexicon::Triple; use target_lexicon::Triple;
use thiserror::Error; use thiserror::Error;
@@ -18,11 +19,16 @@ pub enum RuntimeFunctionError {
CannotFindRuntimeFunction(String), CannotFindRuntimeFunction(String),
} }
extern "C" fn runtime_print(name: *const i8, value: u64) { extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) {
let cstr = unsafe { CStr::from_ptr(name) }; let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy(); let reconstituted = cstr.to_string_lossy();
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap();
} else {
println!("{} = {}", reconstituted, value); println!("{} = {}", reconstituted, value);
} }
}
impl RuntimeFunctions { impl RuntimeFunctions {
pub fn new<M: Module>(platform: &Triple, module: &mut M) -> ModuleResult<RuntimeFunctions> { pub fn new<M: Module>(platform: &Triple, module: &mut M) -> ModuleResult<RuntimeFunctions> {
@@ -36,7 +42,7 @@ impl RuntimeFunctions {
"print", "print",
Linkage::Import, Linkage::Import,
&Signature { &Signature {
params: vec![string_param, int64_param], params: vec![string_param, string_param, int64_param],
returns: vec![], returns: vec![],
call_conv: CallConv::triple_default(platform), call_conv: CallConv::triple_default(platform),
}, },

View File

@@ -48,7 +48,7 @@ impl<'a> RunLoop<'a> {
pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result<Self, BackendError> { pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result<Self, BackendError> {
Ok(RunLoop { Ok(RunLoop {
file_database: SimpleFiles::new(), file_database: SimpleFiles::new(),
jitter: Backend::jit()?, jitter: Backend::jit(None)?,
variable_binding_sites: HashMap::new(), variable_binding_sites: HashMap::new(),
gensym_index: 1, gensym_index: 1,
writer, writer,

View File

@@ -2,13 +2,14 @@ mod env;
mod primop; mod primop;
mod value; mod value;
use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError}; pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError; pub use primop::PrimOpError;
pub use value::Value; pub use value::Value;
use crate::backend::BackendError; use crate::backend::BackendError;
#[derive(Debug, PartialEq, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum EvalError { pub enum EvalError {
#[error(transparent)] #[error(transparent)]
Lookup(#[from] LookupError), Lookup(#[from] LookupError),
@@ -18,6 +19,8 @@ pub enum EvalError {
Backend(#[from] BackendError), Backend(#[from] BackendError),
#[error("IO error: {0}")] #[error("IO error: {0}")]
IO(String), IO(String),
#[error(transparent)]
Module(#[from] ModuleError),
} }
impl From<std::io::Error> for EvalError { impl From<std::io::Error> for EvalError {
@@ -25,3 +28,34 @@ impl From<std::io::Error> for EvalError {
EvalError::IO(value.to_string()) EvalError::IO(value.to_string())
} }
} }
impl PartialEq for EvalError {
fn eq(&self, other: &Self) -> bool {
match self {
EvalError::Lookup(a) => match other {
EvalError::Lookup(b) => a == b,
_ => false,
},
EvalError::PrimOp(a) => match other {
EvalError::PrimOp(b) => a == b,
_ => false,
},
EvalError::Backend(a) => match other {
EvalError::Backend(b) => a == b,
_ => false,
},
EvalError::IO(a) => match other {
EvalError::IO(b) => a == b,
_ => false,
},
EvalError::Module(a) => match other {
EvalError::Module(b) => a.to_string() == b.to_string(),
_ => false,
},
}
}
}