🧪 Add evaluation tests to ensure that passes retain NGR semantics. #2

Merged
acw merged 6 commits from acw/eval-tests into develop 2023-04-16 16:07:45 -07:00
7 changed files with 92 additions and 11 deletions
Showing only changes of commit d455ee87b5 - Show all commits

View File

@@ -1,12 +1,12 @@
#include <stdint.h>
#include <stdio.h>
void print(char *variable_name, uint64_t value) {
void print(char *_ignore, char *variable_name, uint64_t value) {
printf("%s = %llii64\n", variable_name, value);
}
void caller() {
print("x", 4);
print(NULL, "x", 4);
}
extern void gogogo();

View File

@@ -22,10 +22,11 @@ pub struct Backend<M: Module> {
runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, DataId>,
output_buffer: Option<String>,
}
impl Backend<JITModule> {
pub fn jit() -> Result<Self, BackendError> {
pub fn jit(output_buffer: Option<String>) -> Result<Self, BackendError> {
let platform = Triple::host();
let isa_builder = isa::lookup(platform.clone())?;
let mut settings_builder = settings::builder();
@@ -45,6 +46,7 @@ impl Backend<JITModule> {
runtime_functions,
defined_strings: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer,
})
}
@@ -70,6 +72,7 @@ impl Backend<ObjectModule> {
runtime_functions,
defined_strings: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer: None,
})
}
@@ -104,4 +107,20 @@ impl<M: Module> Backend<M> {
self.defined_symbols.insert(name, id);
Ok(id)
}
pub fn output_buffer_ptr(&mut self) -> *mut String {
if let Some(str) = self.output_buffer.as_mut() {
str as *mut String
} else {
std::ptr::null_mut()
}
}
pub fn output(self) -> String {
if let Some(s) = self.output_buffer {
s
} else {
String::new()
}
}
}

View File

@@ -8,8 +8,14 @@ use cranelift_object::ObjectModule;
use target_lexicon::Triple;
impl Backend<JITModule> {
pub fn eval(_program: Program) -> Result<String, EvalError> {
unimplemented!()
pub fn eval(program: Program) -> Result<String, EvalError> {
let mut jitter = Backend::jit(Some(String::new()))?;
let function_id = jitter.compile_function("test", program)?;
jitter.module.finalize_definitions()?;
let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(jitter.output())
}
}
@@ -86,4 +92,16 @@ proptest::proptest! {
assert_eq!(basic_result, compiled_result);
}
}
#[test]
fn jit_backend_works(program: Program) {
use crate::eval::PrimOpError;
let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result);
}
}
}

View File

@@ -60,6 +60,8 @@ impl<M: Module> Backend<M> {
for stmt in program.statements.drain(..) {
match stmt {
Statement::Print(ann, var) => {
let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
let local_name_ref = string_table.get(&var).unwrap();
let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref);
let val = ValueOrRef::Ref(ann, var).into_cranelift(
@@ -67,7 +69,9 @@ impl<M: Module> Backend<M> {
&variable_table,
&pre_defined_symbols,
)?;
builder.ins().call(print_func_ref, &[name_ptr, val]);
builder
.ins()
.call(print_func_ref, &[buffer_ptr, name_ptr, val]);
}
Statement::Binding(_, var_name, value) => {

View File

@@ -4,6 +4,7 @@ use cranelift_jit::JITBuilder;
use cranelift_module::{FuncId, Linkage, Module, ModuleResult};
use std::collections::HashMap;
use std::ffi::CStr;
use std::fmt::Write;
use target_lexicon::Triple;
use thiserror::Error;
@@ -18,10 +19,15 @@ pub enum RuntimeFunctionError {
CannotFindRuntimeFunction(String),
}
extern "C" fn runtime_print(name: *const i8, value: u64) {
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) {
let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy();
println!("{} = {}", reconstituted, value);
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap();
} else {
println!("{} = {}", reconstituted, value);
}
}
impl RuntimeFunctions {
@@ -36,7 +42,7 @@ impl RuntimeFunctions {
"print",
Linkage::Import,
&Signature {
params: vec![string_param, int64_param],
params: vec![string_param, string_param, int64_param],
returns: vec![],
call_conv: CallConv::triple_default(platform),
},

View File

@@ -48,7 +48,7 @@ impl<'a> RunLoop<'a> {
pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result<Self, BackendError> {
Ok(RunLoop {
file_database: SimpleFiles::new(),
jitter: Backend::jit()?,
jitter: Backend::jit(None)?,
variable_binding_sites: HashMap::new(),
gensym_index: 1,
writer,

View File

@@ -2,13 +2,14 @@ mod env;
mod primop;
mod value;
use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError;
pub use value::Value;
use crate::backend::BackendError;
#[derive(Debug, PartialEq, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub enum EvalError {
#[error(transparent)]
Lookup(#[from] LookupError),
@@ -18,6 +19,8 @@ pub enum EvalError {
Backend(#[from] BackendError),
#[error("IO error: {0}")]
IO(String),
#[error(transparent)]
Module(#[from] ModuleError),
}
impl From<std::io::Error> for EvalError {
@@ -25,3 +28,34 @@ impl From<std::io::Error> for EvalError {
EvalError::IO(value.to_string())
}
}
impl PartialEq for EvalError {
fn eq(&self, other: &Self) -> bool {
match self {
EvalError::Lookup(a) => match other {
EvalError::Lookup(b) => a == b,
_ => false,
},
EvalError::PrimOp(a) => match other {
EvalError::PrimOp(b) => a == b,
_ => false,
},
EvalError::Backend(a) => match other {
EvalError::Backend(b) => a == b,
_ => false,
},
EvalError::IO(a) => match other {
EvalError::IO(b) => a == b,
_ => false,
},
EvalError::Module(a) => match other {
EvalError::Module(b) => a.to_string() == b.to_string(),
_ => false,
},
}
}
}