🤔 Add a type inference engine, along with typed literals. (#4)

The typed literal formatting mirrors that of Rust. If no type can be
inferred for an untagged literal, the type inference engine will warn
the user and then assume that they meant an unsigned 64-bit number.
(This is slightly inconvenient, because there can be cases in which our
Arbitrary instance may generate a unary negation, in which we should
assume that it's a signed 64-bit number; we may want to revisit this
later.)

The type inference engine is a standard two phase one, in which we first
generate a series of type constraints, and then we solve those
constraints. In this particular implementation, we actually use a third
phase to generate a final AST.

Finally, to increase the amount of testing performed, I've removed the
overflow checking in the evaluator. The only thing we now check for is
division by zero. This does make things a trace slower in testing, but
hopefully we get more coverage this way.
This commit was merged in pull request #4.
This commit is contained in:
2023-09-19 20:40:05 -07:00
committed by GitHub
parent 1fbfd0c2d2
commit bd3b9af469
44 changed files with 3258 additions and 702 deletions

3
.gitignore vendored
View File

@@ -2,6 +2,9 @@
Cargo.lock Cargo.lock
**/*.o **/*.o
test test
test.exe
test.ilk
test.pdb
*.dSYM *.dSYM
.vscode .vscode
proptest-regressions/ proptest-regressions/

View File

@@ -12,12 +12,12 @@ path = "src/lib.rs"
clap = { version = "^3.0.14", features = ["derive"] } clap = { version = "^3.0.14", features = ["derive"] }
codespan = "0.11.1" codespan = "0.11.1"
codespan-reporting = "0.11.1" codespan-reporting = "0.11.1"
cranelift-codegen = "0.94.0" cranelift-codegen = "0.99.2"
cranelift-jit = "0.94.0" cranelift-jit = "0.99.2"
cranelift-frontend = "0.94.0" cranelift-frontend = "0.99.2"
cranelift-module = "0.94.0" cranelift-module = "0.99.2"
cranelift-native = "0.94.0" cranelift-native = "0.99.2"
cranelift-object = "0.94.0" cranelift-object = "0.99.2"
internment = { version = "0.7.0", default-features = false, features = ["arc"] } internment = { version = "0.7.0", default-features = false, features = ["arc"] }
lalrpop-util = "^0.20.0" lalrpop-util = "^0.20.0"
lazy_static = "^1.4.0" lazy_static = "^1.4.0"

View File

@@ -1,5 +1,88 @@
extern crate lalrpop; extern crate lalrpop;
use std::env;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
fn main() { fn main() {
lalrpop::process_root().unwrap(); lalrpop::process_root().unwrap();
if let Err(e) = generate_example_tests() {
eprintln!("Failure building example tests: {}", e);
std::process::exit(3);
}
}
fn generate_example_tests() -> std::io::Result<()> {
let out_dir = env::var_os("OUT_DIR").expect("OUT_DIR");
let dest_path = Path::new(&out_dir).join("examples.rs");
let mut output = File::create(dest_path)?;
generate_tests(&mut output, PathBuf::from("examples"))?;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=example.rs");
Ok(())
}
fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> {
for entry in fs::read_dir(path_so_far)? {
let entry = entry?;
let file_name = entry
.file_name()
.into_string()
.expect("reasonable file name");
let name = file_name.trim_end_matches(".ngr");
if entry.file_type()?.is_dir() {
writeln!(f, "mod {} {{", name)?;
writeln!(f, " use super::*;")?;
generate_tests(f, entry.path())?;
writeln!(f, "}}")?;
} else {
writeln!(f, "#[test]")?;
writeln!(f, "fn {}() {{", name)?;
writeln!(f, " let mut file_database = SimpleFiles::new();")?;
writeln!(
f,
" let syntax = Syntax::parse_file(&mut file_database, {:?});",
entry.path().display()
)?;
if entry.path().to_string_lossy().contains("broken") {
writeln!(f, " if syntax.is_err() {{")?;
writeln!(f, " return;")?;
writeln!(f, " }}")?;
writeln!(f, " let (errors, _) = syntax.unwrap().validate();")?;
writeln!(
f,
" assert_ne!(errors.len(), 0, \"should have seen an error\");"
)?;
} else {
// NOTE: Since the advent of defaulting rules and type checking, we
// can't guarantee that syntax.eval() will return the same result as
// ir.eval() or backend::eval(). We must do type checking to force
// constants into the right types, first. So this now checks only that
// the result of ir.eval() and backend::eval() are the same.
writeln!(
f,
" let syntax = syntax.expect(\"file should have parsed\");"
)?;
writeln!(f, " let (errors, _) = syntax.validate();")?;
writeln!(
f,
" assert_eq!(errors.len(), 0, \"file should have no validation errors\");"
)?;
writeln!(
f,
" let ir = syntax.type_infer().expect(\"example is typed correctly\");"
)?;
writeln!(f, " let ir_result = ir.eval();")?;
writeln!(f, " let compiled_result = Backend::<JITModule>::eval(ir);")?;
writeln!(f, " assert_eq!(ir_result, compiled_result);")?;
}
writeln!(f, "}}")?;
}
}
Ok(())
} }

View File

@@ -0,0 +1,4 @@
x = 5;
x = 4*x + 3;
print x;
print y;

7
examples/basic/cast1.ngr Normal file
View File

@@ -0,0 +1,7 @@
x8 = 5u8;
x16 = <u16>x8 + 1u16;
print x16;
x32 = <u32>x16 + 1u32;
print x32;
x64 = <u64>x32 + 1u64;
print x64;

7
examples/basic/cast2.ngr Normal file
View File

@@ -0,0 +1,7 @@
x8 = 5i8;
x16 = <i16>x8 - 1i16;
print x16;
x32 = <i32>x16 - 1i32;
print x32;
x64 = <i64>x32 - 1i64;
print x64;

View File

@@ -0,0 +1,3 @@
x = 2i64 + 2i64;
y = -x;
print y;

View File

@@ -1,4 +1,4 @@
x = 5; // this test is useful for making sure we don't accidentally
x = 4*x + 3; // use a signed divide operation anywhere important
print x; a = 96u8 / 160u8;
print y; print a;

View File

@@ -0,0 +1,6 @@
x = 1 + 1u16;
print x;
y = 1u16 + 1;
print y;
z = 1 + 1 + 1;
print z;

View File

@@ -2,12 +2,33 @@
#include <stdio.h> #include <stdio.h>
#include <inttypes.h> #include <inttypes.h>
void print(char *_ignore, char *variable_name, int64_t value) { void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) {
printf("%s = %" PRId64 "i64\n", variable_name, value); switch(vtype) {
} case /* U8 = */ 10:
printf("%s = %" PRIu8 "u8\n", variable_name, (uint8_t)value);
void caller() { break;
print(NULL, "x", 4); case /* U16 = */ 11:
printf("%s = %" PRIu16 "u16\n", variable_name, (uint16_t)value);
break;
case /* U32 = */ 12:
printf("%s = %" PRIu32 "u32\n", variable_name, (uint32_t)value);
break;
case /* U64 = */ 13:
printf("%s = %" PRIu64 "u64\n", variable_name, (uint64_t)value);
break;
case /* I8 = */ 20:
printf("%s = %" PRIi8 "i8\n", variable_name, (int8_t)value);
break;
case /* I16 = */ 21:
printf("%s = %" PRIi16 "i16\n", variable_name, (int16_t)value);
break;
case /* I32 = */ 22:
printf("%s = %" PRIi32 "i32\n", variable_name, (int32_t)value);
break;
case /* I64 = */ 23:
printf("%s = %" PRIi64 "i64\n", variable_name, value);
break;
}
} }
extern void gogogo(); extern void gogogo();

View File

@@ -31,15 +31,15 @@ mod eval;
mod into_crane; mod into_crane;
mod runtime; mod runtime;
use std::collections::HashMap;
pub use self::error::BackendError; pub use self::error::BackendError;
pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions};
use crate::syntax::ConstantType;
use cranelift_codegen::settings::Configurable; use cranelift_codegen::settings::Configurable;
use cranelift_codegen::{isa, settings}; use cranelift_codegen::{isa, settings};
use cranelift_jit::{JITBuilder, JITModule}; use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{default_libcall_names, DataContext, DataId, FuncId, Linkage, Module}; use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module};
use cranelift_object::{ObjectBuilder, ObjectModule}; use cranelift_object::{ObjectBuilder, ObjectModule};
use std::collections::HashMap;
use target_lexicon::Triple; use target_lexicon::Triple;
const EMPTY_DATUM: [u8; 8] = [0; 8]; const EMPTY_DATUM: [u8; 8] = [0; 8];
@@ -55,11 +55,12 @@ const EMPTY_DATUM: [u8; 8] = [0; 8];
/// implementations. /// implementations.
pub struct Backend<M: Module> { pub struct Backend<M: Module> {
pub module: M, pub module: M,
data_ctx: DataContext, data_ctx: DataDescription,
runtime_functions: RuntimeFunctions, runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>, defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, DataId>, defined_symbols: HashMap<String, (DataId, ConstantType)>,
output_buffer: Option<String>, output_buffer: Option<String>,
platform: Triple,
} }
impl Backend<JITModule> { impl Backend<JITModule> {
@@ -85,11 +86,12 @@ impl Backend<JITModule> {
Ok(Backend { Ok(Backend {
module, module,
data_ctx: DataContext::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer, output_buffer,
platform: Triple::host(),
}) })
} }
@@ -123,11 +125,12 @@ impl Backend<ObjectModule> {
Ok(Backend { Ok(Backend {
module, module,
data_ctx: DataContext::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer: None, output_buffer: None,
platform,
}) })
} }
@@ -154,7 +157,7 @@ impl<M: Module> Backend<M> {
let global_id = self let global_id = self
.module .module
.declare_data(&name, Linkage::Local, false, false)?; .declare_data(&name, Linkage::Local, false, false)?;
let mut data_context = DataContext::new(); let mut data_context = DataDescription::new();
data_context.set_align(8); data_context.set_align(8);
data_context.define(s0.into_boxed_str().into_boxed_bytes()); data_context.define(s0.into_boxed_str().into_boxed_bytes());
self.module.define_data(global_id, &data_context)?; self.module.define_data(global_id, &data_context)?;
@@ -167,14 +170,18 @@ impl<M: Module> Backend<M> {
/// These variables can be shared between functions, and will be exported from the /// These variables can be shared between functions, and will be exported from the
/// module itself as public data in the case of static compilation. There initial /// module itself as public data in the case of static compilation. There initial
/// value will be null. /// value will be null.
pub fn define_variable(&mut self, name: String) -> Result<DataId, BackendError> { pub fn define_variable(
&mut self,
name: String,
ctype: ConstantType,
) -> Result<DataId, BackendError> {
self.data_ctx.define(Box::new(EMPTY_DATUM)); self.data_ctx.define(Box::new(EMPTY_DATUM));
let id = self let id = self
.module .module
.declare_data(&name, Linkage::Export, true, false)?; .declare_data(&name, Linkage::Export, true, false)?;
self.module.define_data(id, &self.data_ctx)?; self.module.define_data(id, &self.data_ctx)?;
self.data_ctx.clear(); self.data_ctx.clear();
self.defined_symbols.insert(name, id); self.defined_symbols.insert(name, (id, ctype));
Ok(id) Ok(id)
} }

View File

@@ -1,4 +1,4 @@
use crate::backend::runtime::RuntimeFunctionError; use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError}; use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError};
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
@@ -39,6 +39,8 @@ pub enum BackendError {
LookupError(#[from] LookupError), LookupError(#[from] LookupError),
#[error(transparent)] #[error(transparent)]
Write(#[from] cranelift_object::object::write::Error), Write(#[from] cranelift_object::object::write::Error),
#[error("Invalid type cast from {from} to {to}")]
InvalidTypeCast { from: PrimitiveType, to: Type },
} }
impl From<BackendError> for Diagnostic<usize> { impl From<BackendError> for Diagnostic<usize> {
@@ -64,6 +66,9 @@ impl From<BackendError> for Diagnostic<usize> {
BackendError::Write(me) => { BackendError::Write(me) => {
Diagnostic::error().with_message(format!("Cranelift object write error: {}", me)) Diagnostic::error().with_message(format!("Cranelift object write error: {}", me))
} }
BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message(
format!("Internal error trying to cast from {} to {}", from, to),
),
} }
} }
} }
@@ -103,6 +108,17 @@ impl PartialEq for BackendError {
BackendError::Write(b) => a == b, BackendError::Write(b) => a == b,
_ => false, _ => false,
}, },
BackendError::InvalidTypeCast {
from: from1,
to: to1,
} => match other {
BackendError::InvalidTypeCast {
from: from2,
to: to2,
} => from1 == from2 && to1 == to2,
_ => false,
},
} }
} }
} }

View File

@@ -1,10 +1,13 @@
use std::path::Path;
use crate::backend::Backend; use crate::backend::Backend;
use crate::eval::EvalError; use crate::eval::EvalError;
use crate::ir::Program; use crate::ir::Program;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
use cranelift_jit::JITModule; use cranelift_jit::JITModule;
use cranelift_object::ObjectModule; use cranelift_object::ObjectModule;
#[cfg(test)]
use proptest::arbitrary::Arbitrary;
use std::path::Path;
use target_lexicon::Triple; use target_lexicon::Triple;
impl Backend<JITModule> { impl Backend<JITModule> {
@@ -28,7 +31,8 @@ impl Backend<JITModule> {
let compiled_bytes = jitter.bytes(function_id); let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function(); compiled_function();
Ok(jitter.output()) let output = jitter.output();
Ok(output)
} }
} }
@@ -116,7 +120,7 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn static_backend(program: Program) { fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
let basic_result = program.eval(); let basic_result = program.eval();
@@ -127,8 +131,18 @@ proptest::proptest! {
let basic_result = basic_result.map(|x| x.replace('\n', "\r\n")); let basic_result = basic_result.map(|x| x.replace('\n', "\r\n"));
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let compiled_result = Backend::<ObjectModule>::eval(program); let compiled_result = Backend::<ObjectModule>::eval(program);
assert_eq!(basic_result, compiled_result); proptest::prop_assert_eq!(basic_result, compiled_result);
} }
} }
@@ -136,14 +150,24 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn jit_backend(program: Program) { fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let basic_result = program.eval(); let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program); let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result); proptest::prop_assert_eq!(basic_result, compiled_result);
} }
} }
} }

View File

