Add a proptest for the JIT backend.
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
void print(char *variable_name, uint64_t value) {
|
||||
void print(char *_ignore, char *variable_name, uint64_t value) {
|
||||
printf("%s = %llii64\n", variable_name, value);
|
||||
}
|
||||
|
||||
void caller() {
|
||||
print("x", 4);
|
||||
print(NULL, "x", 4);
|
||||
}
|
||||
|
||||
extern void gogogo();
|
||||
|
||||
@@ -22,10 +22,11 @@ pub struct Backend<M: Module> {
|
||||
runtime_functions: RuntimeFunctions,
|
||||
defined_strings: HashMap<String, DataId>,
|
||||
defined_symbols: HashMap<String, DataId>,
|
||||
output_buffer: Option<String>,
|
||||
}
|
||||
|
||||
impl Backend<JITModule> {
|
||||
pub fn jit() -> Result<Self, BackendError> {
|
||||
pub fn jit(output_buffer: Option<String>) -> Result<Self, BackendError> {
|
||||
let platform = Triple::host();
|
||||
let isa_builder = isa::lookup(platform.clone())?;
|
||||
let mut settings_builder = settings::builder();
|
||||
@@ -45,6 +46,7 @@ impl Backend<JITModule> {
|
||||
runtime_functions,
|
||||
defined_strings: HashMap::new(),
|
||||
defined_symbols: HashMap::new(),
|
||||
output_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,6 +72,7 @@ impl Backend<ObjectModule> {
|
||||
runtime_functions,
|
||||
defined_strings: HashMap::new(),
|
||||
defined_symbols: HashMap::new(),
|
||||
output_buffer: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -104,4 +107,20 @@ impl<M: Module> Backend<M> {
|
||||
self.defined_symbols.insert(name, id);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
pub fn output_buffer_ptr(&mut self) -> *mut String {
|
||||
if let Some(str) = self.output_buffer.as_mut() {
|
||||
str as *mut String
|
||||
} else {
|
||||
std::ptr::null_mut()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn output(self) -> String {
|
||||
if let Some(s) = self.output_buffer {
|
||||
s
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,14 @@ use cranelift_object::ObjectModule;
|
||||
use target_lexicon::Triple;
|
||||
|
||||
impl Backend<JITModule> {
|
||||
pub fn eval(_program: Program) -> Result<String, EvalError> {
|
||||
unimplemented!()
|
||||
pub fn eval(program: Program) -> Result<String, EvalError> {
|
||||
let mut jitter = Backend::jit(Some(String::new()))?;
|
||||
let function_id = jitter.compile_function("test", program)?;
|
||||
jitter.module.finalize_definitions()?;
|
||||
let compiled_bytes = jitter.bytes(function_id);
|
||||
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
|
||||
compiled_function();
|
||||
Ok(jitter.output())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,4 +92,16 @@ proptest::proptest! {
|
||||
assert_eq!(basic_result, compiled_result);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jit_backend_works(program: Program) {
|
||||
use crate::eval::PrimOpError;
|
||||
|
||||
let basic_result = program.eval();
|
||||
|
||||
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
|
||||
let compiled_result = Backend::<JITModule>::eval(program);
|
||||
assert_eq!(basic_result, compiled_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,8 @@ impl<M: Module> Backend<M> {
|
||||
for stmt in program.statements.drain(..) {
|
||||
match stmt {
|
||||
Statement::Print(ann, var) => {
|
||||
let buffer_ptr = self.output_buffer_ptr();
|
||||
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
|
||||
let local_name_ref = string_table.get(&var).unwrap();
|
||||
let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref);
|
||||
let val = ValueOrRef::Ref(ann, var).into_cranelift(
|
||||
@@ -67,7 +69,9 @@ impl<M: Module> Backend<M> {
|
||||
&variable_table,
|
||||
&pre_defined_symbols,
|
||||
)?;
|
||||
builder.ins().call(print_func_ref, &[name_ptr, val]);
|
||||
builder
|
||||
.ins()
|
||||
.call(print_func_ref, &[buffer_ptr, name_ptr, val]);
|
||||
}
|
||||
|
||||
Statement::Binding(_, var_name, value) => {
|
||||
|
||||
@@ -4,6 +4,7 @@ use cranelift_jit::JITBuilder;
|
||||
use cranelift_module::{FuncId, Linkage, Module, ModuleResult};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CStr;
|
||||
use std::fmt::Write;
|
||||
use target_lexicon::Triple;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -18,10 +19,15 @@ pub enum RuntimeFunctionError {
|
||||
CannotFindRuntimeFunction(String),
|
||||
}
|
||||
|
||||
extern "C" fn runtime_print(name: *const i8, value: u64) {
|
||||
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) {
|
||||
let cstr = unsafe { CStr::from_ptr(name) };
|
||||
let reconstituted = cstr.to_string_lossy();
|
||||
|
||||
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
|
||||
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap();
|
||||
} else {
|
||||
println!("{} = {}", reconstituted, value);
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeFunctions {
|
||||
@@ -36,7 +42,7 @@ impl RuntimeFunctions {
|
||||
"print",
|
||||
Linkage::Import,
|
||||
&Signature {
|
||||
params: vec![string_param, int64_param],
|
||||
params: vec![string_param, string_param, int64_param],
|
||||
returns: vec![],
|
||||
call_conv: CallConv::triple_default(platform),
|
||||
},
|
||||
|
||||
@@ -48,7 +48,7 @@ impl<'a> RunLoop<'a> {
|
||||
pub fn new(writer: &'a mut dyn WriteColor, config: Config) -> Result<Self, BackendError> {
|
||||
Ok(RunLoop {
|
||||
file_database: SimpleFiles::new(),
|
||||
jitter: Backend::jit()?,
|
||||
jitter: Backend::jit(None)?,
|
||||
variable_binding_sites: HashMap::new(),
|
||||
gensym_index: 1,
|
||||
writer,
|
||||
|
||||
36
src/eval.rs
36
src/eval.rs
@@ -2,13 +2,14 @@ mod env;
|
||||
mod primop;
|
||||
mod value;
|
||||
|
||||
use cranelift_module::ModuleError;
|
||||
pub use env::{EvalEnvironment, LookupError};
|
||||
pub use primop::PrimOpError;
|
||||
pub use value::Value;
|
||||
|
||||
use crate::backend::BackendError;
|
||||
|
||||
#[derive(Debug, PartialEq, thiserror::Error)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EvalError {
|
||||
#[error(transparent)]
|
||||
Lookup(#[from] LookupError),
|
||||
@@ -18,6 +19,8 @@ pub enum EvalError {
|
||||
Backend(#[from] BackendError),
|
||||
#[error("IO error: {0}")]
|
||||
IO(String),
|
||||
#[error(transparent)]
|
||||
Module(#[from] ModuleError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EvalError {
|
||||
@@ -25,3 +28,34 @@ impl From<std::io::Error> for EvalError {
|
||||
EvalError::IO(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for EvalError {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match self {
|
||||
EvalError::Lookup(a) => match other {
|
||||
EvalError::Lookup(b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
|
||||
EvalError::PrimOp(a) => match other {
|
||||
EvalError::PrimOp(b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
|
||||
EvalError::Backend(a) => match other {
|
||||
EvalError::Backend(b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
|
||||
EvalError::IO(a) => match other {
|
||||
EvalError::IO(b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
|
||||
EvalError::Module(a) => match other {
|
||||
EvalError::Module(b) => a.to_string() == b.to_string(),
|
||||
_ => false,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user