@@ -1,9 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef}; use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, Statement, Type, Value, ValueOrRef};
use crate::syntax::ConstantType;
use cranelift_codegen::entity::EntityRef; use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::{ use cranelift_codegen::ir::{
entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName, self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
}; };
use cranelift_codegen::isa::CallConv; use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context; use cranelift_codegen::Context;
@@ -41,7 +43,7 @@ impl<M: Module> Backend<M> {
let basic_signature = Signature { let basic_signature = Signature {
params: vec![], params: vec![],
returns: vec![], returns: vec![],
call_conv: CallConv::SystemV, call_conv: CallConv::triple_default(&self.platform),
}; };
// this generates the handle for the function that we'll eventually want to // this generates the handle for the function that we'll eventually want to
@@ -85,12 +87,12 @@ impl<M: Module> Backend<M> {
// Just like with strings, generating the `GlobalValue`s we need can potentially // Just like with strings, generating the `GlobalValue`s we need can potentially
// be a little tricky to do on the fly, so we generate the complete list right // be a little tricky to do on the fly, so we generate the complete list right
// here and then use it later. // here and then use it later.
let pre_defined_symbols: HashMap<String, GlobalValue> = self let pre_defined_symbols: HashMap<String, (GlobalValue, ConstantType)> = self
.defined_symbols .defined_symbols
.iter() .iter()
.map(|(k, v)| { .map(|(k, (v, t))| {
let local_data = self.module.declare_data_in_func(*v, &mut ctx.func); let local_data = self.module.declare_data_in_func(*v, &mut ctx.func);
(k.clone(), local_data) (k.clone(), (local_data, *t))
}) })
.collect(); .collect();
@@ -123,7 +125,7 @@ impl<M: Module> Backend<M> {
// Print statements are fairly easy to compile: we just lookup the // Print statements are fairly easy to compile: we just lookup the
// output buffer, the address of the string to print, and the value // output buffer, the address of the string to print, and the value
// of whatever variable we're printing. Then we just call print. // of whatever variable we're printing. Then we just call print.
Statement::Print(ann, var) => { Statement::Print(ann, t, var) => {
// Get the output buffer (or null) from our general compilation context. // Get the output buffer (or null) from our general compilation context.
let buffer_ptr = self.output_buffer_ptr(); let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
@@ -135,31 +137,47 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a // Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn // global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation. // this into an `Expression` and re-use the logic in that implementation.
let val = Expression::Reference(ann, var).into_crane( let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane(
&mut builder, &mut builder,
&variable_table, &variable_table,
&pre_defined_symbols, &pre_defined_symbols,
)?; )?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
let casted_val = match vtype {
ConstantType::U64 | ConstantType::I64 => val,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => {
builder.ins().sextend(types::I64, val)
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => {
builder.ins().uextend(types::I64, val)
}
};
// Finally, we can generate the call to print. // Finally, we can generate the call to print.
builder builder.ins().call(
.ins() print_func_ref,
.call(print_func_ref, &[buffer_ptr, name_ptr, val]); &[buffer_ptr, name_ptr, vtype_repr, casted_val],
);
} }
// Variable binding is a little more con // Variable binding is a little more con
Statement::Binding(_, var_name, value) => { Statement::Binding(_, var_name, _, value) => {
// Kick off to the `Expression` implementation to see what value we're going // Kick off to the `Expression` implementation to see what value we're going
// to bind to this variable. // to bind to this variable.
let val = let (val, etype) =
value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?; value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?;
// Now the question is: is this a local variable, or a global one? // Now the question is: is this a local variable, or a global one?
if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) { if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) {
// It's a global variable! In this case, we assume that someone has already // It's a global variable! In this case, we assume that someone has already
// dedicated some space in memory to store this value. We look this location // dedicated some space in memory to store this value. We look this location
// up, and then tell Cranelift to store the value there. // up, and then tell Cranelift to store the value there.
let val_ptr = builder.ins().symbol_value(types::I64, *global_id); assert_eq!(etype, *ctype);
let val_ptr = builder
.ins()
.symbol_value(ir::Type::from(*ctype), *global_id);
builder.ins().store(MemFlags::new(), val, val_ptr, 0); builder.ins().store(MemFlags::new(), val, val_ptr, 0);
} else { } else {
// It's a local variable! In this case, we need to allocate a new Cranelift // It's a local variable! In this case, we need to allocate a new Cranelift
@@ -171,12 +189,10 @@ impl<M: Module> Backend<M> {
next_var_num += 1; next_var_num += 1;
// We can add the variable directly to our local variable map; it's `Copy`. // We can add the variable directly to our local variable map; it's `Copy`.
variable_table.insert(var_name, var); variable_table.insert(var_name, (var, etype));
// Now we tell Cranelift about our new variable, which has type I64 because // Now we tell Cranelift about our new variable!
// everything we have at this point is of type I64. Once it's declare, we builder.declare_var(var, ir::Type::from(etype));
// define it as having the value we computed above.
builder.declare_var(var, types::I64);
builder.def_var(var, val); builder.def_var(var, val);
} }
} }
@@ -195,7 +211,7 @@ impl<M: Module> Backend<M> {
// so we register it using the function ID and our builder context. However, the // so we register it using the function ID and our builder context. However, the
// result of this function isn't actually super helpful. So we ignore it, unless // result of this function isn't actually super helpful. So we ignore it, unless
// it's an error. // it's an error.
let _ = self.module.define_function(func_id, &mut ctx)?; self.module.define_function(func_id, &mut ctx)?;
// done! // done!
Ok(func_id) Ok(func_id)
@@ -231,54 +247,110 @@ impl Expression {
fn into_crane( fn into_crane(
self, self,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>, local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, GlobalValue>, global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<entities::Value, BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
match self { match self {
// Values are pretty straightforward to compile, mostly because we only Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables),
// have one type of variable, and it's an integer type.
Expression::Value(_, Value::Number(_, v)) => Ok(builder.ins().iconst(types::I64, v)),
Expression::Reference(_, name) => { Expression::Cast(_, target_type, expr) => {
// first we see if this is a local variable (which is nicer, from an let (val, val_type) =
// optimization point of view.) expr.into_crane(builder, local_variables, global_variables)?;
if let Some(local_var) = local_variables.get(&name) {
return Ok(builder.use_var(*local_var)); match (val_type, &target_type) {
(ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)),
(ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().sextend(types::I16, val), ConstantType::I16))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)),
(ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)),
(ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::U16))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)),
(ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)),
(ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
_ => Err(BackendError::InvalidTypeCast {
from: val_type.into(),
to: target_type,
}),
} }
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some(global_var) = global_variables.get(name.as_ref()) {
let val_ptr = builder.ins().symbol_value(types::I64, *global_var);
return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
} }
Expression::Primitive(_, prim, mut vals) => { Expression::Primitive(_, _, prim, mut vals) => {
// we're going to use `pop`, so we're going to pull and compile the right value ... let mut values = vec![];
let right = let mut first_type = None;
vals.pop()
.unwrap() for val in vals.drain(..) {
.into_crane(builder, local_variables, global_variables)?; let (compiled, compiled_type) =
// ... and then the left. val.into_crane(builder, local_variables, global_variables)?;
let left =
vals.pop() if let Some(leftmost_type) = first_type {
.unwrap() assert_eq!(leftmost_type, compiled_type);
.into_crane(builder, local_variables, global_variables)?; } else {
first_type = Some(compiled_type);
}
values.push(compiled);
}
let first_type = first_type.expect("primitive op has at least one argument");
// then we just need to tell Cranelift how to do each of our primitives! Much // then we just need to tell Cranelift how to do each of our primitives! Much
// like Statements, above, we probably want to eventually shuffle this off into // like Statements, above, we probably want to eventually shuffle this off into
// a separate function (maybe something off `Primitive`), but for now it's simple // a separate function (maybe something off `Primitive`), but for now it's simple
// enough that we just do the `match` here. // enough that we just do the `match` here.
match prim { match prim {
Primitive::Plus => Ok(builder.ins().iadd(left, right)), Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)),
Primitive::Minus => Ok(builder.ins().isub(left, right)), Primitive::Minus if values.len() == 2 => {
Primitive::Times => Ok(builder.ins().imul(left, right)), Ok((builder.ins().isub(values[0], values[1]), first_type))
Primitive::Divide => Ok(builder.ins().sdiv(left, right)), }
Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)),
Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)),
Primitive::Divide if first_type.is_signed() => {
Ok((builder.ins().sdiv(values[0], values[1]), first_type))
}
Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)),
} }
} }
} }
@@ -291,9 +363,66 @@ impl ValueOrRef {
fn into_crane( fn into_crane(
self, self,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>, local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, GlobalValue>, global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<entities::Value, BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
Expression::from(self).into_crane(builder, local_variables, global_variables) match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
ValueOrRef::Value(_, _, val) => match val {
Value::I8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8))
}
Value::I16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::I16,
)),
Value::I32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::I32,
)),
Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)),
Value::U8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8))
}
Value::U16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::U16,
)),
Value::U32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::U32,
)),
Value::U64(_, v) => Ok((
builder.ins().iconst(types::I64, v as i64),
ConstantType::U64,
)),
},
ValueOrRef::Ref(_, _, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some((local_var, etype)) = local_variables.get(&name) {
return Ok((builder.use_var(*local_var), *etype));
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some((global_var, etype)) = global_variables.get(name.as_ref()) {
let cranelift_type = ir::Type::from(*etype);
let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var);
return Ok((
builder
.ins()
.load(cranelift_type, MemFlags::new(), val_ptr, 0),
*etype,
));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
}
} }
} }

View File

@@ -8,6 +8,8 @@ use std::fmt::Write;
use target_lexicon::Triple; use target_lexicon::Triple;
use thiserror::Error; use thiserror::Error;
use crate::syntax::ConstantType;
/// An object for querying / using functions built into the runtime. /// An object for querying / using functions built into the runtime.
/// ///
/// Right now, this is a quite a bit of boilerplate for very nebulous /// Right now, this is a quite a bit of boilerplate for very nebulous
@@ -49,7 +51,7 @@ impl RuntimeFunctions {
"print", "print",
Linkage::Import, Linkage::Import,
&Signature { &Signature {
params: vec![string_param, string_param, int64_param], params: vec![string_param, string_param, int64_param, int64_param],
returns: vec![], returns: vec![],
call_conv: CallConv::triple_default(platform), call_conv: CallConv::triple_default(platform),
}, },
@@ -98,13 +100,30 @@ impl RuntimeFunctions {
// we extend with the output, so that multiple JIT'd `Program`s can run concurrently // we extend with the output, so that multiple JIT'd `Program`s can run concurrently
// without stomping over each other's output. If `output_buffer` is NULL, we just print // without stomping over each other's output. If `output_buffer` is NULL, we just print
// to stdout. // to stdout.
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) { extern "C" fn runtime_print(
output_buffer: *mut String,
name: *const i8,
vtype_repr: i64,
value: i64,
) {
let cstr = unsafe { CStr::from_ptr(name) }; let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy(); let reconstituted = cstr.to_string_lossy();
let output = match vtype_repr.try_into() {
Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8),
Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16),
Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32),
Ok(ConstantType::I64) => format!("{} = {}i64", reconstituted, value),
Ok(ConstantType::U8) => format!("{} = {}u8", reconstituted, value as u8),
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value),
};
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap(); writeln!(output_buffer, "{}", output).unwrap();
} else { } else {
println!("{} = {}", reconstituted, value); println!("{}", output);
} }
} }

View File

@@ -17,7 +17,7 @@ fn main() {
let args = CommandLineArguments::parse(); let args = CommandLineArguments::parse();
let mut compiler = ngr::Compiler::default(); let mut compiler = ngr::Compiler::default();
let output_file = args.output.unwrap_or("output.o".to_string()); let output_file = args.output.unwrap_or_else(|| "output.o".to_string());
if let Some(bytes) = compiler.compile(&args.file) { if let Some(bytes) = compiler.compile(&args.file) {
std::fs::write(&output_file, bytes) std::fs::write(&output_file, bytes)

View File

@@ -1,6 +1,5 @@
use crate::backend::Backend;
use crate::ir::Program as IR;
use crate::syntax::Program as Syntax; use crate::syntax::Program as Syntax;
use crate::{backend::Backend, type_infer::TypeInferenceResult};
use codespan_reporting::{ use codespan_reporting::{
diagnostic::Diagnostic, diagnostic::Diagnostic,
files::SimpleFiles, files::SimpleFiles,
@@ -100,8 +99,38 @@ impl Compiler {
return Ok(None); return Ok(None);
} }
// Now that we've validated it, turn it into IR. // Now that we've validated it, let's do type inference, potentially turning
let ir = IR::from(syntax); // into IR while we're at it.
let ir = match syntax.type_infer() {
TypeInferenceResult::Failure {
mut errors,
mut warnings,
} => {
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit(message);
}
return Ok(None);
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
let messages = warnings.drain(..).map(Into::into);
for message in messages {
self.emit(message);
}
result
}
};
// Finally, send all this to Cranelift for conversion into an object file. // Finally, send all this to Cranelift for conversion into an object file.
let mut backend = Backend::object_file(Triple::host())?; let mut backend = Backend::object_file(Triple::host())?;

View File

@@ -35,11 +35,13 @@
//! //!
mod env; mod env;
mod primop; mod primop;
mod primtype;
mod value; mod value;
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError}; pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError; pub use primop::PrimOpError;
pub use primtype::PrimitiveType;
pub use value::Value; pub use value::Value;
use crate::backend::BackendError; use crate::backend::BackendError;

View File

@@ -87,9 +87,9 @@ mod tests {
let tester = tester.extend(arced("bar"), 2i64.into()); let tester = tester.extend(arced("bar"), 2i64.into());
let tester = tester.extend(arced("goo"), 5i64.into()); let tester = tester.extend(arced("goo"), 5i64.into());
assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(tester.lookup(arced("bar")), Ok(2.into())); assert_eq!(tester.lookup(arced("bar")), Ok(2i64.into()));
assert_eq!(tester.lookup(arced("goo")), Ok(5.into())); assert_eq!(tester.lookup(arced("goo")), Ok(5i64.into()));
assert!(tester.lookup(arced("baz")).is_err()); assert!(tester.lookup(arced("baz")).is_err());
} }
@@ -103,14 +103,14 @@ mod tests {
check_nested(&tester); check_nested(&tester);
assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert!(tester.lookup(arced("bar")).is_err()); assert!(tester.lookup(arced("bar")).is_err());
} }
fn check_nested(env: &EvalEnvironment) { fn check_nested(env: &EvalEnvironment) {
let nested_env = env.extend(arced("bar"), 2i64.into()); let nested_env = env.extend(arced("bar"), 2i64.into());
assert_eq!(nested_env.lookup(arced("foo")), Ok(1.into())); assert_eq!(nested_env.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(nested_env.lookup(arced("bar")), Ok(2.into())); assert_eq!(nested_env.lookup(arced("bar")), Ok(2i64.into()));
} }
fn arced(s: &str) -> ArcIntern<String> { fn arced(s: &str) -> ArcIntern<String> {

View File

@@ -1,3 +1,4 @@
use crate::eval::primtype::PrimitiveType;
use crate::eval::value::Value; use crate::eval::value::Value;
/// Errors that can occur running primitive operations in the evaluators. /// Errors that can occur running primitive operations in the evaluators.
@@ -22,6 +23,13 @@ pub enum PrimOpError {
BadArgCount(String, usize), BadArgCount(String, usize),
#[error("Unknown primitive operation {0}")] #[error("Unknown primitive operation {0}")]
UnknownPrimOp(String), UnknownPrimOp(String),
#[error("Unsafe cast from {from} to {to}")]
UnsafeCast {
from: PrimitiveType,
to: PrimitiveType,
},
#[error("Unknown primitive type {0}")]
UnknownPrimType(String),
} }
// Implementing primitives in an interpreter like this is *super* tedious, // Implementing primitives in an interpreter like this is *super* tedious,
@@ -37,39 +45,95 @@ pub enum PrimOpError {
macro_rules! run_op { macro_rules! run_op {
($op: ident, $left: expr, $right: expr) => { ($op: ident, $left: expr, $right: expr) => {
match $op { match $op {
"+" => $left "+" => Ok($left.wrapping_add($right).into()),
.checked_add($right) "-" => Ok($left.wrapping_sub($right).into()),
.ok_or(PrimOpError::MathFailure("+")) "*" => Ok($left.wrapping_mul($right).into()),
.map(Into::into), "/" if $right == 0 => Err(PrimOpError::MathFailure("/")),
"-" => $left "/" => Ok($left.wrapping_div($right).into()),
.checked_sub($right)
.ok_or(PrimOpError::MathFailure("-"))
.map(Into::into),
"*" => $left
.checked_mul($right)
.ok_or(PrimOpError::MathFailure("*"))
.map(Into::into),
"/" => $left
.checked_div($right)
.ok_or(PrimOpError::MathFailure("/"))
.map(Into::into),
_ => Err(PrimOpError::UnknownPrimOp($op.to_string())), _ => Err(PrimOpError::UnknownPrimOp($op.to_string())),
} }
}; };
} }
impl Value { impl Value {
fn unary_op(operation: &str, value: &Value) -> Result<Value, PrimOpError> {
match operation {
"-" => match value {
Value::I8(x) => Ok(Value::I8(x.wrapping_neg())),
Value::I16(x) => Ok(Value::I16(x.wrapping_neg())),
Value::I32(x) => Ok(Value::I32(x.wrapping_neg())),
Value::I64(x) => Ok(Value::I64(x.wrapping_neg())),
_ => Err(PrimOpError::BadTypeFor("-", value.clone())),
},
_ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)),
}
}
fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> { fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> {
match left { match left {
// for now we only have one type, but in the future this is Value::I8(x) => match right {
// going to be very irritating. Value::I8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I16(x) => match right {
Value::I16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I32(x) => match right {
Value::I32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I64(x) => match right { Value::I64(x) => match right {
Value::I64(y) => run_op!(operation, x, *y), Value::I64(y) => run_op!(operation, x, *y),
// _ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
// operation.to_string(), operation.to_string(),
// left.clone(), left.clone(),
// right.clone(), right.clone(),
// )), )),
},
Value::U8(x) => match right {
Value::U8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U16(x) => match right {
Value::U16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U32(x) => match right {
Value::U32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U64(x) => match right {
Value::U64(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
}, },
} }
} }
@@ -83,13 +147,10 @@ impl Value {
/// its worth being careful to make sure that your inputs won't cause either /// its worth being careful to make sure that your inputs won't cause either
/// condition. /// condition.
pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> { pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> {
if values.len() == 2 { match values.len() {
Value::binary_op(operation, &values[0], &values[1]) 1 => Value::unary_op(operation, &values[0]),
} else { 2 => Value::binary_op(operation, &values[0], &values[1]),
Err(PrimOpError::BadArgCount( x => Err(PrimOpError::BadArgCount(operation.to_string(), x)),
operation.to_string(),
values.len(),
))
} }
} }
} }

173
src/eval/primtype.rs Normal file
View File

@@ -0,0 +1,173 @@
use crate::{
eval::{PrimOpError, Value},
syntax::ConstantType,
};
use std::{fmt::Display, str::FromStr};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum PrimitiveType {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::I8 => write!(f, "i8"),
PrimitiveType::I16 => write!(f, "i16"),
PrimitiveType::I32 => write!(f, "i32"),
PrimitiveType::I64 => write!(f, "i64"),
PrimitiveType::U8 => write!(f, "u8"),
PrimitiveType::U16 => write!(f, "u16"),
PrimitiveType::U32 => write!(f, "u32"),
PrimitiveType::U64 => write!(f, "u64"),
}
}
}
impl<'a> From<&'a Value> for PrimitiveType {
fn from(value: &Value) -> Self {
match value {
Value::I8(_) => PrimitiveType::I8,
Value::I16(_) => PrimitiveType::I16,
Value::I32(_) => PrimitiveType::I32,
Value::I64(_) => PrimitiveType::I64,
Value::U8(_) => PrimitiveType::U8,
Value::U16(_) => PrimitiveType::U16,
Value::U32(_) => PrimitiveType::U32,
Value::U64(_) => PrimitiveType::U64,
}
}
}
impl From<ConstantType> for PrimitiveType {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 => PrimitiveType::I8,
ConstantType::I16 => PrimitiveType::I16,
ConstantType::I32 => PrimitiveType::I32,
ConstantType::I64 => PrimitiveType::I64,
ConstantType::U8 => PrimitiveType::U8,
ConstantType::U16 => PrimitiveType::U16,
ConstantType::U32 => PrimitiveType::U32,
ConstantType::U64 => PrimitiveType::U64,
}
}
}
impl FromStr for PrimitiveType {
type Err = PrimOpError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"i8" => Ok(PrimitiveType::I8),
"i16" => Ok(PrimitiveType::I16),
"i32" => Ok(PrimitiveType::I32),
"i64" => Ok(PrimitiveType::I64),
"u8" => Ok(PrimitiveType::U8),
"u16" => Ok(PrimitiveType::U16),
"u32" => Ok(PrimitiveType::U32),
"u64" => Ok(PrimitiveType::U64),
_ => Err(PrimOpError::UnknownPrimType(s.to_string())),
}
}
}
impl PrimitiveType {
/// Return true if this type can be safely cast into the target type.
pub fn can_cast_to(&self, target: &PrimitiveType) -> bool {
match self {
PrimitiveType::U8 => matches!(
target,
PrimitiveType::U8
| PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I16
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U16 => matches!(
target,
PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U32 => matches!(
target,
PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64
),
PrimitiveType::U64 => target == &PrimitiveType::U64,
PrimitiveType::I8 => matches!(
target,
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I16 => matches!(
target,
PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64),
PrimitiveType::I64 => target == &PrimitiveType::I64,
}
}
/// Try to cast the given value to this type, returning the new value.
///
/// Returns an error if the cast is not safe *in* *general*. This means that
/// this function will error even if the number will actually fit in the target
/// type, but it would not be generally safe to cast a member of the given
/// type to the target type. (So, for example, "1i64" is a number that could
/// work as a "u64", but since negative numbers wouldn't work, a cast from
/// "1i64" to "u64" will fail.)
pub fn safe_cast(&self, source: &Value) -> Result<Value, PrimOpError> {
match (self, source) {
(PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)),
(PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)),
(PrimitiveType::U16, Value::U16(x)) => Ok(Value::U16(*x)),
(PrimitiveType::U32, Value::U8(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U16(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U32(x)) => Ok(Value::U32(*x)),
(PrimitiveType::U64, Value::U8(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U16(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U32(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U64(x)) => Ok(Value::U64(*x)),
(PrimitiveType::I8, Value::I8(x)) => Ok(Value::I8(*x)),
(PrimitiveType::I16, Value::I8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I16, Value::I16(x)) => Ok(Value::I16(*x)),
(PrimitiveType::I32, Value::I8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I32(x)) => Ok(Value::I32(*x)),
(PrimitiveType::I64, Value::I8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
_ => Err(PrimOpError::UnsafeCast {
from: source.into(),
to: *self,
}),
}
}
pub fn max_value(&self) -> u64 {
match self {
PrimitiveType::U8 => u8::MAX as u64,
PrimitiveType::U16 => u16::MAX as u64,
PrimitiveType::U32 => u32::MAX as u64,
PrimitiveType::U64 => u64::MAX,
PrimitiveType::I8 => i8::MAX as u64,
PrimitiveType::I16 => i16::MAX as u64,
PrimitiveType::I32 => i32::MAX as u64,
PrimitiveType::I64 => i64::MAX as u64,
}
}
}

View File

@@ -7,19 +7,75 @@ use std::fmt::Display;
/// by type so that we don't mix them up. /// by type so that we don't mix them up.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Value { pub enum Value {
I8(i8),
I16(i16),
I32(i32),
I64(i64), I64(i64),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
} }
impl Display for Value { impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Value::I8(x) => write!(f, "{}i8", x),
Value::I16(x) => write!(f, "{}i16", x),
Value::I32(x) => write!(f, "{}i32", x),
Value::I64(x) => write!(f, "{}i64", x), Value::I64(x) => write!(f, "{}i64", x),
Value::U8(x) => write!(f, "{}u8", x),
Value::U16(x) => write!(f, "{}u16", x),
Value::U32(x) => write!(f, "{}u32", x),
Value::U64(x) => write!(f, "{}u64", x),
} }
} }
} }
impl From<i8> for Value {
fn from(value: i8) -> Self {
Value::I8(value)
}
}
impl From<i16> for Value {
fn from(value: i16) -> Self {
Value::I16(value)
}
}
impl From<i32> for Value {
fn from(value: i32) -> Self {
Value::I32(value)
}
}
impl From<i64> for Value { impl From<i64> for Value {
fn from(value: i64) -> Self { fn from(value: i64) -> Self {
Value::I64(value) Value::I64(value)
} }
} }
impl From<u8> for Value {
fn from(value: u8) -> Self {
Value::U8(value)
}
}
impl From<u16> for Value {
fn from(value: u16) -> Self {
Value::U16(value)
}
}
impl From<u32> for Value {
fn from(value: u32) -> Self {
Value::U32(value)
}
}
impl From<u64> for Value {
fn from(value: u64) -> Self {
Value::U64(value)
}
}

6
src/examples.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::backend::Backend;
use crate::syntax::Program as Syntax;
use codespan_reporting::files::SimpleFiles;
use cranelift_jit::JITModule;
include!(concat!(env!("OUT_DIR"), "/examples.rs"));

View File

@@ -12,9 +12,8 @@
//! validating syntax, and then figuring out how to turn it into Cranelift //! validating syntax, and then figuring out how to turn it into Cranelift
//! and object code. After that point, however, this will be the module to //! and object code. After that point, however, this will be the module to
//! come to for analysis and optimization work. //! come to for analysis and optimization work.
mod ast; pub mod ast;
mod eval; mod eval;
mod from_syntax;
mod strings; mod strings;
pub use ast::*; pub use ast::*;

View File

@@ -1,10 +1,14 @@
use crate::syntax::Location; use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern; use internment::ArcIntern;
use pretty::{DocAllocator, Pretty}; use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::{ use proptest::{
prelude::Arbitrary, prelude::Arbitrary,
strategy::{BoxedStrategy, Strategy}, strategy::{BoxedStrategy, Strategy},
}; };
use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings. /// We're going to represent variables as interned strings.
/// ///
@@ -52,12 +56,15 @@ where
} }
impl Arbitrary for Program { impl Arbitrary for Program {
type Parameters = (); type Parameters = crate::syntax::arbitrary::GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args) crate::syntax::Program::arbitrary_with(args)
.prop_map(Program::from) .prop_map(|x| {
x.type_infer()
.expect("arbitrary_with should generate type-correct programs")
})
.boxed() .boxed()
} }
} }
@@ -74,8 +81,8 @@ impl Arbitrary for Program {
/// ///
#[derive(Debug)] #[derive(Debug)]
pub enum Statement { pub enum Statement {
Binding(Location, Variable, Expression), Binding(Location, Variable, Type, Expression),
Print(Location, Variable), Print(Location, Type, Variable),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
@@ -85,13 +92,13 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Statement::Binding(_, var, expr) => allocator Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string()) .text(var.as_ref().to_string())
.append(allocator.space()) .append(allocator.space())
.append(allocator.text("=")) .append(allocator.text("="))
.append(allocator.space()) .append(allocator.space())
.append(expr.pretty(allocator)), .append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator Statement::Print(_, _, var) => allocator
.text("print") .text("print")
.append(allocator.space()) .append(allocator.space())
.append(allocator.text(var.as_ref().to_string())), .append(allocator.text(var.as_ref().to_string())),
@@ -113,9 +120,32 @@ where
/// variable reference. /// variable reference.
#[derive(Debug)] #[derive(Debug)]
pub enum Expression { pub enum Expression {
Value(Location, Value), Atomic(ValueOrRef),
Reference(Location, Variable), Cast(Location, Type, ValueOrRef),
Primitive(Location, Primitive, Vec<ValueOrRef>), Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
@@ -125,12 +155,16 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Expression::Value(_, val) => val.pretty(allocator), Expression::Atomic(x) => x.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.as_ref().to_string()), Expression::Cast(_, t, e) => allocator
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => { .text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator)) op.pretty(allocator).append(exprs[0].pretty(allocator))
} }
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator); let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator); let right = exprs[1].pretty(allocator);
@@ -140,7 +174,7 @@ where
.append(right) .append(right)
.parens() .parens()
} }
Expression::Primitive(_, op, exprs) => { Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
} }
} }
@@ -161,10 +195,10 @@ pub enum Primitive {
Divide, Divide,
} }
impl<'a> TryFrom<&'a str> for Primitive { impl FromStr for Primitive {
type Error = String; type Err = String;
fn try_from(value: &str) -> Result<Self, Self::Error> { fn from_str(value: &str) -> Result<Self, Self::Err> {
match value { match value {
"+" => Ok(Primitive::Plus), "+" => Ok(Primitive::Plus),
"-" => Ok(Primitive::Minus), "-" => Ok(Primitive::Minus),
@@ -190,15 +224,21 @@ where
} }
} }
impl fmt::Display for Primitive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f)
}
}
/// An expression that is always either a value or a reference. /// An expression that is always either a value or a reference.
/// ///
/// This is the type used to guarantee that we don't nest expressions /// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one /// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference. /// of these, which can only be a constant or a reference.
#[derive(Debug)] #[derive(Clone, Debug)]
pub enum ValueOrRef { pub enum ValueOrRef {
Value(Location, Value), Value(Location, Type, Value),
Ref(Location, ArcIntern<String>), Ref(Location, Type, ArcIntern<String>),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
@@ -208,30 +248,50 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
ValueOrRef::Value(_, v) => v.pretty(allocator), ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()), ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
} }
} }
} }
impl From<ValueOrRef> for Expression { impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self { fn from(value: ValueOrRef) -> Self {
match value { Expression::Atomic(value)
ValueOrRef::Value(loc, val) => Expression::Value(loc, val),
ValueOrRef::Ref(loc, var) => Expression::Reference(loc, var),
}
} }
} }
/// A constant in the IR. /// A constant in the IR.
#[derive(Debug)] ///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug)]
pub enum Value { pub enum Value {
/// A numerical constant. I8(Option<u8>, i8),
/// I16(Option<u8>, i16),
/// The optional argument is the base that was used by the user to input I32(Option<u8>, i32),
/// the number. By retaining it, we can ensure that if we need to print the I64(Option<u8>, i64),
/// number back out, we can do so in the form that the user entered it. U8(Option<u8>, u8),
Number(Option<u8>, i64), U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl Value {
/// Return the type described by this value
pub fn type_of(&self) -> Type {
match self {
Value::I8(_, _) => Type::Primitive(PrimitiveType::I8),
Value::I16(_, _) => Type::Primitive(PrimitiveType::I16),
Value::I32(_, _) => Type::Primitive(PrimitiveType::I32),
Value::I64(_, _) => Type::Primitive(PrimitiveType::I64),
Value::U8(_, _) => Type::Primitive(PrimitiveType::U8),
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
}
}
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
@@ -240,19 +300,64 @@ where
D: ?Sized + DocAllocator<'a, A>, D: ?Sized + DocAllocator<'a, A>,
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { let pretty_internal = |opt_base: &Option<u8>, x, t| {
Value::Number(opt_base, value) => { syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
let value_str = match opt_base { };
None => format!("{}", value),
Some(2) => format!("0b{:b}", value),
Some(8) => format!("0o{:o}", value),
Some(10) => format!("0d{}", value),
Some(16) => format!("0x{:x}", value),
Some(_) => format!("!!{:x}!!", value),
};
allocator.text(value_str) let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
} }
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Primitive(PrimitiveType),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Primitive(pt) => pt.fmt(f),
} }
} }
} }

View File

@@ -1,7 +1,7 @@
use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::eval::{EvalEnvironment, EvalError, Value};
use crate::ir::{Expression, Program, Statement}; use crate::ir::{Expression, Program, Statement};
use super::{Primitive, ValueOrRef}; use super::{Primitive, Type, ValueOrRef};
impl Program { impl Program {
/// Evaluate the program, returning either an error or a string containing everything /// Evaluate the program, returning either an error or a string containing everything
@@ -14,12 +14,12 @@ impl Program {
for stmt in self.statements.iter() { for stmt in self.statements.iter() {
match stmt { match stmt {
Statement::Binding(_, name, value) => { Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?; let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value); env = env.extend(name.clone(), actual_value);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
let value = env.lookup(name.clone())?; let value = env.lookup(name.clone())?;
let line = format!("{} = {}\n", name, value); let line = format!("{} = {}\n", name, value);
stdout.push_str(&line); stdout.push_str(&line);
@@ -34,26 +34,21 @@ impl Program {
impl Expression { impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self { match self {
Expression::Value(_, v) => match v { Expression::Atomic(x) => x.eval(env),
super::Value::Number(_, v) => Ok(Value::I64(*v)),
},
Expression::Reference(_, n) => Ok(env.lookup(n.clone())?), Expression::Cast(_, t, valref) => {
let value = valref.eval(env)?;
Expression::Primitive(_, op, args) => { match t {
let mut arg_values = Vec::with_capacity(args.len()); Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
// we implement primitive operations by first evaluating each of the
// arguments to the function, and then gathering up all the values
// produced.
for arg in args.iter() {
match arg {
ValueOrRef::Ref(_, n) => arg_values.push(env.lookup(n.clone())?),
ValueOrRef::Value(_, super::Value::Number(_, v)) => {
arg_values.push(Value::I64(*v))
}
}
} }
}
Expression::Primitive(_, _, op, args) => {
let arg_values = args
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice // and then finally we call `calculate` to run them. trust me, it's nice
// to not have to deal with all the nonsense hidden under `calculate`. // to not have to deal with all the nonsense hidden under `calculate`.
@@ -68,19 +63,38 @@ impl Expression {
} }
} }
impl ValueOrRef {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(*v)),
super::Value::I16(_, v) => Ok(Value::I16(*v)),
super::Value::I32(_, v) => Ok(Value::I32(*v)),
super::Value::I64(_, v) => Ok(Value::I64(*v)),
super::Value::U8(_, v) => Ok(Value::U8(*v)),
super::Value::U16(_, v) => Ok(Value::U16(*v)),
super::Value::U32(_, v) => Ok(Value::U32(*v)),
super::Value::U64(_, v) => Ok(Value::U64(*v)),
},
ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?),
}
}
}
#[test] #[test]
fn two_plus_three() { fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -1,186 +0,0 @@
use internment::ArcIntern;
use std::sync::atomic::AtomicUsize;
use crate::ir::ast as ir;
use crate::syntax;
use super::ValueOrRef;
impl From<syntax::Program> for ir::Program {
/// We implement the top-level conversion of a syntax::Program into an
/// ir::Program using just the standard `From::from`, because we don't
/// need to return any arguments and we shouldn't produce any errors.
/// Technically there's an `unwrap` deep under the hood that we could
/// float out, but the validator really should've made sure that never
/// happens, so we're just going to assume.
fn from(mut value: syntax::Program) -> Self {
let mut statements = Vec::new();
for stmt in value.statements.drain(..) {
statements.append(&mut stmt.simplify());
}
ir::Program { statements }
}
}
impl From<syntax::Statement> for ir::Program {
/// One interesting thing about this conversion is that there isn't
/// a natural translation from syntax::Statement to ir::Statement,
/// because the syntax version can have nested expressions and the
/// IR version can't.
///
/// As a result, we can naturally convert a syntax::Statement into
/// an ir::Program, because we can allow the additional binding
/// sites to be generated, instead. And, bonus, it turns out that
/// this is what we wanted anyways.
fn from(value: syntax::Statement) -> Self {
ir::Program {
statements: value.simplify(),
}
}
}
impl syntax::Statement {
/// Simplify a syntax::Statement into a series of ir::Statements.
///
/// The reason this function is one-to-many is because we may have to
/// introduce new binding sites in order to avoid having nested
/// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed
/// in syntax::Expression but are expressly *not* allowed in
/// ir::Expression. So this pass converts them into bindings, like
/// this:
///
/// x = (1 + 2) * 3;
///
/// ==>
///
/// x:1 = 1 + 2;
/// x:2 = x:1 * 3;
/// x = x:2
///
/// Thus ensuring that things are nice and simple. Note that the
/// binding of `x:2` is not, strictly speaking, necessary, but it
/// makes the code below much easier to read.
fn simplify(self) -> Vec<ir::Statement> {
let mut new_statements = vec![];
match self {
// Print statements we don't have to do much with
syntax::Statement::Print(loc, name) => {
new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name)))
}
// Bindings, however, may involve a single expression turning into
// a series of statements and then an expression.
syntax::Statement::Binding(loc, name, value) => {
let (mut prereqs, new_value) = value.rebind(&name);
new_statements.append(&mut prereqs);
new_statements.push(ir::Statement::Binding(
loc,
ArcIntern::new(name),
new_value.into(),
))
}
}
new_statements
}
}
impl syntax::Expression {
/// This actually does the meat of the simplification work, here, by rebinding
/// any nested expressions into their own variables. We have this return
/// `ValueOrRef` in all cases because it makes for slighly less code; in the
/// case when we actually want an `Expression`, we can just use `into()`.
fn rebind(self, base_name: &str) -> (Vec<ir::Statement>, ir::ValueOrRef) {
match self {
// Values just convert in the obvious way, and require no prereqs
syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())),
// Similarly, references just convert in the obvious way, and require
// no prereqs
syntax::Expression::Reference(loc, name) => {
(vec![], ValueOrRef::Ref(loc, ArcIntern::new(name)))
}
// Primitive expressions are where we do the real work.
syntax::Expression::Primitive(loc, prim, mut expressions) => {
// generate a fresh new name for the binding site we're going to
// introduce, basing the name on wherever we came from; so if this
// expression was bound to `x` originally, it might become `x:23`.
//
// gensym is guaranteed to give us a name that is unused anywhere
// else in the program.
let new_name = gensym(base_name);
let mut prereqs = Vec::new();
let mut new_exprs = Vec::new();
// here we loop through every argument, and recurse on the expressions
// we find. that will give us any new binding sites that *they* introduce,
// and a simple value or reference that we can use in our result.
for expr in expressions.drain(..) {
let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str());
prereqs.append(&mut cur_prereqs);
new_exprs.push(arg);
}
// now we're going to use those new arguments to run the primitive, binding
// the results to the new variable we introduced.
let prim =
ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function");
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Primitive(loc.clone(), prim, new_exprs),
));
// and finally, we can return all the new bindings, and a reference to
// the variable we just introduced to hold the value of the primitive
// invocation.
(prereqs, ValueOrRef::Ref(loc, new_name))
}
}
}
}
impl From<syntax::Value> for ir::Value {
fn from(value: syntax::Value) -> Self {
match value {
syntax::Value::Number(base, val) => ir::Value::Number(base, val),
}
}
}
impl From<String> for ir::Primitive {
fn from(value: String) -> Self {
value.try_into().unwrap()
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = ir::Program::from(input);
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -21,12 +21,12 @@ impl Program {
impl Statement { impl Statement {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) { fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self { match self {
Statement::Binding(_, name, expr) => { Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
expr.register_strings(string_set); expr.register_strings(string_set);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
} }
} }

View File

@@ -63,8 +63,11 @@
//! //!
pub mod backend; pub mod backend;
pub mod eval; pub mod eval;
#[cfg(test)]
mod examples;
pub mod ir; pub mod ir;
pub mod syntax; pub mod syntax;
pub mod type_infer;
/// Implementation module for the high-level compiler. /// Implementation module for the high-level compiler.
mod compiler; mod compiler;

View File

@@ -1,6 +1,6 @@
use crate::backend::{Backend, BackendError}; use crate::backend::{Backend, BackendError};
use crate::ir::Program as IR; use crate::syntax::{ConstantType, Location, ParserError, Statement};
use crate::syntax::{Location, ParserError, Statement}; use crate::type_infer::TypeInferenceResult;
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use codespan_reporting::files::SimpleFiles; use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::{self, Config}; use codespan_reporting::term::{self, Config};
@@ -129,17 +129,32 @@ impl REPL {
.source(); .source();
let syntax = Statement::parse(entry, source)?; let syntax = Statement::parse(entry, source)?;
// if this is a variable binding, and we've never defined this variable before, let program = match syntax {
// we should tell cranelift about it. this is optimistic; if we fail to compile, Statement::Binding(loc, name, expr) => {
// then we won't use this definition until someone tries again. // if this is a variable binding, and we've never defined this variable before,
if let Statement::Binding(_, ref name, _) = syntax { // we should tell cranelift about it. this is optimistic; if we fail to compile,
if !self.variable_binding_sites.contains_key(name.as_str()) { // then we won't use this definition until someone tries again.
self.jitter.define_string(name)?; if !self.variable_binding_sites.contains_key(&name.name) {
self.jitter.define_variable(name.clone())?; self.jitter.define_string(&name.name)?;
self.jitter
.define_variable(name.to_string(), ConstantType::U64)?;
}
crate::syntax::Program {
statements: vec![
Statement::Binding(loc.clone(), name.clone(), expr),
Statement::Print(loc, name),
],
}
} }
nonbinding => crate::syntax::Program {
statements: vec![nonbinding],
},
}; };
let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites); let (mut errors, mut warnings) =
program.validate_with_bindings(&mut self.variable_binding_sites);
let stop = !errors.is_empty(); let stop = !errors.is_empty();
let messages = errors let messages = errors
.drain(..) .drain(..)
@@ -154,13 +169,39 @@ impl REPL {
return Ok(()); return Ok(());
} }
let ir = IR::from(syntax); match program.type_infer() {
let name = format!("line{}", line_no); TypeInferenceResult::Failure {
let function_id = self.jitter.compile_function(&name, ir)?; mut errors,
self.jitter.module.finalize_definitions()?; mut warnings,
let compiled_bytes = self.jitter.bytes(function_id); } => {
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let messages = errors
compiled_function(); .drain(..)
Ok(()) .map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit_diagnostic(message)?;
}
Ok(())
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
for message in warnings.drain(..).map(Into::into) {
self.emit_diagnostic(message)?;
}
let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, result)?;
self.jitter.module.finalize_definitions()?;
let compiled_bytes = self.jitter.bytes(function_id);
let compiled_function =
unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(())
}
}
} }
} }

View File

@@ -27,7 +27,7 @@ use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles};
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
use logos::Logos; use logos::Logos;
mod arbitrary; pub mod arbitrary;
mod ast; mod ast;
mod eval; mod eval;
mod location; mod location;
@@ -40,6 +40,8 @@ lalrpop_mod!(
mod pretty; mod pretty;
mod validate; mod validate;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
pub use crate::syntax::ast::*; pub use crate::syntax::ast::*;
pub use crate::syntax::location::Location; pub use crate::syntax::location::Location;
pub use crate::syntax::parser::{ProgramParser, StatementParser}; pub use crate::syntax::parser::{ProgramParser, StatementParser};
@@ -48,7 +50,7 @@ pub use crate::syntax::tokens::{LexerError, Token};
use ::pretty::{Arena, Pretty}; use ::pretty::{Arena, Pretty};
use lalrpop_util::ParseError; use lalrpop_util::ParseError;
#[cfg(test)] #[cfg(test)]
use proptest::{prop_assert, prop_assert_eq}; use proptest::{arbitrary::Arbitrary, prop_assert, prop_assert_eq};
#[cfg(test)] #[cfg(test)]
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
@@ -73,12 +75,12 @@ pub enum ParserError {
/// Raised when we're parsing the file, and run into a token in a /// Raised when we're parsing the file, and run into a token in a
/// place we weren't expecting it. /// place we weren't expecting it.
#[error("Unrecognized token")] #[error("Unrecognized token")]
UnrecognizedToken(Location, Location, Token, Vec<String>), UnrecognizedToken(Location, Token, Vec<String>),
/// Raised when we were expecting the end of the file, but instead /// Raised when we were expecting the end of the file, but instead
/// got another token. /// got another token.
#[error("Extra token")] #[error("Extra token")]
ExtraToken(Location, Token, Location), ExtraToken(Location, Token),
/// Raised when the lexer just had some sort of internal problem /// Raised when the lexer just had some sort of internal problem
/// and just gave up. /// and just gave up.
@@ -106,30 +108,24 @@ impl ParserError {
fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self { fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self {
match err { match err {
ParseError::InvalidToken { location } => { ParseError::InvalidToken { location } => {
ParserError::InvalidToken(Location::new(file_idx, location)) ParserError::InvalidToken(Location::new(file_idx, location..location + 1))
}
ParseError::UnrecognizedEof { location, expected } => {
ParserError::UnrecognizedEOF(Location::new(file_idx, location), expected)
} }
ParseError::UnrecognizedEof { location, expected } => ParserError::UnrecognizedEOF(
Location::new(file_idx, location..location + 1),
expected,
),
ParseError::UnrecognizedToken { ParseError::UnrecognizedToken {
token: (start, token, end), token: (start, token, end),
expected, expected,
} => ParserError::UnrecognizedToken( } => {
Location::new(file_idx, start), ParserError::UnrecognizedToken(Location::new(file_idx, start..end), token, expected)
Location::new(file_idx, end), }
token,
expected,
),
ParseError::ExtraToken { ParseError::ExtraToken {
token: (start, token, end), token: (start, token, end),
} => ParserError::ExtraToken( } => ParserError::ExtraToken(Location::new(file_idx, start..end), token),
Location::new(file_idx, start),
token,
Location::new(file_idx, end),
),
ParseError::User { error } => match error { ParseError::User { error } => match error {
LexerError::LexFailure(offset) => { LexerError::LexFailure(offset) => {
ParserError::LexFailure(Location::new(file_idx, offset)) ParserError::LexFailure(Location::new(file_idx, offset..offset + 1))
} }
}, },
} }
@@ -180,37 +176,25 @@ impl<'a> From<&'a ParserError> for Diagnostic<usize> {
), ),
// encountered a token where it shouldn't be // encountered a token where it shouldn't be
ParserError::UnrecognizedToken(start, end, token, expected) => { ParserError::UnrecognizedToken(loc, token, expected) => {
let expected_str = let expected_str =
format!("unexpected token {}{}", token, display_expected(expected)); format!("unexpected token {}{}", token, display_expected(expected));
let unexpected_str = format!("unexpected token {}", token); let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error() Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str) .with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
} }
// I think we get this when we get a token, but were expected EOF // I think we get this when we get a token, but were expected EOF
ParserError::ExtraToken(start, token, end) => { ParserError::ExtraToken(loc, token) => {
let expected_str = let expected_str =
format!("unexpected token {} after the expected end of file", token); format!("unexpected token {} after the expected end of file", token);
let unexpected_str = format!("unexpected token {}", token); let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error() Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str) .with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
} }
// simple lexer errors // simple lexer errors
@@ -293,24 +277,27 @@ fn order_of_operations() {
Program::from_str(muladd1).unwrap(), Program::from_str(muladd1).unwrap(),
Program { Program {
statements: vec![Statement::Binding( statements: vec![Statement::Binding(
Location::new(testfile, 0), Location::new(testfile, 0..1),
"x".to_string(), Name::manufactured("x"),
Expression::Primitive( Expression::Primitive(
Location::new(testfile, 6), Location::new(testfile, 6..7),
"+".to_string(), "+".to_string(),
vec![ vec![
Expression::Value(Location::new(testfile, 4), Value::Number(None, 1)), Expression::Value(
Location::new(testfile, 4..5),
Value::Number(None, None, 1),
),
Expression::Primitive( Expression::Primitive(
Location::new(testfile, 10), Location::new(testfile, 10..11),
"*".to_string(), "*".to_string(),
vec![ vec![
Expression::Value( Expression::Value(
Location::new(testfile, 8), Location::new(testfile, 8..9),
Value::Number(None, 2), Value::Number(None, None, 2),
), ),
Expression::Value( Expression::Value(
Location::new(testfile, 12), Location::new(testfile, 12..13),
Value::Number(None, 3), Value::Number(None, None, 3),
), ),
] ]
) )
@@ -350,8 +337,8 @@ proptest::proptest! {
} }
#[test] #[test]
fn generated_run_or_overflow(program: Program) { fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::{EvalError, PrimOpError}; use crate::eval::{EvalError, PrimOpError};
assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))) prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))));
} }
} }

View File

@@ -1,136 +1,189 @@
use std::collections::HashSet; use crate::syntax::ast::{ConstantType, Expression, Name, Program, Statement, Value};
use crate::syntax::ast::{Expression, Program, Statement, Value};
use crate::syntax::location::Location; use crate::syntax::location::Location;
use proptest::sample::select; use proptest::sample::select;
use proptest::{ use proptest::{
prelude::{Arbitrary, BoxedStrategy, Strategy}, prelude::{Arbitrary, BoxedStrategy, Strategy},
strategy::{Just, Union}, strategy::{Just, Union},
}; };
use std::collections::HashMap;
use std::ops::Range;
const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*"; const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*";
#[derive(Debug)] impl ConstantType {
struct Name(String); fn get_operators(&self) -> &'static [(&'static str, usize)] {
match self {
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => {
&[("+", 2), ("-", 1), ("-", 2), ("*", 2), ("/", 2)]
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => {
&[("+", 2), ("-", 2), ("*", 2), ("/", 2)]
}
}
}
}
#[derive(Clone)]
pub struct GenerationEnvironment {
allow_inference: bool,
block_length: Range<usize>,
bindings: HashMap<Name, ConstantType>,
return_type: ConstantType,
}
impl Default for GenerationEnvironment {
fn default() -> Self {
GenerationEnvironment {
allow_inference: true,
block_length: 2..10,
bindings: HashMap::new(),
return_type: ConstantType::U64,
}
}
}
impl GenerationEnvironment {
pub fn new(allow_inference: bool) -> Self {
GenerationEnvironment {
allow_inference,
..Default::default()
}
}
}
impl Arbitrary for Program {
type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
proptest::collection::vec(
ProgramStatementInfo::arbitrary(),
genenv.block_length.clone(),
)
.prop_flat_map(move |mut items| {
let mut statements = Vec::new();
let mut genenv = genenv.clone();
for psi in items.drain(..) {
if genenv.bindings.is_empty() || psi.should_be_binding {
genenv.return_type = psi.binding_type;
let expr = Expression::arbitrary_with(genenv.clone());
genenv.bindings.insert(psi.name.clone(), psi.binding_type);
statements.push(
expr.prop_map(move |expr| {
Statement::Binding(Location::manufactured(), psi.name.clone(), expr)
})
.boxed(),
);
} else {
let printers = genenv.bindings.keys().map(|n| {
Just(Statement::Print(
Location::manufactured(),
Name::manufactured(n),
))
});
statements.push(Union::new(printers).boxed());
}
}
statements
.prop_map(|statements| Program { statements })
.boxed()
})
.boxed()
}
}
impl Arbitrary for Name { impl Arbitrary for Name {
type Parameters = (); type Parameters = ();
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
VALID_VARIABLE_NAMES.prop_map(Name).boxed() VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed()
} }
} }
impl Arbitrary for Program { #[derive(Debug)]
struct ProgramStatementInfo {
should_be_binding: bool,
name: Name,
binding_type: ConstantType,
}
impl Arbitrary for ProgramStatementInfo {
type Parameters = (); type Parameters = ();
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
let optionals = Vec::<Option<Name>>::arbitrary(); (
Union::new(vec![Just(true), Just(true), Just(false)]),
optionals Name::arbitrary(),
.prop_flat_map(|mut possible_names| { ConstantType::arbitrary(),
let mut statements = Vec::new(); )
let mut defined_variables: HashSet<String> = HashSet::new(); .prop_map(
|(should_be_binding, name, binding_type)| ProgramStatementInfo {
for possible_name in possible_names.drain(..) { should_be_binding,
match possible_name { name,
None if defined_variables.is_empty() => continue, binding_type,
None => statements.push( },
Union::new(defined_variables.iter().map(|name| { )
Just(Statement::Print(Location::manufactured(), name.to_string()))
}))
.boxed(),
),
Some(new_name) => {
let closures_name = new_name.0.clone();
let retval =
Expression::arbitrary_with(Some(defined_variables.clone()))
.prop_map(move |exp| {
Statement::Binding(
Location::manufactured(),
closures_name.clone(),
exp,
)
})
.boxed();
defined_variables.insert(new_name.0);
statements.push(retval);
}
}
}
statements
})
.prop_map(|statements| Program { statements })
.boxed() .boxed()
} }
} }
impl Arbitrary for Statement {
type Parameters = Option<HashSet<String>>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let duplicated_args = args.clone();
let defined_variables = args.unwrap_or_default();
let binding_strategy = (
VALID_VARIABLE_NAMES,
Expression::arbitrary_with(duplicated_args),
)
.prop_map(|(name, exp)| Statement::Binding(Location::manufactured(), name, exp))
.boxed();
if defined_variables.is_empty() {
binding_strategy
} else {
let print_strategy = Union::new(
defined_variables
.iter()
.map(|x| Just(Statement::Print(Location::manufactured(), x.to_string()))),
)
.boxed();
Union::new([binding_strategy, print_strategy]).boxed()
}
}
}
impl Arbitrary for Expression { impl Arbitrary for Expression {
type Parameters = Option<HashSet<String>>; type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
let defined_variables = args.unwrap_or_default(); // Value(Location, Value). These are the easiest variations to create, because we can always
// create one.
let value_strategy = Value::arbitrary() let value_strategy = Value::arbitrary_with(genenv.clone())
.prop_map(move |x| Expression::Value(Location::manufactured(), x)) .prop_map(|x| Expression::Value(Location::manufactured(), x))
.boxed(); .boxed();
let leaf_strategy = if defined_variables.is_empty() { // Reference(Location, String), These are slightly trickier, because we can end up in a situation
// where either no variables are defined, or where none of the defined variables have a type we
// can work with. So what we're going to do is combine this one with the previous one as a "leaf
// strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy
// if we can't actually create an references.
let mut bound_variables_of_type = genenv
.bindings
.iter()
.filter(|(_, v)| genenv.return_type == **v)
.map(|(n, _)| n)
.collect::<Vec<_>>();
let leaf_strategy = if bound_variables_of_type.is_empty() {
value_strategy value_strategy
} else { } else {
let reference_strategy = Union::new(defined_variables.iter().map(|x| { let mut strats = bound_variables_of_type
Just(Expression::Reference( .drain(..)
Location::manufactured(), .map(|x| {
x.to_owned(), Just(Expression::Reference(
)) Location::manufactured(),
})) x.name.clone(),
.boxed(); ))
Union::new([value_strategy, reference_strategy]).boxed() .boxed()
})
.collect::<Vec<_>>();
strats.push(value_strategy);
Union::new(strats).boxed()
}; };
// now we generate our recursive types, given our leaf strategy
leaf_strategy leaf_strategy
.prop_recursive(3, 64, 2, move |inner| { .prop_recursive(3, 10, 2, move |strat| {
( (
select(super::BINARY_OPERATORS), select(genenv.return_type.get_operators()),
proptest::collection::vec(inner, 2), strat.clone(),
strat,
) )
.prop_map(move |(operator, exprs)| { .prop_map(|((oper, count), left, right)| {
Expression::Primitive(Location::manufactured(), operator.to_string(), exprs) let mut args = vec![left, right];
while args.len() > count {
args.pop();
}
Expression::Primitive(Location::manufactured(), oper.to_string(), args)
}) })
}) })
.boxed() .boxed()
@@ -138,22 +191,57 @@ impl Arbitrary for Expression {
} }
impl Arbitrary for Value { impl Arbitrary for Value {
type Parameters = (); type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
let base_strategy = Union::new([ let printed_base_strategy = Union::new([
Just(None::<u8>), Just(None::<u8>),
Just(Some(2)), Just(Some(2)),
Just(Some(8)), Just(Some(8)),
Just(Some(10)), Just(Some(10)),
Just(Some(16)), Just(Some(16)),
]); ]);
let value_strategy = u64::arbitrary();
let value_strategy = i64::arbitrary(); (printed_base_strategy, bool::arbitrary(), value_strategy)
.prop_map(move |(base, declare_type, value)| {
(base_strategy, value_strategy) let converted_value = match genenv.return_type {
.prop_map(move |(base, value)| Value::Number(base, value)) ConstantType::I8 => value % (i8::MAX as u64),
ConstantType::U8 => value % (u8::MAX as u64),
ConstantType::I16 => value % (i16::MAX as u64),
ConstantType::U16 => value % (u16::MAX as u64),
ConstantType::I32 => value % (i32::MAX as u64),
ConstantType::U32 => value % (u32::MAX as u64),
ConstantType::I64 => value % (i64::MAX as u64),
ConstantType::U64 => value,
};
let ty = if declare_type || !genenv.allow_inference {
Some(genenv.return_type)
} else {
None
};
Value::Number(base, ty, converted_value)
})
.boxed() .boxed()
} }
} }
impl Arbitrary for ConstantType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
Union::new([
Just(ConstantType::I8),
Just(ConstantType::I16),
Just(ConstantType::I32),
Just(ConstantType::I64),
Just(ConstantType::U8),
Just(ConstantType::U16),
Just(ConstantType::U32),
Just(ConstantType::U64),
])
.boxed()
}
}

View File

@@ -1,7 +1,10 @@
use crate::syntax::Location; use std::fmt;
use std::hash::Hash;
/// The set of valid binary operators. use internment::ArcIntern;
pub static BINARY_OPERATORS: &[&str] = &["+", "-", "*", "/"];
pub use crate::syntax::tokens::ConstantType;
use crate::syntax::Location;
/// A structure represented a parsed program. /// A structure represented a parsed program.
/// ///
@@ -16,6 +19,56 @@ pub struct Program {
pub statements: Vec<Statement>, pub statements: Vec<Statement>,
} }
/// A Name.
///
/// This is basically a string, but annotated with the place the string
/// is in the source file.
#[derive(Clone, Debug)]
pub struct Name {
pub name: String,
pub location: Location,
}
impl Name {
pub fn new<S: ToString>(n: S, location: Location) -> Name {
Name {
name: n.to_string(),
location,
}
}
pub fn manufactured<S: ToString>(n: S) -> Name {
Name {
name: n.to_string(),
location: Location::manufactured(),
}
}
pub fn intern(self) -> ArcIntern<String> {
ArcIntern::new(self.name)
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state)
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
/// A parsed statement. /// A parsed statement.
/// ///
/// Statements are guaranteed to be syntactically valid, but may be /// Statements are guaranteed to be syntactically valid, but may be
@@ -29,8 +82,8 @@ pub struct Program {
/// thing, not if they are the exact same statement. /// thing, not if they are the exact same statement.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Statement { pub enum Statement {
Binding(Location, String, Expression), Binding(Location, Name, Expression),
Print(Location, String), Print(Location, Name),
} }
impl PartialEq for Statement { impl PartialEq for Statement {
@@ -58,6 +111,7 @@ impl PartialEq for Statement {
pub enum Expression { pub enum Expression {
Value(Location, Value), Value(Location, Value),
Reference(Location, String), Reference(Location, String),
Cast(Location, String, Box<Expression>),
Primitive(Location, String, Vec<Expression>), Primitive(Location, String, Vec<Expression>),
} }
@@ -72,6 +126,10 @@ impl PartialEq for Expression {
Expression::Reference(_, var2) => var1 == var2, Expression::Reference(_, var2) => var1 == var2,
_ => false, _ => false,
}, },
Expression::Cast(_, t1, e1) => match other {
Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2,
_ => false,
},
Expression::Primitive(_, prim1, args1) => match other { Expression::Primitive(_, prim1, args1) => match other {
Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2,
_ => false, _ => false,
@@ -83,6 +141,12 @@ impl PartialEq for Expression {
/// A value from the source syntax /// A value from the source syntax
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Value { pub enum Value {
/// The value of the number, and an optional base that it was written in /// The value of the number, an optional base that it was written in, and any
Number(Option<u8>, i64), /// type information provided.
///
/// u64 is chosen because it should be big enough to carry the amount of
/// information we need, and technically we interpret -4 as the primitive unary
/// operation "-" on the number 4. We'll translate this into a type-specific
/// number at a later time.
Number(Option<u8>, Option<ConstantType>, u64),
} }

View File

@@ -1,7 +1,8 @@
use internment::ArcIntern; use internment::ArcIntern;
use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value};
use crate::syntax::{Expression, Program, Statement}; use crate::syntax::{ConstantType, Expression, Program, Statement};
use std::str::FromStr;
impl Program { impl Program {
/// Evaluate the program, returning either an error or what it prints out when run. /// Evaluate the program, returning either an error or what it prints out when run.
@@ -24,11 +25,11 @@ impl Program {
match stmt { match stmt {
Statement::Binding(_, name, value) => { Statement::Binding(_, name, value) => {
let actual_value = value.eval(&env)?; let actual_value = value.eval(&env)?;
env = env.extend(ArcIntern::new(name.clone()), actual_value); env = env.extend(name.clone().intern(), actual_value);
} }
Statement::Print(_, name) => { Statement::Print(_, name) => {
let value = env.lookup(ArcIntern::new(name.clone()))?; let value = env.lookup(name.clone().intern())?;
let line = format!("{} = {}\n", name, value); let line = format!("{} = {}\n", name, value);
stdout.push_str(&line); stdout.push_str(&line);
} }
@@ -43,11 +44,28 @@ impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self { match self {
Expression::Value(_, v) => match v { Expression::Value(_, v) => match v {
super::Value::Number(_, v) => Ok(Value::I64(*v)), super::Value::Number(_, ty, v) => match ty {
None => Ok(Value::U64(*v)),
// FIXME: make these types validate their input size
Some(ConstantType::I8) => Ok(Value::I8(*v as i8)),
Some(ConstantType::I16) => Ok(Value::I16(*v as i16)),
Some(ConstantType::I32) => Ok(Value::I32(*v as i32)),
Some(ConstantType::I64) => Ok(Value::I64(*v as i64)),
Some(ConstantType::U8) => Ok(Value::U8(*v as u8)),
Some(ConstantType::U16) => Ok(Value::U16(*v as u16)),
Some(ConstantType::U32) => Ok(Value::U32(*v as u32)),
Some(ConstantType::U64) => Ok(Value::U64(*v)),
},
}, },
Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?), Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?),
Expression::Cast(_, target, expr) => {
let target_type = PrimitiveType::from_str(target)?;
let value = expr.eval(env)?;
Ok(target_type.safe_cast(&value)?)
}
Expression::Primitive(_, op, args) => { Expression::Primitive(_, op, args) => {
let mut arg_values = Vec::with_capacity(args.len()); let mut arg_values = Vec::with_capacity(args.len());
@@ -66,12 +84,12 @@ impl Expression {
fn two_plus_three() { fn two_plus_three() {
let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let output = input.eval().expect("runs successfully"); let output = input.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let output = input.eval().expect("runs successfully"); let output = input.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -1,13 +1,15 @@
use std::ops::Range;
use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::diagnostic::{Diagnostic, Label};
/// A source location, for use in pointing users towards warnings and errors. /// A source location, for use in pointing users towards warnings and errors.
/// ///
/// Internally, locations are very tied to the `codespan_reporting` library, /// Internally, locations are very tied to the `codespan_reporting` library,
/// and the primary use of them is to serve as anchors within that library. /// and the primary use of them is to serve as anchors within that library.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Location { pub struct Location {
file_idx: usize, file_idx: usize,
offset: usize, location: Range<usize>,
} }
impl Location { impl Location {
@@ -17,8 +19,8 @@ impl Location {
/// The file index is based on the file database being used. See the /// The file index is based on the file database being used. See the
/// `codespan_reporting::files::SimpleFiles::add` function, which is /// `codespan_reporting::files::SimpleFiles::add` function, which is
/// normally where we get this index. /// normally where we get this index.
pub fn new(file_idx: usize, offset: usize) -> Self { pub fn new(file_idx: usize, location: Range<usize>) -> Self {
Location { file_idx, offset } Location { file_idx, location }
} }
/// Generate a `Location` for a completely manufactured bit of code. /// Generate a `Location` for a completely manufactured bit of code.
@@ -30,7 +32,7 @@ impl Location {
pub fn manufactured() -> Self { pub fn manufactured() -> Self {
Location { Location {
file_idx: 0, file_idx: 0,
offset: 0, location: 0..0,
} }
} }
@@ -47,7 +49,7 @@ impl Location {
/// actually happened), but you'd probably want to make the first location /// actually happened), but you'd probably want to make the first location
/// the secondary label to help users find it. /// the secondary label to help users find it.
pub fn primary_label(&self) -> Label<usize> { pub fn primary_label(&self) -> Label<usize> {
Label::primary(self.file_idx, self.offset..self.offset) Label::primary(self.file_idx, self.location.clone())
} }
/// Generate a secondary label for a [`Diagnostic`], based on this source /// Generate a secondary label for a [`Diagnostic`], based on this source
@@ -64,35 +66,7 @@ impl Location {
/// probably want to make the first location the secondary label to help /// probably want to make the first location the secondary label to help
/// users find it. /// users find it.
pub fn secondary_label(&self) -> Label<usize> { pub fn secondary_label(&self) -> Label<usize> {
Label::secondary(self.file_idx, self.offset..self.offset) Label::secondary(self.file_idx, self.location.clone())
}
/// Given this location and another, generate a primary label that
/// specifies the area between those two locations.
///
/// See [`Self::primary_label`] for some discussion of primary versus
/// secondary labels. If the two locations are the same, this method does
/// the exact same thing as [`Self::primary_label`]. If this item was
/// generated by [`Self::manufactured`], it will act as if you'd called
/// `primary_label` on the argument. Otherwise, it will generate the obvious
/// span.
///
/// This function will return `None` only in the case that you provide
/// labels from two different files, which it cannot sensibly handle.
pub fn range_label(&self, end: &Location) -> Option<Label<usize>> {
if self.file_idx == 0 {
return Some(end.primary_label());
}
if self.file_idx != end.file_idx {
return None;
}
if self.offset > end.offset {
Some(Label::primary(self.file_idx, end.offset..self.offset))
} else {
Some(Label::primary(self.file_idx, self.offset..end.offset))
}
} }
/// Return an error diagnostic centered at this location. /// Return an error diagnostic centered at this location.
@@ -102,10 +76,7 @@ impl Location {
/// this particular location. You'll need to extend it with actually useful /// this particular location. You'll need to extend it with actually useful
/// information, like what kind of error it is. /// information, like what kind of error it is.
pub fn error(&self) -> Diagnostic<usize> { pub fn error(&self) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary( Diagnostic::error().with_labels(vec![Label::primary(self.file_idx, self.location.clone())])
self.file_idx,
self.offset..self.offset,
)])
} }
/// Return an error diagnostic centered at this location, with the given message. /// Return an error diagnostic centered at this location, with the given message.
@@ -115,10 +86,34 @@ impl Location {
/// even more information to ut, using [`Diagnostic::with_labels`], /// even more information to ut, using [`Diagnostic::with_labels`],
/// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`]. /// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`].
pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> { pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary( Diagnostic::error().with_labels(vec![
self.file_idx, Label::primary(self.file_idx, self.location.clone()).with_message(msg)
self.offset..self.offset, ])
) }
.with_message(msg)])
/// Merge two locations into a single location spanning the whole range between
/// them.
///
/// This function returns None if the locations are from different files; this
/// can happen if one of the locations is manufactured, for example.
pub fn merge(&self, other: &Self) -> Option<Self> {
if self.file_idx != other.file_idx {
None
} else {
let start = if self.location.start <= other.location.start {
self.location.start
} else {
other.location.start
};
let end = if self.location.end >= other.location.end {
self.location.end
} else {
other.location.end
};
Some(Location {
file_idx: self.file_idx,
location: start..end,
})
}
} }
} }

View File

@@ -9,8 +9,8 @@
//! eventually want to leave lalrpop behind.) //! eventually want to leave lalrpop behind.)
//! //!
use crate::syntax::{LexerError, Location}; use crate::syntax::{LexerError, Location};
use crate::syntax::ast::{Program,Statement,Expression,Value}; use crate::syntax::ast::{Program,Statement,Expression,Value,Name};
use crate::syntax::tokens::Token; use crate::syntax::tokens::{ConstantType, Token};
use internment::ArcIntern; use internment::ArcIntern;
// one cool thing about lalrpop: we can pass arguments. in this case, the // one cool thing about lalrpop: we can pass arguments. in this case, the
@@ -32,6 +32,8 @@ extern {
";" => Token::Semi, ";" => Token::Semi,
"(" => Token::LeftParen, "(" => Token::LeftParen,
")" => Token::RightParen, ")" => Token::RightParen,
"<" => Token::LessThan,
">" => Token::GreaterThan,
"print" => Token::Print, "print" => Token::Print,
@@ -44,7 +46,7 @@ extern {
// to name and use "their value", you get their source location. // to name and use "their value", you get their source location.
// For these, we want "their value" to be their actual contents, // For these, we want "their value" to be their actual contents,
// which is why we put their types in angle brackets. // which is why we put their types in angle brackets.
"<num>" => Token::Number((<Option<u8>>,<i64>)), "<num>" => Token::Number((<Option<u8>>,<Option<ConstantType>>,<u64>)),
"<var>" => Token::Variable(<ArcIntern<String>>), "<var>" => Token::Variable(<ArcIntern<String>>),
} }
} }
@@ -89,10 +91,19 @@ pub Statement: Statement = {
// A statement can be a variable binding. Note, here, that we use this // A statement can be a variable binding. Note, here, that we use this
// funny @L thing to get the source location before the variable, so that // funny @L thing to get the source location before the variable, so that
// we can say that this statement spans across everything. // we can say that this statement spans across everything.
<l:@L> <v:"<var>"> "=" <e:Expression> ";" => Statement::Binding(Location::new(file_idx, l), v.to_string(), e), <ls: @L> <v:"<var>"> <var_end: @L> "=" <e:Expression> ";" <le: @L> =>
Statement::Binding(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, ls..var_end)),
e,
),
// Alternatively, a statement can just be a print statement. // Alternatively, a statement can just be a print statement.
"print" <l:@L> <v:"<var>"> ";" => Statement::Print(Location::new(file_idx, l), v.to_string()), <ls: @L> "print" <name_start: @L> <v:"<var>"> <name_end: @L> ";" <le: @L> =>
Statement::Print(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, name_start..name_end)),
),
} }
// Expressions! Expressions are a little fiddly, because we're going to // Expressions! Expressions are a little fiddly, because we're going to
@@ -124,15 +135,27 @@ Expression: Expression = {
// we group addition and subtraction under the heading "additive" // we group addition and subtraction under the heading "additive"
AdditiveExpression: Expression = { AdditiveExpression: Expression = {
<e1:AdditiveExpression> <l:@L> "+" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "+".to_string(), vec![e1, e2]), <ls: @L> <e1:AdditiveExpression> <l: @L> "+" <e2:MultiplicativeExpression> <le: @L> =>
<e1:AdditiveExpression> <l:@L> "-" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "-".to_string(), vec![e1, e2]), Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]),
<ls: @L> <e1:AdditiveExpression> <l: @L> "-" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]),
MultiplicativeExpression, MultiplicativeExpression,
} }
// similarly, we group multiplication and division under "multiplicative" // similarly, we group multiplication and division under "multiplicative"
MultiplicativeExpression: Expression = { MultiplicativeExpression: Expression = {
<e1:MultiplicativeExpression> <l:@L> "*" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "*".to_string(), vec![e1, e2]), <ls: @L> <e1:MultiplicativeExpression> <l: @L> "*" <e2:UnaryExpression> <le: @L> =>
<e1:MultiplicativeExpression> <l:@L> "/" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "/".to_string(), vec![e1, e2]), Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]),
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "/" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]),
UnaryExpression,
}
UnaryExpression: Expression = {
<l: @L> "-" <e:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]),
<l: @L> "<" <v:"<var>"> ">" <e:UnaryExpression> <le: @L> =>
Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)),
AtomicExpression, AtomicExpression,
} }
@@ -140,22 +163,9 @@ MultiplicativeExpression: Expression = {
// they cannot be further divided into parts // they cannot be further divided into parts
AtomicExpression: Expression = { AtomicExpression: Expression = {
// just a variable reference // just a variable reference
<l:@L> <v:"<var>"> => Expression::Reference(Location::new(file_idx, l), v.to_string()), <l: @L> <v:"<var>"> <end: @L> => Expression::Reference(Location::new(file_idx, l..end), v.to_string()),
// just a number // just a number
<l:@L> <n:"<num>"> => { <l: @L> <n:"<num>"> <end: @L> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)),
let val = Value::Number(n.0, n.1);
Expression::Value(Location::new(file_idx, l), val)
},
// a tricky case: also just a number, but using a negative sign. an
// alternative way to do this -- and we may do this eventually -- is
// to implement a unary negation expression. this has the odd effect
// that the user never actually writes down a negative number; they just
// write positive numbers which are immediately sent to a negation
// primitive!
<l:@L> "-" <n:"<num>"> => {
let val = Value::Number(n.0, -n.1);
Expression::Value(Location::new(file_idx, l), val)
},
// finally, let people parenthesize expressions and get back to a // finally, let people parenthesize expressions and get back to a
// lower precedence // lower precedence
"(" <e:Expression> ")" => e, "(" <e:Expression> ")" => e,

View File

@@ -1,6 +1,8 @@
use crate::syntax::ast::{Expression, Program, Statement, Value, BINARY_OPERATORS}; use crate::syntax::ast::{Expression, Program, Statement, Value};
use pretty::{DocAllocator, DocBuilder, Pretty}; use pretty::{DocAllocator, DocBuilder, Pretty};
use super::ConstantType;
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where where
A: 'a, A: 'a,
@@ -50,14 +52,14 @@ where
match self { match self {
Expression::Value(_, val) => val.pretty(allocator), Expression::Value(_, val) => val.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.to_string()), Expression::Reference(_, var) => allocator.text(var.to_string()),
Expression::Primitive(_, op, exprs) if BINARY_OPERATORS.contains(&op.as_ref()) => { Expression::Cast(_, t, e) => allocator
assert_eq!( .text(t.clone())
exprs.len(), .angles()
2, .append(e.pretty(allocator)),
"Found binary operator with {} components?", Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator
exprs.len() .text(op.to_string())
); .append(exprs[0].pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator); let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator); let right = exprs[1].pretty(allocator);
@@ -84,15 +86,14 @@ where
{ {
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
match self { match self {
Value::Number(opt_base, value) => { Value::Number(opt_base, ty, value) => {
let sign = if *value < 0 { "-" } else { "" };
let value_str = match opt_base { let value_str = match opt_base {
None => format!("{}", value), None => format!("{}{}", value, type_suffix(ty)),
Some(2) => format!("{}0b{:b}", sign, value.abs()), Some(2) => format!("0b{:b}{}", value, type_suffix(ty)),
Some(8) => format!("{}0o{:o}", sign, value.abs()), Some(8) => format!("0o{:o}{}", value, type_suffix(ty)),
Some(10) => format!("{}0d{}", sign, value.abs()), Some(10) => format!("0d{}{}", value, type_suffix(ty)),
Some(16) => format!("{}0x{:x}", sign, value.abs()), Some(16) => format!("0x{:x}{}", value, type_suffix(ty)),
Some(_) => format!("!!{}{:x}!!", sign, value.abs()), Some(_) => format!("!!{:x}{}!!", value, type_suffix(ty)),
}; };
allocator.text(value_str) allocator.text(value_str)
@@ -101,6 +102,20 @@ where
} }
} }
fn type_suffix(x: &Option<ConstantType>) -> &'static str {
match x {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct CommaSep {} struct CommaSep {}

View File

@@ -40,6 +40,12 @@ pub enum Token {
#[token(")")] #[token(")")]
RightParen, RightParen,
#[token("<")]
LessThan,
#[token(">")]
GreaterThan,
// Next we take of any reserved words; I always like to put // Next we take of any reserved words; I always like to put
// these before we start recognizing more complicated regular // these before we start recognizing more complicated regular
// expressions. I don't think it matters, but it works for me. // expressions. I don't think it matters, but it works for me.
@@ -53,13 +59,14 @@ pub enum Token {
/// Numbers capture both the value we read from the input, /// Numbers capture both the value we read from the input,
/// converted to an `i64`, as well as the base the user used /// converted to an `i64`, as well as the base the user used
/// to write the number, if they did so. /// to write the number and/or the type the user specified,
#[regex(r"0b[01]+", |v| parse_number(Some(2), v))] /// if they did either.
#[regex(r"0o[0-7]+", |v| parse_number(Some(8), v))] #[regex(r"0b[01]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(2), v))]
#[regex(r"0d[0-9]+", |v| parse_number(Some(10), v))] #[regex(r"0o[0-7]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(8), v))]
#[regex(r"0x[0-9a-fA-F]+", |v| parse_number(Some(16), v))] #[regex(r"0d[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(10), v))]
#[regex(r"[0-9]+", |v| parse_number(None, v))] #[regex(r"0x[0-9a-fA-F]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(16), v))]
Number((Option<u8>, i64)), #[regex(r"[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(None, v))]
Number((Option<u8>, Option<ConstantType>, u64)),
// Variables; this is a very standard, simple set of characters // Variables; this is a very standard, simple set of characters
// for variables, but feel free to experiment with more complicated // for variables, but feel free to experiment with more complicated
@@ -88,15 +95,29 @@ impl fmt::Display for Token {
Token::Semi => write!(f, "';'"), Token::Semi => write!(f, "';'"),
Token::LeftParen => write!(f, "'('"), Token::LeftParen => write!(f, "'('"),
Token::RightParen => write!(f, "')'"), Token::RightParen => write!(f, "')'"),
Token::LessThan => write!(f, "<"),
Token::GreaterThan => write!(f, ">"),
Token::Print => write!(f, "'print'"), Token::Print => write!(f, "'print'"),
Token::Operator(c) => write!(f, "'{}'", c), Token::Operator(c) => write!(f, "'{}'", c),
Token::Number((None, v)) => write!(f, "'{}'", v), Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)),
Token::Number((Some(2), v)) => write!(f, "'0b{:b}'", v), Token::Number((Some(2), otype, v)) => {
Token::Number((Some(8), v)) => write!(f, "'0o{:o}'", v), write!(f, "'0b{:b}{}'", v, display_optional_type(otype))
Token::Number((Some(10), v)) => write!(f, "'{}'", v), }
Token::Number((Some(16), v)) => write!(f, "'0x{:x}'", v), Token::Number((Some(8), otype, v)) => {
Token::Number((Some(b), v)) => { write!(f, "'0o{:o}{}'", v, display_optional_type(otype))
write!(f, "Invalidly-based-number<base={},val={}>", b, v) }
Token::Number((Some(10), otype, v)) => {
write!(f, "'{}{}'", v, display_optional_type(otype))
}
Token::Number((Some(16), otype, v)) => {
write!(f, "'0x{:x}{}'", v, display_optional_type(otype))
}
Token::Number((Some(b), opt_type, v)) => {
write!(
f,
"Invalidly-based-number<base={},val={},opt_type={:?}>",
b, v, opt_type
)
} }
Token::Variable(s) => write!(f, "'{}'", s), Token::Variable(s) => write!(f, "'{}'", s),
Token::Error => write!(f, "<error>"), Token::Error => write!(f, "<error>"),
@@ -122,6 +143,125 @@ impl Token {
} }
} }
#[repr(i64)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ConstantType {
U8 = 10,
U16 = 11,
U32 = 12,
U64 = 13,
I8 = 20,
I16 = 21,
I32 = 22,
I64 = 23,
}
impl From<ConstantType> for cranelift_codegen::ir::Type {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8,
ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16,
ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32,
ConstantType::I64 | ConstantType::U64 => cranelift_codegen::ir::types::I64,
}
}
}
impl ConstantType {
/// Returns true if the given type is (a) numeric and (b) signed;
pub fn is_signed(&self) -> bool {
matches!(
self,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64
)
}
/// Return the set of types that can be safely casted into this type.
pub fn safe_casts_to(self) -> Vec<ConstantType> {
match self {
ConstantType::I8 => vec![ConstantType::I8],
ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8],
ConstantType::I32 => vec![
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::I64 => vec![
ConstantType::I64,
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::U8 => vec![ConstantType::U8],
ConstantType::U16 => vec![ConstantType::U16, ConstantType::U8],
ConstantType::U32 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8],
ConstantType::U64 => vec![
ConstantType::U64,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
}
}
/// Return the set of all currently-available constant types
pub fn all_types() -> Vec<Self> {
vec![
ConstantType::U8,
ConstantType::U16,
ConstantType::U32,
ConstantType::U64,
ConstantType::I8,
ConstantType::I16,
ConstantType::I32,
ConstantType::I64,
]
}
/// Return the name of the given type, as a string
pub fn name(&self) -> String {
match self {
ConstantType::I8 => "i8".to_string(),
ConstantType::I16 => "i16".to_string(),
ConstantType::I32 => "i32".to_string(),
ConstantType::I64 => "i64".to_string(),
ConstantType::U8 => "u8".to_string(),
ConstantType::U16 => "u16".to_string(),
ConstantType::U32 => "u32".to_string(),
ConstantType::U64 => "u64".to_string(),
}
}
}
#[derive(Debug, Error, PartialEq)]
pub enum InvalidConstantType {
#[error("Unrecognized constant {0} for constant type")]
Value(i64),
}
impl TryFrom<i64> for ConstantType {
type Error = InvalidConstantType;
fn try_from(value: i64) -> Result<Self, Self::Error> {
match value {
10 => Ok(ConstantType::U8),
11 => Ok(ConstantType::U16),
12 => Ok(ConstantType::U32),
13 => Ok(ConstantType::U64),
20 => Ok(ConstantType::I8),
21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64),
_ => Err(InvalidConstantType::Value(value)),
}
}
}
/// Parse a number in the given base, return a pair of the base and the /// Parse a number in the given base, return a pair of the base and the
/// parsed number. This is just a helper used for all of the number /// parsed number. This is just a helper used for all of the number
/// regular expression cases, which kicks off to the obvious Rust /// regular expression cases, which kicks off to the obvious Rust
@@ -129,24 +269,66 @@ impl Token {
fn parse_number( fn parse_number(
base: Option<u8>, base: Option<u8>,
value: &Lexer<Token>, value: &Lexer<Token>,
) -> Result<(Option<u8>, i64), ParseIntError> { ) -> Result<(Option<u8>, Option<ConstantType>, u64), ParseIntError> {
let (radix, strval) = match base { let (radix, strval) = match base {
None => (10, value.slice()), None => (10, value.slice()),
Some(radix) => (radix, &value.slice()[2..]), Some(radix) => (radix, &value.slice()[2..]),
}; };
let intval = i64::from_str_radix(strval, radix as u32)?; let (declared_type, strval) = if let Some(strval) = strval.strip_suffix("u8") {
Ok((base, intval)) (Some(ConstantType::U8), strval)
} else if let Some(strval) = strval.strip_suffix("u16") {
(Some(ConstantType::U16), strval)
} else if let Some(strval) = strval.strip_suffix("u32") {
(Some(ConstantType::U32), strval)
} else if let Some(strval) = strval.strip_suffix("u64") {
(Some(ConstantType::U64), strval)
} else if let Some(strval) = strval.strip_suffix("i8") {
(Some(ConstantType::I8), strval)
} else if let Some(strval) = strval.strip_suffix("i16") {
(Some(ConstantType::I16), strval)
} else if let Some(strval) = strval.strip_suffix("i32") {
(Some(ConstantType::I32), strval)
} else if let Some(strval) = strval.strip_suffix("i64") {
(Some(ConstantType::I64), strval)
} else {
(None, strval)
};
let intval = u64::from_str_radix(strval, radix as u32)?;
Ok((base, declared_type, intval))
}
fn display_optional_type(otype: &Option<ConstantType>) -> &'static str {
match otype {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
} }
#[test] #[test]
fn lex_numbers() { fn lex_numbers() {
let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc // 9"); let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9");
assert_eq!(lex0.next(), Some(Token::Number((None, 12)))); assert_eq!(lex0.next(), Some(Token::Number((None, None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(2), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(2), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(8), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(8), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(10), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(10), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(16), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(16), None, 12))));
assert_eq!(
lex0.next(),
Some(Token::Number((None, Some(ConstantType::U8), 12)))
);
assert_eq!(
lex0.next(),
Some(Token::Number((Some(16), Some(ConstantType::I64), 12)))
);
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
@@ -168,6 +350,31 @@ fn lexer_spans() {
assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(lex0.next(), Some((Token::var("x"), 4..5))); assert_eq!(lex0.next(), Some((Token::var("x"), 4..5)));
assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7))); assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7)));
assert_eq!(lex0.next(), Some((Token::Number((None, 1)), 8..9))); assert_eq!(lex0.next(), Some((Token::Number((None, None, 1)), 8..9)));
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
#[test]
fn further_spans() {
let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned();
assert_eq!(lex0.next(), Some((Token::var("x"), 0..1)));
assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 4..8))
);
assert_eq!(lex0.next(), Some((Token::Operator('+'), 9..10)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 11..15))
);
assert_eq!(lex0.next(), Some((Token::Semi, 15..16)));
assert_eq!(lex0.next(), Some((Token::var("y"), 17..18)));
assert_eq!(lex0.next(), Some((Token::Equals, 19..20)));
assert_eq!(lex0.next(), Some((Token::Operator('-'), 21..22)));
assert_eq!(lex0.next(), Some((Token::var("x"), 22..23)));
assert_eq!(lex0.next(), Some((Token::Semi, 23..24)));
assert_eq!(lex0.next(), Some((Token::Print, 25..30)));
assert_eq!(lex0.next(), Some((Token::var("y"), 31..32)));
assert_eq!(lex0.next(), Some((Token::Semi, 32..33)));
}

View File

@@ -1,6 +1,9 @@
use crate::syntax::{Expression, Location, Program, Statement}; use crate::{
eval::PrimitiveType,
syntax::{Expression, Location, Program, Statement},
};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use std::collections::HashMap; use std::{collections::HashMap, str::FromStr};
/// An error we found while validating the input program. /// An error we found while validating the input program.
/// ///
@@ -11,6 +14,7 @@ use std::collections::HashMap;
/// and using [`codespan_reporting`] to present them to the user. /// and using [`codespan_reporting`] to present them to the user.
pub enum Error { pub enum Error {
UnboundVariable(Location, String), UnboundVariable(Location, String),
UnknownType(Location, String),
} }
impl From<Error> for Diagnostic<usize> { impl From<Error> for Diagnostic<usize> {
@@ -19,6 +23,10 @@ impl From<Error> for Diagnostic<usize> {
Error::UnboundVariable(location, name) => location Error::UnboundVariable(location, name) => location
.labelled_error("unbound here") .labelled_error("unbound here")
.with_message(format!("Unbound variable '{}'", name)), .with_message(format!("Unbound variable '{}'", name)),
Error::UnknownType(location, name) => location
.labelled_error("type referenced here")
.with_message(format!("Unknown type '{}'", name)),
} }
} }
} }
@@ -57,12 +65,24 @@ impl Program {
/// example, and generates warnings for things that are inadvisable but not /// example, and generates warnings for things that are inadvisable but not
/// actually a problem. /// actually a problem.
pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) { pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) {
let mut bound_variables = HashMap::new();
self.validate_with_bindings(&mut bound_variables)
}
/// Validate that the program makes semantic sense, not just syntactic sense.
///
/// This checks for things like references to variables that don't exist, for
/// example, and generates warnings for things that are inadvisable but not
/// actually a problem.
pub fn validate_with_bindings(
&self,
bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) {
let mut errors = vec![]; let mut errors = vec![];
let mut warnings = vec![]; let mut warnings = vec![];
let mut bound_variables = HashMap::new();
for stmt in self.statements.iter() { for stmt in self.statements.iter() {
let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables); let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables);
errors.append(&mut new_errors); errors.append(&mut new_errors);
warnings.append(&mut new_warnings); warnings.append(&mut new_warnings);
} }
@@ -81,7 +101,7 @@ impl Statement {
/// occurs. We use a `HashMap` to map these bound locations to the locations /// occurs. We use a `HashMap` to map these bound locations to the locations
/// where their bound, because these locations are handy when generating errors /// where their bound, because these locations are handy when generating errors
/// and warnings. /// and warnings.
pub fn validate( fn validate(
&self, &self,
bound_variables: &mut HashMap<String, Location>, bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) { ) -> (Vec<Error>, Vec<Warning>) {
@@ -97,20 +117,20 @@ impl Statement {
errors.append(&mut exp_errors); errors.append(&mut exp_errors);
warnings.append(&mut exp_warnings); warnings.append(&mut exp_warnings);
if let Some(original_binding_site) = bound_variables.get(var) { if let Some(original_binding_site) = bound_variables.get(&var.name) {
warnings.push(Warning::ShadowedVariable( warnings.push(Warning::ShadowedVariable(
original_binding_site.clone(), original_binding_site.clone(),
loc.clone(), loc.clone(),
var.clone(), var.to_string(),
)); ));
} else { } else {
bound_variables.insert(var.clone(), loc.clone()); bound_variables.insert(var.to_string(), loc.clone());
} }
} }
Statement::Print(_, var) if bound_variables.contains_key(var) => {} Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {}
Statement::Print(loc, var) => { Statement::Print(loc, var) => {
errors.push(Error::UnboundVariable(loc.clone(), var.clone())) errors.push(Error::UnboundVariable(loc.clone(), var.to_string()))
} }
} }
@@ -127,6 +147,15 @@ impl Expression {
vec![Error::UnboundVariable(loc.clone(), var.clone())], vec![Error::UnboundVariable(loc.clone(), var.clone())],
vec![], vec![],
), ),
Expression::Cast(location, t, expr) => {
let (mut errs, warns) = expr.validate(variable_map);
if PrimitiveType::from_str(t).is_err() {
errs.push(Error::UnknownType(location.clone(), t.clone()))
}
(errs, warns)
}
Expression::Primitive(_, _, args) => { Expression::Primitive(_, _, args) => {
let mut errors = vec![]; let mut errors = vec![];
let mut warnings = vec![]; let mut warnings = vec![];
@@ -142,3 +171,19 @@ impl Expression {
} }
} }
} }
#[test]
fn cast_checks_are_reasonable() {
let good_stmt = Statement::parse(0, "x = <u16>4u8;").expect("valid test case");
let (good_errs, good_warns) = good_stmt.validate(&mut HashMap::new());
assert!(good_errs.is_empty());
assert!(good_warns.is_empty());
let bad_stmt = Statement::parse(0, "x = <apple>4u8;").expect("valid test case");
let (bad_errs, bad_warns) = bad_stmt.validate(&mut HashMap::new());
assert!(bad_warns.is_empty());
assert_eq!(bad_errs.len(), 1);
assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple"));
}

52
src/type_infer.rs Normal file
View File

@@ -0,0 +1,52 @@
//! A type inference pass for NGR
//!
//! The type checker implemented here is a relatively straightforward one, designed to be
//! fairly easy to understand rather than super fast. So don't be expecting the fastest
//! type checker in the West, here.
//!
//! The actual type checker operates in three phases. In the first phase, we translate
//! the syntax AST into something that's close to the final IR. During the process, we
//! generate a list of type constraints to solve. In the second phase, we try to solve
//! all the constraints we've generated. If that's successful, in the final phase, we
//! do the final conversion to the IR AST, filling in any type information we've learned
//! along the way.
mod ast;
mod convert;
mod finalize;
mod solve;
use self::convert::convert_program;
use self::finalize::finalize_program;
use self::solve::solve_constraints;
pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning};
use crate::ir::ast as ir;
use crate::syntax;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
#[cfg(test)]
use proptest::prelude::Arbitrary;
impl syntax::Program {
/// Infer the types for the syntactic AST, returning either a type-checked program in
/// the IR, or a series of type errors encountered during inference.
///
/// You really should have made sure that this program was validated before running
/// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> TypeInferenceResult<ir::Program> {
let mut constraint_db = vec![];
let program = convert_program(self, &mut constraint_db);
let inference_result = solve_constraints(constraint_db);
inference_result.map(|resolutions| finalize_program(program, &resolutions))
}
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) {
let syntax_result = input.eval();
let ir = input.type_infer().expect("arbitrary should generate type-safe programs");
let ir_result = ir.eval();
proptest::prop_assert_eq!(syntax_result, ir_result);
}
}

336
src/type_infer/ast.rs Normal file
View File

@@ -0,0 +1,336 @@
pub use crate::ir::ast::Primitive;
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
/// to want to use while we're doing type inference, but don't want to keep around
/// afterwards. These are:
///
/// * A notion of a type variable
/// * An unknown numeric constant form
///
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use std::fmt;
use std::sync::atomic::AtomicUsize;
/// We're going to represent variables as interned strings.
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file.
///
/// In addition, for the moment there's not really much of interest to hold here besides
/// the list of statements read from the file. Order is important. In the future, you
/// could imagine caching analysis information in this structure.
///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
/// The representation of a statement in the language.
///
/// For now, this is either a binding site (`x = 4`) or a print statement
/// (`print x`). Someday, though, more!
///
/// As with `Program`, this type implements [`Pretty`], which should
/// be used to display the structure whenever possible. It does not
/// implement [`Arbitrary`], though, mostly because it's slightly
/// complicated to do so.
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
}
}
}
/// The representation of an expression.
///
/// Note that expressions, like everything else in this syntax tree,
/// supports [`Pretty`], and it's strongly encouraged that you use
/// that trait/module when printing these structures.
///
/// Also, Expressions at this point in the compiler are explicitly
/// defined so that they are *not* recursive. By this point, if an
/// expression requires some other data (like, for example, invoking
/// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a
/// variable reference.
#[derive(Debug, PartialEq)]
pub enum Expression {
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Clone, Debug, PartialEq)]
pub enum ValueOrRef {
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
Expression::Atomic(value)
}
}
/// A constant in the IR.
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Unknown(Option<u8>, u64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::Unknown(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
}
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType),
}
impl Type {
pub fn is_concrete(&self) -> bool {
!matches!(self, Type::Variable(_, _))
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f),
}
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), name)
}

378
src/type_infer/convert.rs Normal file
View File

@@ -0,0 +1,378 @@
use super::ast as ir;
use super::ast::Type;
use crate::eval::PrimitiveType;
use crate::syntax::{self, ConstantType};
use crate::type_infer::solve::Constraint;
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
///
/// If the input function has been validated (which it should be), then this should run
/// into no error conditions. However, if you failed to validate the input, then this
/// function can panic.
pub fn convert_program(
mut program: syntax::Program,
constraint_db: &mut Vec<Constraint>,
) -> ir::Program {
let mut statements = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
for stmt in program.statements.drain(..) {
statements.append(&mut convert_statement(
stmt,
constraint_db,
&mut renames,
&mut bindings,
));
}
ir::Program { statements }
}
/// This function takes a syntactic statements and converts it into a series of
/// IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name.to_string());
let final_name = renames
.get(&iname)
.map(Clone::clone)
.unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
}
syntax::Statement::Binding(loc, name, expr) => {
let (mut prereqs, expr, ty) =
convert_expression(expr, constraint_db, renames, bindings);
let iname = ArcIntern::new(name.to_string());
let final_name = if bindings.contains_key(&iname) {
let new_name = ir::gensym(iname.as_str());
renames.insert(iname, new_name.clone());
new_name
} else {
iname
};
bindings.insert(final_name.clone(), ty.clone());
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
prereqs
}
}
}
/// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression {
syntax::Expression::Value(loc, val) => match val {
syntax::Value::Number(base, mctype, value) => {
let (newval, newtype) = match mctype {
None => {
let newtype = ir::gentype();
let newval = ir::Value::Unknown(base, value);
constraint_db.push(Constraint::ConstantNumericType(
loc.clone(),
newtype.clone(),
));
(newval, newtype)
}
Some(ConstantType::U8) => (
ir::Value::U8(base, value as u8),
ir::Type::Primitive(PrimitiveType::U8),
),
Some(ConstantType::U16) => (
ir::Value::U16(base, value as u16),
ir::Type::Primitive(PrimitiveType::U16),
),
Some(ConstantType::U32) => (
ir::Value::U32(base, value as u32),
ir::Type::Primitive(PrimitiveType::U32),
),
Some(ConstantType::U64) => (
ir::Value::U64(base, value),
ir::Type::Primitive(PrimitiveType::U64),
),
Some(ConstantType::I8) => (
ir::Value::I8(base, value as i8),
ir::Type::Primitive(PrimitiveType::I8),
),
Some(ConstantType::I16) => (
ir::Value::I16(base, value as i16),
ir::Type::Primitive(PrimitiveType::I16),
),
Some(ConstantType::I32) => (
ir::Value::I32(base, value as i32),
ir::Type::Primitive(PrimitiveType::I32),
),
Some(ConstantType::I64) => (
ir::Value::I64(base, value as i64),
ir::Type::Primitive(PrimitiveType::I64),
),
};
constraint_db.push(Constraint::FitsInNumType(
loc.clone(),
newtype.clone(),
value,
));
(
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
newtype,
)
}
},
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
let rtype = bindings
.get(&final_name)
.cloned()
.expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype)
}
syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) =
convert_expression(*expr, constraint_db, renames, bindings);
let val_or_ref = simplify_expr(nexpr, &mut stmts);
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
let target_type = Type::Primitive(target_prim_type);
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
(stmts, res, target_type)
}
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = ir::gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) =
convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(
loc.clone(),
primop,
atypes.clone(),
ret_type.clone(),
));
(
stmts,
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
ret_type,
)
}
}
}
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref,
expr => {
let etype = expr.type_of().clone();
let loc = expr.location().clone();
let nname = ir::gensym("g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
stmts.push(nbinding);
ir::ValueOrRef::Ref(loc, etype, nname)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::syntax::Location;
fn one() -> syntax::Expression {
syntax::Expression::Value(
Location::manufactured(),
syntax::Value::Number(None, None, 1),
)
}
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool {
for x in x.iter() {
if f(x) {
return true;
}
}
false
}
fn infer_expression(
x: syntax::Expression,
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) {
let mut constraints = Vec::new();
let renames = HashMap::new();
let mut bindings = HashMap::new();
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
(expr, stmts, constraints, ty)
}
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
let mut constraints = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
(res, constraints)
}
#[test]
fn constant_one() {
let (expr, stmts, constraints, ty) = infer_expression(one());
assert!(stmts.is_empty());
assert!(matches!(
expr,
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty)
));
}
#[test]
fn one_plus_one() {
let opo = syntax::Expression::Primitive(
Location::manufactured(),
"+".to_string(),
vec![one(), one()],
);
let (expr, stmts, constraints, ty) = infer_expression(opo);
assert!(stmts.is_empty());
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
));
}
#[test]
fn one_plus_one_plus_one() {
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
let (stmts, constraints) = infer_statement(stmt);
assert_eq!(stmts.len(), 2);
let ir::Statement::Binding(
_args,
name1,
temp_ty1,
ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1),
) = stmts.get(0).expect("item two")
else {
panic!("Failed to match first statement");
};
let ir::Statement::Binding(
_args,
name2,
temp_ty2,
ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2),
) = stmts.get(1).expect("item two")
else {
panic!("Failed to match second statement");
};
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] =
&primargs1[..]
else {
panic!("Failed to match first arguments");
};
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] =
&primargs2[..]
else {
panic!("Failed to match first arguments");
};
assert_ne!(name1, name2);
assert_ne!(temp_ty1, temp_ty2);
assert_ne!(primty1, primty2);
assert_eq!(name1, left2name);
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty)
));
for (i, s) in stmts.iter().enumerate() {
println!("{}: {:?}", i, s);
}
for (i, c) in constraints.iter().enumerate() {
println!("{}: {:?}", i, c);
}
}
}

187
src/type_infer/finalize.rs Normal file
View File

@@ -0,0 +1,187 @@
use super::{ast as input, solve::TypeResolutions};
use crate::{eval::PrimitiveType, ir as output};
pub fn finalize_program(
mut program: input::Program,
resolutions: &TypeResolutions,
) -> output::Program {
output::Program {
statements: program
.statements
.drain(..)
.map(|x| finalize_statement(x, resolutions))
.collect(),
}
}
fn finalize_statement(
statement: input::Statement,
resolutions: &TypeResolutions,
) -> output::Statement {
match statement {
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding(
loc,
var,
finalize_type(ty, resolutions),
finalize_expression(expr, resolutions),
),
input::Statement::Print(loc, ty, var) => {
output::Statement::Print(loc, finalize_type(ty, resolutions), var)
}
}
}
fn finalize_expression(
expression: input::Expression,
resolutions: &TypeResolutions,
) -> output::Expression {
match expression {
input::Expression::Atomic(val_or_ref) => {
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
}
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
loc,
finalize_type(target, resolutions),
finalize_val_or_ref(val_or_ref, resolutions),
),
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
loc,
finalize_type(ty, resolutions),
prim,
args.drain(..)
.map(|x| finalize_val_or_ref(x, resolutions))
.collect(),
),
}
}
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type {
match ty {
input::Type::Primitive(x) => output::Type::Primitive(x),
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => output::Type::Primitive(*pt),
},
}
}
fn finalize_val_or_ref(
valref: input::ValueOrRef,
resolutions: &TypeResolutions,
) -> output::ValueOrRef {
match valref {
input::ValueOrRef::Ref(loc, ty, var) => {
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var)
}
input::ValueOrRef::Value(loc, ty, val) => {
let new_type = finalize_type(ty, resolutions);
match val {
input::Value::Unknown(base, value) => match new_type {
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U8(base, value as u8),
),
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U16(base, value as u16),
),
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U32(base, value as u32),
),
output::Type::Primitive(PrimitiveType::U64) => {
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I8(base, value as i8),
),
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I16(base, value as i16),
),
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I32(base, value as i32),
),
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I64(base, value as i64),
),
},
input::Value::U8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
}
input::Value::U16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
}
input::Value::U32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
}
input::Value::U64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
input::Value::I8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
}
input::Value::I16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
}
input::Value::I32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
}
input::Value::I64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
}
}
}
}
}

542
src/type_infer/solve.rs Normal file
View File

@@ -0,0 +1,542 @@
use super::ast as ir;
use super::ast::Type;
use crate::{eval::PrimitiveType, syntax::Location};
use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern;
use std::{collections::HashMap, fmt};
/// A type inference constraint that we're going to need to solve.
#[derive(Debug)]
pub enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, Type),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, Type, u64),
/// The given primitive has the proper arguments types associated with it
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
/// The given type can be casted to the target type safely
CanCastTo(Location, Type, Type),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, Type),
/// The given type is attached to a constant and must be some numeric type.
/// If we can't figure it out, we should warn the user and then just use a
/// default.
ConstantNumericType(Location, Type),
/// The two types should be equivalent
Equivalent(Location, Type, Type),
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => {
write!(f, "PRIM {} {} -> {}", op, args[0], ret)
}
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => {
write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret)
}
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
}
}
}
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
/// The results of type inference; like [`Result`], but with a bit more information.
///
/// This result is parameterized, because sometimes it's handy to return slightly
/// different things; there's a [`TypeInferenceResult::map`] function for performing
/// those sorts of conversions.
pub enum TypeInferenceResult<Result> {
Success {
result: Result,
warnings: Vec<TypeInferenceWarning>,
},
Failure {
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
},
}
impl<R> TypeInferenceResult<R> {
// If this was a successful type inference, run the function over the result to
// create a new result.
//
// This is the moral equivalent of [`Result::map`], but for type inference results.
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
where
F: FnOnce(R) -> U,
{
match self {
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
result: f(result),
warnings,
},
TypeInferenceResult::Failure { errors, warnings } => {
TypeInferenceResult::Failure { errors, warnings }
}
}
}
// Return the final result, or panic if it's not a success
pub fn expect(self, msg: &str) -> R {
match self {
TypeInferenceResult::Success { result, .. } => result,
TypeInferenceResult::Failure { .. } => {
panic!("tried to get value from failed type inference: {}", msg)
}
}
}
}
/// The various kinds of errors that can occur while doing type inference.
pub enum TypeInferenceError {
/// The user provide a constant that is too large for its inferred type.
ConstantTooLarge(Location, PrimitiveType, u64),
/// The two types needed to be equivalent, but weren't.
NotEquivalent(Location, PrimitiveType, PrimitiveType),
/// We cannot safely cast the first type to the second type.
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
/// The primitive invocation provided the wrong number of arguments.
WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize),
/// We had a constraint we just couldn't solve.
CouldNotSolve(Constraint),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
fn from(value: TypeInferenceError) -> Self {
match value {
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
.labelled_error("constant too large for type")
.with_message(format!(
"Type {} has a max value of {}, which is smaller than {}",
primty,
primty.max_value(),
value
)),
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
.labelled_error("type inference error")
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
.labelled_error("unsafe type cast")
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
.labelled_error("wrong number of arguments")
.with_message(format!(
"expected {} for {}, received {}",
if lower == upper && lower > 1 {
format!("{} arguments", lower)
} else if lower == upper {
format!("{} argument", lower)
} else {
format!("{}-{} arguments", lower, upper)
},
prim,
observed
)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {:#?}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {:#?} were equivalent",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not determine if {} could fit in {}",
val, ty
))
}
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if {} was a numeric type", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) =>
panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty),
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if type {} was printable", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not tell if primitive {} received the proper argument types",
prim
))
}
}
}
}
/// Warnings that we might want to tell the user about.
///
/// These are fine, probably, but could indicate some behavior the user might not
/// expect, and so they might want to do something about them.
pub enum TypeInferenceWarning {
DefaultedTo(Location, Type),
}
impl From<TypeInferenceWarning> for Diagnostic<usize> {
fn from(value: TypeInferenceWarning) -> Self {
match value {
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
.with_labels(vec![loc.primary_label().with_message("unknown type")])
.with_message(format!("Defaulted unknown type to {}", ty)),
}
}
}
/// Solve all the constraints in the provided database.
///
/// This process can take a bit, so you might not want to do it multiple times. Basically,
/// it's going to grind on these constraints until either it figures them out, or it stops
/// making progress. I haven't done the math on the constraints to even figure out if this
/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time.
///
/// The return value is a type inference result, which pairs some warnings with either a
/// successful set of type resolutions (mappings from type variables to their values), or
/// a series of inference errors.
pub fn solve_constraints(
mut constraint_db: Vec<Constraint>,
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
let mut warnings = vec![];
let mut resolutions = HashMap::new();
let mut changed_something = true;
// We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop.
while changed_something && !constraint_db.is_empty() {
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
changed_something = false;
// This is sort of a double-buffering thing; we're going to rename constraint_db
// and then set it to a new empty vector, which we'll add to as we find new
// constraints or find ourselves unable to solve existing ones.
let mut local_constraints = constraint_db;
constraint_db = vec![];
// OK. First thing we're going to do is run through all of our constraints,
// and see if we can solve any, or reduce them to theoretically more simple
// constraints.
for constraint in local_constraints.drain(..) {
match constraint {
// Currently, all of our types are printable
Constraint::Printable(_loc, _ty) => changed_something = true,
// Case #1: We have two primitive types. If they're equal, we've discharged this
// constraint! We can just continue. If they're not equal, add an error and then
// see what else we come up with.
Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => {
if t1 != t2 {
errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2));
}
changed_something = true;
}
// Case #2: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name))
| Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions.
Constraint::Equivalent(
ref loc,
Type::Variable(_, ref name1),
Type::Variable(_, ref name2),
) => match (resolutions.get(name1), resolutions.get(name2)) {
(None, None) => {
constraint_db.push(constraint);
}
(Some(pt), None) => {
resolutions.insert(name2.clone(), *pt);
changed_something = true;
}
(None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt);
changed_something = true;
}
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
changed_something = true;
}
(Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2));
changed_something = true;
}
},
// Make sure that the provided number fits within the provided constant type. For the
// moment, we're going to call an error here a failure, although this could be a
// warning in the future.
Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => {
if ctype.max_value() < val {
errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val));
}
changed_something = true;
}
// If we have a non-constant type, then let's see if we can advance this to a constant
// type
Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Variable(vloc, var),
val,
)),
Some(nt) => {
constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Primitive(*nt),
val,
));
changed_something = true;
}
}
}
// If the left type in a "can cast to" check is a variable, let's see if we can advance
// it into something more tangible
Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
Type::Variable(vloc, var),
to_type,
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
Type::Primitive(*nt),
to_type,
));
changed_something = true;
}
}
}
// If the right type in a "can cast to" check is a variable, same deal
Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Primitive(*nt),
));
changed_something = true;
}
}
}
// If both of them are types, then we can actually do the test. yay!
Constraint::CanCastTo(
loc,
Type::Primitive(from_type),
Type::Primitive(to_type),
) => {
if !from_type.can_cast_to(&to_type) {
errors.push(TypeInferenceError::CannotSafelyCast(
loc, from_type, to_type,
));
}
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::NumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db
.push(Constraint::NumericType(loc, Type::Variable(vloc, var))),
Some(nt) => {
constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::NumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::ConstantNumericType(
loc,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db
.push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::ConstantNumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// OK, this one could be a little tricky if we tried to do it all at once, but
// instead what we're going to do here is just use this constraint to generate
// a bunch more constraints, and then go have the engine solve those. The only
// real errors we're going to come up with here are "arity errors"; errors we
// find by discovering that the number of arguments provided doesn't make sense
// given the primitive being used.
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide
if args.len() != 2 =>
{
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
2,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc, left, ret));
changed_something = true;
}
ir::Primitive::Minus if args.is_empty() || args.len() > 2 => {
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
1,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Minus if args.len() == 1 => {
let arg = args.pop().expect("1 > 0");
constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(loc, arg, ret));
changed_something = true;
}
ir::Primitive::Minus => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret));
changed_something = true;
}
},
}
}
// If that didn't actually come up with anything, and we just recycled all the constraints
// back into the database unchanged, then let's take a look for cases in which we just
// wanted something we didn't know to be a number. Basically, those are cases where the
// user just wrote a number, but didn't tell us what type it was, and there isn't enough
// information in the context to tell us. If that happens, we'll just set that type to
// be u64, and warn the user that we did so.
if !changed_something && !constraint_db.is_empty() {
local_constraints = constraint_db;
constraint_db = vec![];
for constraint in local_constraints.drain(..) {
match constraint {
Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => {
let resty = Type::Primitive(PrimitiveType::U64);
constraint_db.push(Constraint::Equivalent(
loc.clone(),
t,
Type::Primitive(PrimitiveType::U64),
));
warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty));
changed_something = true;
}
_ => constraint_db.push(constraint),
}
}
}
}
// OK, we left our loop. Which means that either we solved everything, or we didn't.
// If we didn't, turn the unsolved constraints into type inference errors, and add
// them to our error list.
let mut unsolved_constraint_errors = constraint_db
.drain(..)
.map(TypeInferenceError::CouldNotSolve)
.collect();
errors.append(&mut unsolved_constraint_errors);
// How'd we do?
if errors.is_empty() {
TypeInferenceResult::Success {
result: resolutions,
warnings,
}
} else {
TypeInferenceResult::Failure { errors, warnings }
}
}