🤔 Add a type inference engine, along with typed literals. #4

Merged
acw merged 25 commits from acw/type-checker into develop 2023-09-19 20:40:05 -07:00
44 changed files with 3258 additions and 702 deletions

3
.gitignore vendored
View File

@@ -2,6 +2,9 @@
Cargo.lock Cargo.lock
**/*.o **/*.o
test test
test.exe
test.ilk
test.pdb
*.dSYM *.dSYM
.vscode .vscode
proptest-regressions/ proptest-regressions/

View File

@@ -12,12 +12,12 @@ path = "src/lib.rs"
clap = { version = "^3.0.14", features = ["derive"] } clap = { version = "^3.0.14", features = ["derive"] }
codespan = "0.11.1" codespan = "0.11.1"
codespan-reporting = "0.11.1" codespan-reporting = "0.11.1"
cranelift-codegen = "0.94.0" cranelift-codegen = "0.99.2"
cranelift-jit = "0.94.0" cranelift-jit = "0.99.2"
cranelift-frontend = "0.94.0" cranelift-frontend = "0.99.2"
cranelift-module = "0.94.0" cranelift-module = "0.99.2"
cranelift-native = "0.94.0" cranelift-native = "0.99.2"
cranelift-object = "0.94.0" cranelift-object = "0.99.2"
internment = { version = "0.7.0", default-features = false, features = ["arc"] } internment = { version = "0.7.0", default-features = false, features = ["arc"] }
lalrpop-util = "^0.20.0" lalrpop-util = "^0.20.0"
lazy_static = "^1.4.0" lazy_static = "^1.4.0"

View File

@@ -1,5 +1,88 @@
extern crate lalrpop; extern crate lalrpop;
use std::env;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
fn main() { fn main() {
lalrpop::process_root().unwrap(); lalrpop::process_root().unwrap();
if let Err(e) = generate_example_tests() {
eprintln!("Failure building example tests: {}", e);
std::process::exit(3);
}
}
fn generate_example_tests() -> std::io::Result<()> {
let out_dir = env::var_os("OUT_DIR").expect("OUT_DIR");
let dest_path = Path::new(&out_dir).join("examples.rs");
let mut output = File::create(dest_path)?;
generate_tests(&mut output, PathBuf::from("examples"))?;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=example.rs");
Ok(())
}
fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> {
for entry in fs::read_dir(path_so_far)? {
let entry = entry?;
let file_name = entry
.file_name()
.into_string()
.expect("reasonable file name");
let name = file_name.trim_end_matches(".ngr");
if entry.file_type()?.is_dir() {
writeln!(f, "mod {} {{", name)?;
writeln!(f, " use super::*;")?;
generate_tests(f, entry.path())?;
writeln!(f, "}}")?;
} else {
writeln!(f, "#[test]")?;
writeln!(f, "fn {}() {{", name)?;
writeln!(f, " let mut file_database = SimpleFiles::new();")?;
writeln!(
f,
" let syntax = Syntax::parse_file(&mut file_database, {:?});",
entry.path().display()
)?;
if entry.path().to_string_lossy().contains("broken") {
writeln!(f, " if syntax.is_err() {{")?;
writeln!(f, " return;")?;
writeln!(f, " }}")?;
writeln!(f, " let (errors, _) = syntax.unwrap().validate();")?;
writeln!(
f,
" assert_ne!(errors.len(), 0, \"should have seen an error\");"
)?;
} else {
// NOTE: Since the advent of defaulting rules and type checking, we
// can't guarantee that syntax.eval() will return the same result as
// ir.eval() or backend::eval(). We must do type checking to force
// constants into the right types, first. So this now checks only that
// the result of ir.eval() and backend::eval() are the same.
writeln!(
f,
" let syntax = syntax.expect(\"file should have parsed\");"
)?;
writeln!(f, " let (errors, _) = syntax.validate();")?;
writeln!(
f,
" assert_eq!(errors.len(), 0, \"file should have no validation errors\");"
)?;
writeln!(
f,
" let ir = syntax.type_infer().expect(\"example is typed correctly\");"
)?;
writeln!(f, " let ir_result = ir.eval();")?;
writeln!(f, " let compiled_result = Backend::<JITModule>::eval(ir);")?;
writeln!(f, " assert_eq!(ir_result, compiled_result);")?;
}
writeln!(f, "}}")?;
}
}
Ok(())
} }

View File

@@ -0,0 +1,4 @@
x = 5;
x = 4*x + 3;
print x;
print y;

7
examples/basic/cast1.ngr Normal file
View File

@@ -0,0 +1,7 @@
x8 = 5u8;
x16 = <u16>x8 + 1u16;
print x16;
x32 = <u32>x16 + 1u32;
print x32;
x64 = <u64>x32 + 1u64;
print x64;

7
examples/basic/cast2.ngr Normal file
View File

@@ -0,0 +1,7 @@
x8 = 5i8;
x16 = <i16>x8 - 1i16;
print x16;
x32 = <i32>x16 - 1i32;
print x32;
x64 = <i64>x32 - 1i64;
print x64;

View File

@@ -0,0 +1,3 @@
x = 2i64 + 2i64;
y = -x;
print y;

View File

@@ -1,4 +1,4 @@
x = 5; // this test is useful for making sure we don't accidentally
x = 4*x + 3; // use a signed divide operation anywhere important
print x; a = 96u8 / 160u8;
print y; print a;

View File

@@ -0,0 +1,6 @@
x = 1 + 1u16;
print x;
y = 1u16 + 1;
print y;
z = 1 + 1 + 1;
print z;

View File

@@ -2,12 +2,33 @@
#include <stdio.h> #include <stdio.h>
#include <inttypes.h> #include <inttypes.h>
void print(char *_ignore, char *variable_name, int64_t value) { void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) {
printf("%s = %" PRId64 "i64\n", variable_name, value); switch(vtype) {
case /* U8 = */ 10:
printf("%s = %" PRIu8 "u8\n", variable_name, (uint8_t)value);
break;
case /* U16 = */ 11:
printf("%s = %" PRIu16 "u16\n", variable_name, (uint16_t)value);
break;
case /* U32 = */ 12:
printf("%s = %" PRIu32 "u32\n", variable_name, (uint32_t)value);
break;
case /* U64 = */ 13:
printf("%s = %" PRIu64 "u64\n", variable_name, (uint64_t)value);
break;
case /* I8 = */ 20:
printf("%s = %" PRIi8 "i8\n", variable_name, (int8_t)value);
break;
case /* I16 = */ 21:
printf("%s = %" PRIi16 "i16\n", variable_name, (int16_t)value);
break;
case /* I32 = */ 22:
printf("%s = %" PRIi32 "i32\n", variable_name, (int32_t)value);
break;
case /* I64 = */ 23:
printf("%s = %" PRIi64 "i64\n", variable_name, value);
break;
} }
void caller() {
print(NULL, "x", 4);
} }
extern void gogogo(); extern void gogogo();

View File

@@ -31,15 +31,15 @@ mod eval;
mod into_crane; mod into_crane;
mod runtime; mod runtime;
use std::collections::HashMap;
pub use self::error::BackendError; pub use self::error::BackendError;
pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions};
use crate::syntax::ConstantType;
use cranelift_codegen::settings::Configurable; use cranelift_codegen::settings::Configurable;
use cranelift_codegen::{isa, settings}; use cranelift_codegen::{isa, settings};
use cranelift_jit::{JITBuilder, JITModule}; use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{default_libcall_names, DataContext, DataId, FuncId, Linkage, Module}; use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module};
use cranelift_object::{ObjectBuilder, ObjectModule}; use cranelift_object::{ObjectBuilder, ObjectModule};
use std::collections::HashMap;
use target_lexicon::Triple; use target_lexicon::Triple;
const EMPTY_DATUM: [u8; 8] = [0; 8]; const EMPTY_DATUM: [u8; 8] = [0; 8];
@@ -55,11 +55,12 @@ const EMPTY_DATUM: [u8; 8] = [0; 8];
/// implementations. /// implementations.
pub struct Backend<M: Module> { pub struct Backend<M: Module> {
pub module: M, pub module: M,
data_ctx: DataContext, data_ctx: DataDescription,
runtime_functions: RuntimeFunctions, runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>, defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, DataId>, defined_symbols: HashMap<String, (DataId, ConstantType)>,
output_buffer: Option<String>, output_buffer: Option<String>,
platform: Triple,
} }
impl Backend<JITModule> { impl Backend<JITModule> {
@@ -85,11 +86,12 @@ impl Backend<JITModule> {
Ok(Backend { Ok(Backend {
module, module,
data_ctx: DataContext::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer, output_buffer,
platform: Triple::host(),
}) })
} }
@@ -123,11 +125,12 @@ impl Backend<ObjectModule> {
Ok(Backend { Ok(Backend {
module, module,
data_ctx: DataContext::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer: None, output_buffer: None,
platform,
}) })
} }
@@ -154,7 +157,7 @@ impl<M: Module> Backend<M> {
let global_id = self let global_id = self
.module .module
.declare_data(&name, Linkage::Local, false, false)?; .declare_data(&name, Linkage::Local, false, false)?;
let mut data_context = DataContext::new(); let mut data_context = DataDescription::new();
data_context.set_align(8); data_context.set_align(8);
data_context.define(s0.into_boxed_str().into_boxed_bytes()); data_context.define(s0.into_boxed_str().into_boxed_bytes());
self.module.define_data(global_id, &data_context)?; self.module.define_data(global_id, &data_context)?;
@@ -167,14 +170,18 @@ impl<M: Module> Backend<M> {
/// These variables can be shared between functions, and will be exported from the /// These variables can be shared between functions, and will be exported from the
/// module itself as public data in the case of static compilation. There initial /// module itself as public data in the case of static compilation. There initial
/// value will be null. /// value will be null.
pub fn define_variable(&mut self, name: String) -> Result<DataId, BackendError> { pub fn define_variable(
&mut self,
name: String,
ctype: ConstantType,
) -> Result<DataId, BackendError> {
self.data_ctx.define(Box::new(EMPTY_DATUM)); self.data_ctx.define(Box::new(EMPTY_DATUM));
let id = self let id = self
.module .module
.declare_data(&name, Linkage::Export, true, false)?; .declare_data(&name, Linkage::Export, true, false)?;
self.module.define_data(id, &self.data_ctx)?; self.module.define_data(id, &self.data_ctx)?;
self.data_ctx.clear(); self.data_ctx.clear();
self.defined_symbols.insert(name, id); self.defined_symbols.insert(name, (id, ctype));
Ok(id) Ok(id)
} }

View File

@@ -1,4 +1,4 @@
use crate::backend::runtime::RuntimeFunctionError; use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError}; use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError};
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
@@ -39,6 +39,8 @@ pub enum BackendError {
LookupError(#[from] LookupError), LookupError(#[from] LookupError),
#[error(transparent)] #[error(transparent)]
Write(#[from] cranelift_object::object::write::Error), Write(#[from] cranelift_object::object::write::Error),
#[error("Invalid type cast from {from} to {to}")]
InvalidTypeCast { from: PrimitiveType, to: Type },
} }
impl From<BackendError> for Diagnostic<usize> { impl From<BackendError> for Diagnostic<usize> {
@@ -64,6 +66,9 @@ impl From<BackendError> for Diagnostic<usize> {
BackendError::Write(me) => { BackendError::Write(me) => {
Diagnostic::error().with_message(format!("Cranelift object write error: {}", me)) Diagnostic::error().with_message(format!("Cranelift object write error: {}", me))
} }
BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message(
format!("Internal error trying to cast from {} to {}", from, to),
),
} }
} }
} }
@@ -103,6 +108,17 @@ impl PartialEq for BackendError {
BackendError::Write(b) => a == b, BackendError::Write(b) => a == b,
_ => false, _ => false,
}, },
BackendError::InvalidTypeCast {
from: from1,
to: to1,
} => match other {
BackendError::InvalidTypeCast {
from: from2,
to: to2,
} => from1 == from2 && to1 == to2,
_ => false,
},
} }
} }
} }

View File

@@ -1,10 +1,13 @@
use std::path::Path;
use crate::backend::Backend; use crate::backend::Backend;
use crate::eval::EvalError; use crate::eval::EvalError;
use crate::ir::Program; use crate::ir::Program;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
use cranelift_jit::JITModule; use cranelift_jit::JITModule;
use cranelift_object::ObjectModule; use cranelift_object::ObjectModule;
#[cfg(test)]
use proptest::arbitrary::Arbitrary;
use std::path::Path;
use target_lexicon::Triple; use target_lexicon::Triple;
impl Backend<JITModule> { impl Backend<JITModule> {
@@ -28,7 +31,8 @@ impl Backend<JITModule> {
let compiled_bytes = jitter.bytes(function_id); let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function(); compiled_function();
Ok(jitter.output()) let output = jitter.output();
Ok(output)
} }
} }
@@ -116,7 +120,7 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn static_backend(program: Program) { fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
let basic_result = program.eval(); let basic_result = program.eval();
@@ -127,8 +131,18 @@ proptest::proptest! {
let basic_result = basic_result.map(|x| x.replace('\n', "\r\n")); let basic_result = basic_result.map(|x| x.replace('\n', "\r\n"));
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let compiled_result = Backend::<ObjectModule>::eval(program); let compiled_result = Backend::<ObjectModule>::eval(program);
assert_eq!(basic_result, compiled_result); proptest::prop_assert_eq!(basic_result, compiled_result);
} }
} }
@@ -136,14 +150,24 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn jit_backend(program: Program) { fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let basic_result = program.eval(); let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program); let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result); proptest::prop_assert_eq!(basic_result, compiled_result);
} }
} }
} }

View File

@@ -1,9 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef}; use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, Statement, Type, Value, ValueOrRef};
use crate::syntax::ConstantType;
use cranelift_codegen::entity::EntityRef; use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::{ use cranelift_codegen::ir::{
entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName, self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
}; };
use cranelift_codegen::isa::CallConv; use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context; use cranelift_codegen::Context;
@@ -41,7 +43,7 @@ impl<M: Module> Backend<M> {
let basic_signature = Signature { let basic_signature = Signature {
params: vec![], params: vec![],
returns: vec![], returns: vec![],
call_conv: CallConv::SystemV, call_conv: CallConv::triple_default(&self.platform),
}; };
// this generates the handle for the function that we'll eventually want to // this generates the handle for the function that we'll eventually want to
@@ -85,12 +87,12 @@ impl<M: Module> Backend<M> {
// Just like with strings, generating the `GlobalValue`s we need can potentially // Just like with strings, generating the `GlobalValue`s we need can potentially
// be a little tricky to do on the fly, so we generate the complete list right // be a little tricky to do on the fly, so we generate the complete list right
// here and then use it later. // here and then use it later.
let pre_defined_symbols: HashMap<String, GlobalValue> = self let pre_defined_symbols: HashMap<String, (GlobalValue, ConstantType)> = self
.defined_symbols .defined_symbols
.iter() .iter()
.map(|(k, v)| { .map(|(k, (v, t))| {
let local_data = self.module.declare_data_in_func(*v, &mut ctx.func); let local_data = self.module.declare_data_in_func(*v, &mut ctx.func);
(k.clone(), local_data) (k.clone(), (local_data, *t))
}) })
.collect(); .collect();
@@ -123,7 +125,7 @@ impl<M: Module> Backend<M> {
// Print statements are fairly easy to compile: we just lookup the // Print statements are fairly easy to compile: we just lookup the
// output buffer, the address of the string to print, and the value // output buffer, the address of the string to print, and the value
// of whatever variable we're printing. Then we just call print. // of whatever variable we're printing. Then we just call print.
Statement::Print(ann, var) => { Statement::Print(ann, t, var) => {
// Get the output buffer (or null) from our general compilation context. // Get the output buffer (or null) from our general compilation context.
let buffer_ptr = self.output_buffer_ptr(); let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
@@ -135,31 +137,47 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a // Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn // global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation. // this into an `Expression` and re-use the logic in that implementation.
let val = Expression::Reference(ann, var).into_crane( let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane(
&mut builder, &mut builder,
&variable_table, &variable_table,
&pre_defined_symbols, &pre_defined_symbols,
)?; )?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
let casted_val = match vtype {
ConstantType::U64 | ConstantType::I64 => val,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => {
builder.ins().sextend(types::I64, val)
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => {
builder.ins().uextend(types::I64, val)
}
};
// Finally, we can generate the call to print. // Finally, we can generate the call to print.
builder builder.ins().call(
.ins() print_func_ref,
.call(print_func_ref, &[buffer_ptr, name_ptr, val]); &[buffer_ptr, name_ptr, vtype_repr, casted_val],
);
} }
// Variable binding is a little more con // Variable binding is a little more con
Statement::Binding(_, var_name, value) => { Statement::Binding(_, var_name, _, value) => {
// Kick off to the `Expression` implementation to see what value we're going // Kick off to the `Expression` implementation to see what value we're going
// to bind to this variable. // to bind to this variable.
let val = let (val, etype) =
value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?; value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?;
// Now the question is: is this a local variable, or a global one? // Now the question is: is this a local variable, or a global one?
if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) { if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) {
// It's a global variable! In this case, we assume that someone has already // It's a global variable! In this case, we assume that someone has already
// dedicated some space in memory to store this value. We look this location // dedicated some space in memory to store this value. We look this location
// up, and then tell Cranelift to store the value there. // up, and then tell Cranelift to store the value there.
let val_ptr = builder.ins().symbol_value(types::I64, *global_id); assert_eq!(etype, *ctype);
let val_ptr = builder
.ins()
.symbol_value(ir::Type::from(*ctype), *global_id);
builder.ins().store(MemFlags::new(), val, val_ptr, 0); builder.ins().store(MemFlags::new(), val, val_ptr, 0);
} else { } else {
// It's a local variable! In this case, we need to allocate a new Cranelift // It's a local variable! In this case, we need to allocate a new Cranelift
@@ -171,12 +189,10 @@ impl<M: Module> Backend<M> {
next_var_num += 1; next_var_num += 1;
// We can add the variable directly to our local variable map; it's `Copy`. // We can add the variable directly to our local variable map; it's `Copy`.
variable_table.insert(var_name, var); variable_table.insert(var_name, (var, etype));
// Now we tell Cranelift about our new variable, which has type I64 because // Now we tell Cranelift about our new variable!
// everything we have at this point is of type I64. Once it's declare, we builder.declare_var(var, ir::Type::from(etype));
// define it as having the value we computed above.
builder.declare_var(var, types::I64);
builder.def_var(var, val); builder.def_var(var, val);
} }
} }
@@ -195,7 +211,7 @@ impl<M: Module> Backend<M> {
// so we register it using the function ID and our builder context. However, the // so we register it using the function ID and our builder context. However, the
// result of this function isn't actually super helpful. So we ignore it, unless // result of this function isn't actually super helpful. So we ignore it, unless
// it's an error. // it's an error.
let _ = self.module.define_function(func_id, &mut ctx)?; self.module.define_function(func_id, &mut ctx)?;
// done! // done!
Ok(func_id) Ok(func_id)
@@ -231,54 +247,110 @@ impl Expression {
fn into_crane( fn into_crane(
self, self,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>, local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, GlobalValue>, global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<entities::Value, BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
match self { match self {
// Values are pretty straightforward to compile, mostly because we only Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables),
// have one type of variable, and it's an integer type.
Expression::Value(_, Value::Number(_, v)) => Ok(builder.ins().iconst(types::I64, v)),
Expression::Reference(_, name) => { Expression::Cast(_, target_type, expr) => {
// first we see if this is a local variable (which is nicer, from an let (val, val_type) =
// optimization point of view.) expr.into_crane(builder, local_variables, global_variables)?;
if let Some(local_var) = local_variables.get(&name) {
return Ok(builder.use_var(*local_var)); match (val_type, &target_type) {
(ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)),
(ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().sextend(types::I16, val), ConstantType::I16))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
} }
// then we check to see if this is a global reference, which requires us to (ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)),
// first lookup where the value is stored, and then load it. (ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => {
if let Some(global_var) = global_variables.get(name.as_ref()) { Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
let val_ptr = builder.ins().symbol_value(types::I64, *global_var); }
return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0)); (ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
} }
// this should never happen, because we should have made sure that there are (ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)),
// no unbound variables a long time before this. but still ... (ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => {
Err(BackendError::VariableLookupFailure(name)) Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
} }
Expression::Primitive(_, prim, mut vals) => { (ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)),
// we're going to use `pop`, so we're going to pull and compile the right value ...
let right = (ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)),
vals.pop() (ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => {
.unwrap() Ok((builder.ins().uextend(types::I16, val), ConstantType::U16))
.into_crane(builder, local_variables, global_variables)?; }
// ... and then the left. (ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => {
let left = Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
vals.pop() }
.unwrap() (ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => {
.into_crane(builder, local_variables, global_variables)?; Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)),
(ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)),
(ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
_ => Err(BackendError::InvalidTypeCast {
from: val_type.into(),
to: target_type,
}),
}
}
Expression::Primitive(_, _, prim, mut vals) => {
let mut values = vec![];
let mut first_type = None;
for val in vals.drain(..) {
let (compiled, compiled_type) =
val.into_crane(builder, local_variables, global_variables)?;
if let Some(leftmost_type) = first_type {
assert_eq!(leftmost_type, compiled_type);
} else {
first_type = Some(compiled_type);
}
values.push(compiled);
}
let first_type = first_type.expect("primitive op has at least one argument");
// then we just need to tell Cranelift how to do each of our primitives! Much // then we just need to tell Cranelift how to do each of our primitives! Much
// like Statements, above, we probably want to eventually shuffle this off into // like Statements, above, we probably want to eventually shuffle this off into
// a separate function (maybe something off `Primitive`), but for now it's simple // a separate function (maybe something off `Primitive`), but for now it's simple
// enough that we just do the `match` here. // enough that we just do the `match` here.
match prim { match prim {
Primitive::Plus => Ok(builder.ins().iadd(left, right)), Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)),
Primitive::Minus => Ok(builder.ins().isub(left, right)), Primitive::Minus if values.len() == 2 => {
Primitive::Times => Ok(builder.ins().imul(left, right)), Ok((builder.ins().isub(values[0], values[1]), first_type))
Primitive::Divide => Ok(builder.ins().sdiv(left, right)), }
Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)),
Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)),
Primitive::Divide if first_type.is_signed() => {
Ok((builder.ins().sdiv(values[0], values[1]), first_type))
}
Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)),
} }
} }
} }
@@ -291,9 +363,66 @@ impl ValueOrRef {
fn into_crane( fn into_crane(
self, self,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>, local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, GlobalValue>, global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<entities::Value, BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
Expression::from(self).into_crane(builder, local_variables, global_variables) match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
ValueOrRef::Value(_, _, val) => match val {
Value::I8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8))
}
Value::I16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::I16,
)),
Value::I32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::I32,
)),
Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)),
Value::U8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8))
}
Value::U16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::U16,
)),
Value::U32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::U32,
)),
Value::U64(_, v) => Ok((
builder.ins().iconst(types::I64, v as i64),
ConstantType::U64,
)),
},
ValueOrRef::Ref(_, _, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some((local_var, etype)) = local_variables.get(&name) {
return Ok((builder.use_var(*local_var), *etype));
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some((global_var, etype)) = global_variables.get(name.as_ref()) {
let cranelift_type = ir::Type::from(*etype);
let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var);
return Ok((
builder
.ins()
.load(cranelift_type, MemFlags::new(), val_ptr, 0),
*etype,
));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
}
} }
} }

View File

@@ -8,6 +8,8 @@ use std::fmt::Write;
use target_lexicon::Triple; use target_lexicon::Triple;
use thiserror::Error; use thiserror::Error;
use crate::syntax::ConstantType;
/// An object for querying / using functions built into the runtime. /// An object for querying / using functions built into the runtime.
/// ///
/// Right now, this is a quite a bit of boilerplate for very nebulous /// Right now, this is a quite a bit of boilerplate for very nebulous
@@ -49,7 +51,7 @@ impl RuntimeFunctions {
"print", "print",
Linkage::Import, Linkage::Import,
&Signature { &Signature {
params: vec![string_param, string_param, int64_param], params: vec![string_param, string_param, int64_param, int64_param],
returns: vec![], returns: vec![],
call_conv: CallConv::triple_default(platform), call_conv: CallConv::triple_default(platform),
}, },
@@ -98,13 +100,30 @@ impl RuntimeFunctions {
// we extend with the output, so that multiple JIT'd `Program`s can run concurrently // we extend with the output, so that multiple JIT'd `Program`s can run concurrently
// without stomping over each other's output. If `output_buffer` is NULL, we just print // without stomping over each other's output. If `output_buffer` is NULL, we just print
// to stdout. // to stdout.
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) { extern "C" fn runtime_print(
output_buffer: *mut String,
name: *const i8,
vtype_repr: i64,
value: i64,
) {
let cstr = unsafe { CStr::from_ptr(name) }; let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy(); let reconstituted = cstr.to_string_lossy();
let output = match vtype_repr.try_into() {
Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8),
Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16),
Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32),
Ok(ConstantType::I64) => format!("{} = {}i64", reconstituted, value),
Ok(ConstantType::U8) => format!("{} = {}u8", reconstituted, value as u8),
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value),
};
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap(); writeln!(output_buffer, "{}", output).unwrap();
} else { } else {
println!("{} = {}", reconstituted, value); println!("{}", output);
} }
} }

View File

@@ -17,7 +17,7 @@ fn main() {
let args = CommandLineArguments::parse(); let args = CommandLineArguments::parse();
let mut compiler = ngr::Compiler::default(); let mut compiler = ngr::Compiler::default();
let output_file = args.output.unwrap_or("output.o".to_string()); let output_file = args.output.unwrap_or_else(|| "output.o".to_string());
if let Some(bytes) = compiler.compile(&args.file) { if let Some(bytes) = compiler.compile(&args.file) {
std::fs::write(&output_file, bytes) std::fs::write(&output_file, bytes)

View File

@@ -1,6 +1,5 @@
use crate::backend::Backend;
use crate::ir::Program as IR;
use crate::syntax::Program as Syntax; use crate::syntax::Program as Syntax;
use crate::{backend::Backend, type_infer::TypeInferenceResult};
use codespan_reporting::{ use codespan_reporting::{
diagnostic::Diagnostic, diagnostic::Diagnostic,
files::SimpleFiles, files::SimpleFiles,
@@ -100,8 +99,38 @@ impl Compiler {
return Ok(None); return Ok(None);
} }
// Now that we've validated it, turn it into IR. // Now that we've validated it, let's do type inference, potentially turning
let ir = IR::from(syntax); // into IR while we're at it.
let ir = match syntax.type_infer() {
TypeInferenceResult::Failure {
mut errors,
mut warnings,
} => {
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit(message);
}
return Ok(None);
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
let messages = warnings.drain(..).map(Into::into);
for message in messages {
self.emit(message);
}
result
}
};
// Finally, send all this to Cranelift for conversion into an object file. // Finally, send all this to Cranelift for conversion into an object file.
let mut backend = Backend::object_file(Triple::host())?; let mut backend = Backend::object_file(Triple::host())?;

View File

@@ -35,11 +35,13 @@
//! //!
mod env; mod env;
mod primop; mod primop;
mod primtype;
mod value; mod value;
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError}; pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError; pub use primop::PrimOpError;
pub use primtype::PrimitiveType;
pub use value::Value; pub use value::Value;
use crate::backend::BackendError; use crate::backend::BackendError;

View File

@@ -87,9 +87,9 @@ mod tests {
let tester = tester.extend(arced("bar"), 2i64.into()); let tester = tester.extend(arced("bar"), 2i64.into());
let tester = tester.extend(arced("goo"), 5i64.into()); let tester = tester.extend(arced("goo"), 5i64.into());
assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(tester.lookup(arced("bar")), Ok(2.into())); assert_eq!(tester.lookup(arced("bar")), Ok(2i64.into()));
assert_eq!(tester.lookup(arced("goo")), Ok(5.into())); assert_eq!(tester.lookup(arced("goo")), Ok(5i64.into()));
assert!(tester.lookup(arced("baz")).is_err()); assert!(tester.lookup(arced("baz")).is_err());
} }
@@ -103,14 +103,14 @@ mod tests {
check_nested(&tester); check_nested(&tester);
assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert!(tester.lookup(arced("bar")).is_err()); assert!(tester.lookup(arced("bar")).is_err());
} }
fn check_nested(env: &EvalEnvironment) { fn check_nested(env: &EvalEnvironment) {
let nested_env = env.extend(arced("bar"), 2i64.into()); let nested_env = env.extend(arced("bar"), 2i64.into());
assert_eq!(nested_env.lookup(arced("foo")), Ok(1.into())); assert_eq!(nested_env.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(nested_env.lookup(arced("bar")), Ok(2.into())); assert_eq!(nested_env.lookup(arced("bar")), Ok(2i64.into()));
} }
fn arced(s: &str) -> ArcIntern<String> { fn arced(s: &str) -> ArcIntern<String> {

View File

@@ -1,3 +1,4 @@
use crate::eval::primtype::PrimitiveType;
use crate::eval::value::Value; use crate::eval::value::Value;
/// Errors that can occur running primitive operations in the evaluators. /// Errors that can occur running primitive operations in the evaluators.
@@ -22,6 +23,13 @@ pub enum PrimOpError {
BadArgCount(String, usize), BadArgCount(String, usize),
#[error("Unknown primitive operation {0}")] #[error("Unknown primitive operation {0}")]
UnknownPrimOp(String), UnknownPrimOp(String),
#[error("Unsafe cast from {from} to {to}")]
UnsafeCast {
from: PrimitiveType,
to: PrimitiveType,
},
#[error("Unknown primitive type {0}")]
UnknownPrimType(String),
} }
// Implementing primitives in an interpreter like this is *super* tedious, // Implementing primitives in an interpreter like this is *super* tedious,
@@ -37,39 +45,95 @@ pub enum PrimOpError {
macro_rules! run_op { macro_rules! run_op {
($op: ident, $left: expr, $right: expr) => { ($op: ident, $left: expr, $right: expr) => {
match $op { match $op {
"+" => $left "+" => Ok($left.wrapping_add($right).into()),
.checked_add($right) "-" => Ok($left.wrapping_sub($right).into()),
.ok_or(PrimOpError::MathFailure("+")) "*" => Ok($left.wrapping_mul($right).into()),
.map(Into::into), "/" if $right == 0 => Err(PrimOpError::MathFailure("/")),
"-" => $left "/" => Ok($left.wrapping_div($right).into()),
.checked_sub($right)
.ok_or(PrimOpError::MathFailure("-"))
.map(Into::into),
"*" => $left
.checked_mul($right)
.ok_or(PrimOpError::MathFailure("*"))
.map(Into::into),
"/" => $left
.checked_div($right)
.ok_or(PrimOpError::MathFailure("/"))
.map(Into::into),
_ => Err(PrimOpError::UnknownPrimOp($op.to_string())), _ => Err(PrimOpError::UnknownPrimOp($op.to_string())),
} }
}; };
} }
impl Value { impl Value {
fn unary_op(operation: &str, value: &Value) -> Result<Value, PrimOpError> {
match operation {
"-" => match value {
Value::I8(x) => Ok(Value::I8(x.wrapping_neg())),
Value::I16(x) => Ok(Value::I16(x.wrapping_neg())),
Value::I32(x) => Ok(Value::I32(x.wrapping_neg())),
Value::I64(x) => Ok(Value::I64(x.wrapping_neg())),
_ => Err(PrimOpError::BadTypeFor("-", value.clone())),
},
_ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)),
}
}
fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> { fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> {
match left { match left {
// for now we only have one type, but in the future this is Value::I8(x) => match right {
// going to be very irritating. Value::I8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I16(x) => match right {
Value::I16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I32(x) => match right {
Value::I32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I64(x) => match right { Value::I64(x) => match right {
Value::I64(y) => run_op!(operation, x, *y), Value::I64(y) => run_op!(operation, x, *y),
// _ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
// operation.to_string(), operation.to_string(),
// left.clone(), left.clone(),
// right.clone(), right.clone(),
// )), )),
},
Value::U8(x) => match right {
Value::U8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U16(x) => match right {
Value::U16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U32(x) => match right {
Value::U32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U64(x) => match right {
Value::U64(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
}, },
} }
} }
@@ -83,13 +147,10 @@ impl Value {
/// its worth being careful to make sure that your inputs won't cause either /// its worth being careful to make sure that your inputs won't cause either
/// condition. /// condition.
pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> { pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> {
if values.len() == 2 { match values.len() {
Value::binary_op(operation, &values[0], &values[1]) 1 => Value::unary_op(operation, &values[0]),
} else { 2 => Value::binary_op(operation, &values[0], &values[1]),
Err(PrimOpError::BadArgCount( x => Err(PrimOpError::BadArgCount(operation.to_string(), x)),
operation.to_string(),
values.len(),
))
} }
} }
} }

173
src/eval/primtype.rs Normal file
View File

@@ -0,0 +1,173 @@
use crate::{
eval::{PrimOpError, Value},
syntax::ConstantType,
};
use std::{fmt::Display, str::FromStr};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum PrimitiveType {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::I8 => write!(f, "i8"),
PrimitiveType::I16 => write!(f, "i16"),
PrimitiveType::I32 => write!(f, "i32"),
PrimitiveType::I64 => write!(f, "i64"),
PrimitiveType::U8 => write!(f, "u8"),
PrimitiveType::U16 => write!(f, "u16"),
PrimitiveType::U32 => write!(f, "u32"),
PrimitiveType::U64 => write!(f, "u64"),
}
}
}
impl<'a> From<&'a Value> for PrimitiveType {
fn from(value: &Value) -> Self {
match value {
Value::I8(_) => PrimitiveType::I8,
Value::I16(_) => PrimitiveType::I16,
Value::I32(_) => PrimitiveType::I32,
Value::I64(_) => PrimitiveType::I64,
Value::U8(_) => PrimitiveType::U8,
Value::U16(_) => PrimitiveType::U16,
Value::U32(_) => PrimitiveType::U32,
Value::U64(_) => PrimitiveType::U64,
}
}
}
impl From<ConstantType> for PrimitiveType {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 => PrimitiveType::I8,
ConstantType::I16 => PrimitiveType::I16,
ConstantType::I32 => PrimitiveType::I32,
ConstantType::I64 => PrimitiveType::I64,
ConstantType::U8 => PrimitiveType::U8,
ConstantType::U16 => PrimitiveType::U16,
ConstantType::U32 => PrimitiveType::U32,
ConstantType::U64 => PrimitiveType::U64,
}
}
}
impl FromStr for PrimitiveType {
type Err = PrimOpError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"i8" => Ok(PrimitiveType::I8),
"i16" => Ok(PrimitiveType::I16),
"i32" => Ok(PrimitiveType::I32),
"i64" => Ok(PrimitiveType::I64),
"u8" => Ok(PrimitiveType::U8),
"u16" => Ok(PrimitiveType::U16),
"u32" => Ok(PrimitiveType::U32),
"u64" => Ok(PrimitiveType::U64),
_ => Err(PrimOpError::UnknownPrimType(s.to_string())),
}
}
}
impl PrimitiveType {
/// Return true if this type can be safely cast into the target type.
pub fn can_cast_to(&self, target: &PrimitiveType) -> bool {
match self {
PrimitiveType::U8 => matches!(
target,
PrimitiveType::U8
| PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I16
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U16 => matches!(
target,
PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U32 => matches!(
target,
PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64
),
PrimitiveType::U64 => target == &PrimitiveType::U64,
PrimitiveType::I8 => matches!(
target,
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I16 => matches!(
target,
PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64),
PrimitiveType::I64 => target == &PrimitiveType::I64,
}
}
/// Try to cast the given value to this type, returning the new value.
///
/// Returns an error if the cast is not safe *in* *general*. This means that
/// this function will error even if the number will actually fit in the target
/// type, but it would not be generally safe to cast a member of the given
/// type to the target type. (So, for example, "1i64" is a number that could
/// work as a "u64", but since negative numbers wouldn't work, a cast from
/// "1i64" to "u64" will fail.)
pub fn safe_cast(&self, source: &Value) -> Result<Value, PrimOpError> {
match (self, source) {
(PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)),
(PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)),
(PrimitiveType::U16, Value::U16(x)) => Ok(Value::U16(*x)),
(PrimitiveType::U32, Value::U8(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U16(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U32(x)) => Ok(Value::U32(*x)),
(PrimitiveType::U64, Value::U8(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U16(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U32(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U64(x)) => Ok(Value::U64(*x)),
(PrimitiveType::I8, Value::I8(x)) => Ok(Value::I8(*x)),
(PrimitiveType::I16, Value::I8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I16, Value::I16(x)) => Ok(Value::I16(*x)),
(PrimitiveType::I32, Value::I8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I32(x)) => Ok(Value::I32(*x)),
(PrimitiveType::I64, Value::I8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
_ => Err(PrimOpError::UnsafeCast {
from: source.into(),
to: *self,
}),
}
}
pub fn max_value(&self) -> u64 {
match self {
PrimitiveType::U8 => u8::MAX as u64,
PrimitiveType::U16 => u16::MAX as u64,
PrimitiveType::U32 => u32::MAX as u64,
PrimitiveType::U64 => u64::MAX,
PrimitiveType::I8 => i8::MAX as u64,
PrimitiveType::I16 => i16::MAX as u64,
PrimitiveType::I32 => i32::MAX as u64,
PrimitiveType::I64 => i64::MAX as u64,
}
}
}

View File

@@ -7,19 +7,75 @@ use std::fmt::Display;
/// by type so that we don't mix them up. /// by type so that we don't mix them up.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Value { pub enum Value {
I8(i8),
I16(i16),
I32(i32),
I64(i64), I64(i64),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
} }
impl Display for Value { impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Value::I8(x) => write!(f, "{}i8", x),
Value::I16(x) => write!(f, "{}i16", x),
Value::I32(x) => write!(f, "{}i32", x),
Value::I64(x) => write!(f, "{}i64", x), Value::I64(x) => write!(f, "{}i64", x),
Value::U8(x) => write!(f, "{}u8", x),
Value::U16(x) => write!(f, "{}u16", x),
Value::U32(x) => write!(f, "{}u32", x),
Value::U64(x) => write!(f, "{}u64", x),
} }
} }
} }
impl From<i8> for Value {
fn from(value: i8) -> Self {
Value::I8(value)
}
}
impl From<i16> for Value {
fn from(value: i16) -> Self {
Value::I16(value)
}
}
impl From<i32> for Value {
fn from(value: i32) -> Self {
Value::I32(value)
}
}
impl From<i64> for Value { impl From<i64> for Value {
fn from(value: i64) -> Self { fn from(value: i64) -> Self {
Value::I64(value) Value::I64(value)
} }
} }
impl From<u8> for Value {
fn from(value: u8) -> Self {
Value::U8(value)
}
}
impl From<u16> for Value {
fn from(value: u16) -> Self {
Value::U16(value)
}
}
impl From<u32> for Value {
fn from(value: u32) -> Self {
Value::U32(value)
}
}
impl From<u64> for Value {
fn from(value: u64) -> Self {
Value::U64(value)
}
}

6
src/examples.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::backend::Backend;
use crate::syntax::Program as Syntax;
use codespan_reporting::files::SimpleFiles;
use cranelift_jit::JITModule;
include!(concat!(env!("OUT_DIR"), "/examples.rs"));

View File

@@ -12,9 +12,8 @@
//! validating syntax, and then figuring out how to turn it into Cranelift //! validating syntax, and then figuring out how to turn it into Cranelift
//! and object code. After that point, however, this will be the module to //! and object code. After that point, however, this will be the module to
//! come to for analysis and optimization work. //! come to for analysis and optimization work.
mod ast; pub mod ast;
mod eval; mod eval;
mod from_syntax;
mod strings; mod strings;
pub use ast::*; pub use ast::*;

View File

@@ -1,10 +1,14 @@
use crate::syntax::Location; use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern; use internment::ArcIntern;
use pretty::{DocAllocator, Pretty}; use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::{ use proptest::{
prelude::Arbitrary, prelude::Arbitrary,
strategy::{BoxedStrategy, Strategy}, strategy::{BoxedStrategy, Strategy},
}; };
use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings. /// We're going to represent variables as interned strings.
/// ///
@@ -52,12 +56,15 @@ where
} }
impl Arbitrary for Program { impl Arbitrary for Program {
type Parameters = (); type Parameters = crate::syntax::arbitrary::GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args) crate::syntax::Program::arbitrary_with(args)
.prop_map(Program::from) .prop_map(|x| {
x.type_infer()
.expect("arbitrary_with should generate type-correct programs")
})
.boxed() .boxed()
} }
} }
@@ -74,8 +81,8 @@ impl Arbitrary for Program {
/// ///
#[derive(Debug)] #[derive(Debug)]
pub enum Statement { pub enum Statement {
Binding(Location, Variable, Expression), Binding(Location, Variable, Type, Expression),
Print(Location, Variable), Print(Location, Type, Variable),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
@@ -85,13 +92,13 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Statement::Binding(_, var, expr) => allocator Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string()) .text(var.as_ref().to_string())
.append(allocator.space()) .append(allocator.space())
.append(allocator.text("=")) .append(allocator.text("="))
.append(allocator.space()) .append(allocator.space())
.append(expr.pretty(allocator)), .append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator Statement::Print(_, _, var) => allocator
.text("print") .text("print")
.append(allocator.space()) .append(allocator.space())
.append(allocator.text(var.as_ref().to_string())), .append(allocator.text(var.as_ref().to_string())),
@@ -113,9 +120,32 @@ where
/// variable reference. /// variable reference.
#[derive(Debug)] #[derive(Debug)]
pub enum Expression { pub enum Expression {
Value(Location, Value), Atomic(ValueOrRef),
Reference(Location, Variable), Cast(Location, Type, ValueOrRef),
Primitive(Location, Primitive, Vec<ValueOrRef>), Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
@@ -125,12 +155,16 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Expression::Value(_, val) => val.pretty(allocator), Expression::Atomic(x) => x.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.as_ref().to_string()), Expression::Cast(_, t, e) => allocator
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => { .text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator)) op.pretty(allocator).append(exprs[0].pretty(allocator))
} }
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator); let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator); let right = exprs[1].pretty(allocator);
@@ -140,7 +174,7 @@ where
.append(right) .append(right)
.parens() .parens()
} }
Expression::Primitive(_, op, exprs) => { Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
} }
} }
@@ -161,10 +195,10 @@ pub enum Primitive {
Divide, Divide,
} }
impl<'a> TryFrom<&'a str> for Primitive { impl FromStr for Primitive {
type Error = String; type Err = String;
fn try_from(value: &str) -> Result<Self, Self::Error> { fn from_str(value: &str) -> Result<Self, Self::Err> {
match value { match value {
"+" => Ok(Primitive::Plus), "+" => Ok(Primitive::Plus),
"-" => Ok(Primitive::Minus), "-" => Ok(Primitive::Minus),
@@ -190,15 +224,21 @@ where
} }
} }
impl fmt::Display for Primitive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f)
}
}
/// An expression that is always either a value or a reference. /// An expression that is always either a value or a reference.
/// ///
/// This is the type used to guarantee that we don't nest expressions /// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one /// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference. /// of these, which can only be a constant or a reference.
#[derive(Debug)] #[derive(Clone, Debug)]
pub enum ValueOrRef { pub enum ValueOrRef {
Value(Location, Value), Value(Location, Type, Value),
Ref(Location, ArcIntern<String>), Ref(Location, Type, ArcIntern<String>),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
@@ -208,30 +248,50 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
ValueOrRef::Value(_, v) => v.pretty(allocator), ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()), ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
} }
} }
} }
impl From<ValueOrRef> for Expression { impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self { fn from(value: ValueOrRef) -> Self {
match value { Expression::Atomic(value)
ValueOrRef::Value(loc, val) => Expression::Value(loc, val),
ValueOrRef::Ref(loc, var) => Expression::Reference(loc, var),
}
} }
} }
/// A constant in the IR. /// A constant in the IR.
#[derive(Debug)]
pub enum Value {
/// A numerical constant.
/// ///
/// The optional argument is the base that was used by the user to input /// The optional argument in numeric types is the base that was used by the
/// the number. By retaining it, we can ensure that if we need to print the /// user to input the number. By retaining it, we can ensure that if we need
/// number back out, we can do so in the form that the user entered it. /// to print the number back out, we can do so in the form that the user
Number(Option<u8>, i64), /// entered it.
#[derive(Clone, Debug)]
pub enum Value {
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl Value {
/// Return the type described by this value
pub fn type_of(&self) -> Type {
match self {
Value::I8(_, _) => Type::Primitive(PrimitiveType::I8),
Value::I16(_, _) => Type::Primitive(PrimitiveType::I16),
Value::I32(_, _) => Type::Primitive(PrimitiveType::I32),
Value::I64(_, _) => Type::Primitive(PrimitiveType::I64),
Value::U8(_, _) => Type::Primitive(PrimitiveType::U8),
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
}
}
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
@@ -240,19 +300,64 @@ where
D: ?Sized + DocAllocator<'a, A>, D: ?Sized + DocAllocator<'a, A>,
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { let pretty_internal = |opt_base: &Option<u8>, x, t| {
Value::Number(opt_base, value) => { syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
let value_str = match opt_base {
None => format!("{}", value),
Some(2) => format!("0b{:b}", value),
Some(8) => format!("0o{:o}", value),
Some(10) => format!("0d{}", value),
Some(16) => format!("0x{:x}", value),
Some(_) => format!("!!{:x}!!", value),
}; };
allocator.text(value_str) let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
} }
} }
} }
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Primitive(PrimitiveType),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Primitive(pt) => pt.fmt(f),
}
}
} }

View File

@@ -1,7 +1,7 @@
use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::eval::{EvalEnvironment, EvalError, Value};
use crate::ir::{Expression, Program, Statement}; use crate::ir::{Expression, Program, Statement};
use super::{Primitive, ValueOrRef}; use super::{Primitive, Type, ValueOrRef};
impl Program { impl Program {
/// Evaluate the program, returning either an error or a string containing everything /// Evaluate the program, returning either an error or a string containing everything
@@ -14,12 +14,12 @@ impl Program {
for stmt in self.statements.iter() { for stmt in self.statements.iter() {
match stmt { match stmt {
Statement::Binding(_, name, value) => { Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?; let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value); env = env.extend(name.clone(), actual_value);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
let value = env.lookup(name.clone())?; let value = env.lookup(name.clone())?;
let line = format!("{} = {}\n", name, value); let line = format!("{} = {}\n", name, value);
stdout.push_str(&line); stdout.push_str(&line);
@@ -34,27 +34,22 @@ impl Program {
impl Expression { impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self { match self {
Expression::Value(_, v) => match v { Expression::Atomic(x) => x.eval(env),
super::Value::Number(_, v) => Ok(Value::I64(*v)),
},
Expression::Reference(_, n) => Ok(env.lookup(n.clone())?), Expression::Cast(_, t, valref) => {
let value = valref.eval(env)?;
Expression::Primitive(_, op, args) => { match t {
let mut arg_values = Vec::with_capacity(args.len()); Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
// we implement primitive operations by first evaluating each of the
// arguments to the function, and then gathering up all the values
// produced.
for arg in args.iter() {
match arg {
ValueOrRef::Ref(_, n) => arg_values.push(env.lookup(n.clone())?),
ValueOrRef::Value(_, super::Value::Number(_, v)) => {
arg_values.push(Value::I64(*v))
}
} }
} }
Expression::Primitive(_, _, op, args) => {
let arg_values = args
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice // and then finally we call `calculate` to run them. trust me, it's nice
// to not have to deal with all the nonsense hidden under `calculate`. // to not have to deal with all the nonsense hidden under `calculate`.
match op { match op {
@@ -68,19 +63,38 @@ impl Expression {
} }
} }
impl ValueOrRef {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(*v)),
super::Value::I16(_, v) => Ok(Value::I16(*v)),
super::Value::I32(_, v) => Ok(Value::I32(*v)),
super::Value::I64(_, v) => Ok(Value::I64(*v)),
super::Value::U8(_, v) => Ok(Value::U8(*v)),
super::Value::U16(_, v) => Ok(Value::U16(*v)),
super::Value::U32(_, v) => Ok(Value::U32(*v)),
super::Value::U64(_, v) => Ok(Value::U64(*v)),
},
ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?),
}
}
}
#[test] #[test]
fn two_plus_three() { fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -1,186 +0,0 @@
use internment::ArcIntern;
use std::sync::atomic::AtomicUsize;
use crate::ir::ast as ir;
use crate::syntax;
use super::ValueOrRef;
impl From<syntax::Program> for ir::Program {
/// We implement the top-level conversion of a syntax::Program into an
/// ir::Program using just the standard `From::from`, because we don't
/// need to return any arguments and we shouldn't produce any errors.
/// Technically there's an `unwrap` deep under the hood that we could
/// float out, but the validator really should've made sure that never
/// happens, so we're just going to assume.
fn from(mut value: syntax::Program) -> Self {
let mut statements = Vec::new();
for stmt in value.statements.drain(..) {
statements.append(&mut stmt.simplify());
}
ir::Program { statements }
}
}
impl From<syntax::Statement> for ir::Program {
/// One interesting thing about this conversion is that there isn't
/// a natural translation from syntax::Statement to ir::Statement,
/// because the syntax version can have nested expressions and the
/// IR version can't.
///
/// As a result, we can naturally convert a syntax::Statement into
/// an ir::Program, because we can allow the additional binding
/// sites to be generated, instead. And, bonus, it turns out that
/// this is what we wanted anyways.
fn from(value: syntax::Statement) -> Self {
ir::Program {
statements: value.simplify(),
}
}
}
impl syntax::Statement {
/// Simplify a syntax::Statement into a series of ir::Statements.
///
/// The reason this function is one-to-many is because we may have to
/// introduce new binding sites in order to avoid having nested
/// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed
/// in syntax::Expression but are expressly *not* allowed in
/// ir::Expression. So this pass converts them into bindings, like
/// this:
///
/// x = (1 + 2) * 3;
///
/// ==>
///
/// x:1 = 1 + 2;
/// x:2 = x:1 * 3;
/// x = x:2
///
/// Thus ensuring that things are nice and simple. Note that the
/// binding of `x:2` is not, strictly speaking, necessary, but it
/// makes the code below much easier to read.
fn simplify(self) -> Vec<ir::Statement> {
let mut new_statements = vec![];
match self {
// Print statements we don't have to do much with
syntax::Statement::Print(loc, name) => {
new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name)))
}
// Bindings, however, may involve a single expression turning into
// a series of statements and then an expression.
syntax::Statement::Binding(loc, name, value) => {
let (mut prereqs, new_value) = value.rebind(&name);
new_statements.append(&mut prereqs);
new_statements.push(ir::Statement::Binding(
loc,
ArcIntern::new(name),
new_value.into(),
))
}
}
new_statements
}
}
impl syntax::Expression {
/// This actually does the meat of the simplification work, here, by rebinding
/// any nested expressions into their own variables. We have this return
/// `ValueOrRef` in all cases because it makes for slighly less code; in the
/// case when we actually want an `Expression`, we can just use `into()`.
fn rebind(self, base_name: &str) -> (Vec<ir::Statement>, ir::ValueOrRef) {
match self {
// Values just convert in the obvious way, and require no prereqs
syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())),
// Similarly, references just convert in the obvious way, and require
// no prereqs
syntax::Expression::Reference(loc, name) => {
(vec![], ValueOrRef::Ref(loc, ArcIntern::new(name)))
}
// Primitive expressions are where we do the real work.
syntax::Expression::Primitive(loc, prim, mut expressions) => {
// generate a fresh new name for the binding site we're going to
// introduce, basing the name on wherever we came from; so if this
// expression was bound to `x` originally, it might become `x:23`.
//
// gensym is guaranteed to give us a name that is unused anywhere
// else in the program.
let new_name = gensym(base_name);
let mut prereqs = Vec::new();
let mut new_exprs = Vec::new();
// here we loop through every argument, and recurse on the expressions
// we find. that will give us any new binding sites that *they* introduce,
// and a simple value or reference that we can use in our result.
for expr in expressions.drain(..) {
let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str());
prereqs.append(&mut cur_prereqs);
new_exprs.push(arg);
}
// now we're going to use those new arguments to run the primitive, binding
// the results to the new variable we introduced.
let prim =
ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function");
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Primitive(loc.clone(), prim, new_exprs),
));
// and finally, we can return all the new bindings, and a reference to
// the variable we just introduced to hold the value of the primitive
// invocation.
(prereqs, ValueOrRef::Ref(loc, new_name))
}
}
}
}
impl From<syntax::Value> for ir::Value {
fn from(value: syntax::Value) -> Self {
match value {
syntax::Value::Number(base, val) => ir::Value::Number(base, val),
}
}
}
impl From<String> for ir::Primitive {
fn from(value: String) -> Self {
value.try_into().unwrap()
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = ir::Program::from(input);
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -21,12 +21,12 @@ impl Program {
impl Statement { impl Statement {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) { fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self { match self {
Statement::Binding(_, name, expr) => { Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
expr.register_strings(string_set); expr.register_strings(string_set);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
} }
} }

View File

@@ -63,8 +63,11 @@
//! //!
pub mod backend; pub mod backend;
pub mod eval; pub mod eval;
#[cfg(test)]
mod examples;
pub mod ir; pub mod ir;
pub mod syntax; pub mod syntax;
pub mod type_infer;
/// Implementation module for the high-level compiler. /// Implementation module for the high-level compiler.
mod compiler; mod compiler;

View File

@@ -1,6 +1,6 @@
use crate::backend::{Backend, BackendError}; use crate::backend::{Backend, BackendError};
use crate::ir::Program as IR; use crate::syntax::{ConstantType, Location, ParserError, Statement};
use crate::syntax::{Location, ParserError, Statement}; use crate::type_infer::TypeInferenceResult;
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use codespan_reporting::files::SimpleFiles; use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::{self, Config}; use codespan_reporting::term::{self, Config};
@@ -129,17 +129,32 @@ impl REPL {
.source(); .source();
let syntax = Statement::parse(entry, source)?; let syntax = Statement::parse(entry, source)?;
let program = match syntax {
Statement::Binding(loc, name, expr) => {
// if this is a variable binding, and we've never defined this variable before, // if this is a variable binding, and we've never defined this variable before,
// we should tell cranelift about it. this is optimistic; if we fail to compile, // we should tell cranelift about it. this is optimistic; if we fail to compile,
// then we won't use this definition until someone tries again. // then we won't use this definition until someone tries again.
if let Statement::Binding(_, ref name, _) = syntax { if !self.variable_binding_sites.contains_key(&name.name) {
if !self.variable_binding_sites.contains_key(name.as_str()) { self.jitter.define_string(&name.name)?;
self.jitter.define_string(name)?; self.jitter
self.jitter.define_variable(name.clone())?; .define_variable(name.to_string(), ConstantType::U64)?;
} }
crate::syntax::Program {
statements: vec![
Statement::Binding(loc.clone(), name.clone(), expr),
Statement::Print(loc, name),
],
}
}
nonbinding => crate::syntax::Program {
statements: vec![nonbinding],
},
}; };
let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites); let (mut errors, mut warnings) =
program.validate_with_bindings(&mut self.variable_binding_sites);
let stop = !errors.is_empty(); let stop = !errors.is_empty();
let messages = errors let messages = errors
.drain(..) .drain(..)
@@ -154,13 +169,39 @@ impl REPL {
return Ok(()); return Ok(());
} }
let ir = IR::from(syntax); match program.type_infer() {
TypeInferenceResult::Failure {
mut errors,
mut warnings,
} => {
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit_diagnostic(message)?;
}
Ok(())
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
for message in warnings.drain(..).map(Into::into) {
self.emit_diagnostic(message)?;
}
let name = format!("line{}", line_no); let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, ir)?; let function_id = self.jitter.compile_function(&name, result)?;
self.jitter.module.finalize_definitions()?; self.jitter.module.finalize_definitions()?;
let compiled_bytes = self.jitter.bytes(function_id); let compiled_bytes = self.jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let compiled_function =
unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function(); compiled_function();
Ok(()) Ok(())
} }
} }
}
}

View File

@@ -27,7 +27,7 @@ use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles};
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
use logos::Logos; use logos::Logos;
mod arbitrary; pub mod arbitrary;
mod ast; mod ast;
mod eval; mod eval;
mod location; mod location;
@@ -40,6 +40,8 @@ lalrpop_mod!(
mod pretty; mod pretty;
mod validate; mod validate;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
pub use crate::syntax::ast::*; pub use crate::syntax::ast::*;
pub use crate::syntax::location::Location; pub use crate::syntax::location::Location;
pub use crate::syntax::parser::{ProgramParser, StatementParser}; pub use crate::syntax::parser::{ProgramParser, StatementParser};
@@ -48,7 +50,7 @@ pub use crate::syntax::tokens::{LexerError, Token};
use ::pretty::{Arena, Pretty}; use ::pretty::{Arena, Pretty};
use lalrpop_util::ParseError; use lalrpop_util::ParseError;
#[cfg(test)] #[cfg(test)]
use proptest::{prop_assert, prop_assert_eq}; use proptest::{arbitrary::Arbitrary, prop_assert, prop_assert_eq};
#[cfg(test)] #[cfg(test)]
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
@@ -73,12 +75,12 @@ pub enum ParserError {
/// Raised when we're parsing the file, and run into a token in a /// Raised when we're parsing the file, and run into a token in a
/// place we weren't expecting it. /// place we weren't expecting it.
#[error("Unrecognized token")] #[error("Unrecognized token")]
UnrecognizedToken(Location, Location, Token, Vec<String>), UnrecognizedToken(Location, Token, Vec<String>),
/// Raised when we were expecting the end of the file, but instead /// Raised when we were expecting the end of the file, but instead
/// got another token. /// got another token.
#[error("Extra token")] #[error("Extra token")]
ExtraToken(Location, Token, Location), ExtraToken(Location, Token),
/// Raised when the lexer just had some sort of internal problem /// Raised when the lexer just had some sort of internal problem
/// and just gave up. /// and just gave up.
@@ -106,30 +108,24 @@ impl ParserError {
fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self { fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self {
match err { match err {
ParseError::InvalidToken { location } => { ParseError::InvalidToken { location } => {
ParserError::InvalidToken(Location::new(file_idx, location)) ParserError::InvalidToken(Location::new(file_idx, location..location + 1))
}
ParseError::UnrecognizedEof { location, expected } => {
ParserError::UnrecognizedEOF(Location::new(file_idx, location), expected)
} }
ParseError::UnrecognizedEof { location, expected } => ParserError::UnrecognizedEOF(
Location::new(file_idx, location..location + 1),
expected,
),
ParseError::UnrecognizedToken { ParseError::UnrecognizedToken {
token: (start, token, end), token: (start, token, end),
expected, expected,
} => ParserError::UnrecognizedToken( } => {
Location::new(file_idx, start), ParserError::UnrecognizedToken(Location::new(file_idx, start..end), token, expected)
Location::new(file_idx, end), }
token,
expected,
),
ParseError::ExtraToken { ParseError::ExtraToken {
token: (start, token, end), token: (start, token, end),
} => ParserError::ExtraToken( } => ParserError::ExtraToken(Location::new(file_idx, start..end), token),
Location::new(file_idx, start),
token,
Location::new(file_idx, end),
),
ParseError::User { error } => match error { ParseError::User { error } => match error {
LexerError::LexFailure(offset) => { LexerError::LexFailure(offset) => {
ParserError::LexFailure(Location::new(file_idx, offset)) ParserError::LexFailure(Location::new(file_idx, offset..offset + 1))
} }
}, },
} }
@@ -180,37 +176,25 @@ impl<'a> From<&'a ParserError> for Diagnostic<usize> {
), ),
// encountered a token where it shouldn't be // encountered a token where it shouldn't be
ParserError::UnrecognizedToken(start, end, token, expected) => { ParserError::UnrecognizedToken(loc, token, expected) => {
let expected_str = let expected_str =
format!("unexpected token {}{}", token, display_expected(expected)); format!("unexpected token {}{}", token, display_expected(expected));
let unexpected_str = format!("unexpected token {}", token); let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error() Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str) .with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
} }
// I think we get this when we get a token, but were expected EOF // I think we get this when we get a token, but were expected EOF
ParserError::ExtraToken(start, token, end) => { ParserError::ExtraToken(loc, token) => {
let expected_str = let expected_str =
format!("unexpected token {} after the expected end of file", token); format!("unexpected token {} after the expected end of file", token);
let unexpected_str = format!("unexpected token {}", token); let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error() Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str) .with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
} }
// simple lexer errors // simple lexer errors
@@ -293,24 +277,27 @@ fn order_of_operations() {
Program::from_str(muladd1).unwrap(), Program::from_str(muladd1).unwrap(),
Program { Program {
statements: vec![Statement::Binding( statements: vec![Statement::Binding(
Location::new(testfile, 0), Location::new(testfile, 0..1),
"x".to_string(), Name::manufactured("x"),
Expression::Primitive( Expression::Primitive(
Location::new(testfile, 6), Location::new(testfile, 6..7),
"+".to_string(), "+".to_string(),
vec![ vec![
Expression::Value(Location::new(testfile, 4), Value::Number(None, 1)), Expression::Value(
Location::new(testfile, 4..5),
Value::Number(None, None, 1),
),
Expression::Primitive( Expression::Primitive(
Location::new(testfile, 10), Location::new(testfile, 10..11),
"*".to_string(), "*".to_string(),
vec![ vec![
Expression::Value( Expression::Value(
Location::new(testfile, 8), Location::new(testfile, 8..9),
Value::Number(None, 2), Value::Number(None, None, 2),
), ),
Expression::Value( Expression::Value(
Location::new(testfile, 12), Location::new(testfile, 12..13),
Value::Number(None, 3), Value::Number(None, None, 3),
), ),
] ]
) )
@@ -350,8 +337,8 @@ proptest::proptest! {
} }
#[test] #[test]
fn generated_run_or_overflow(program: Program) { fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::{EvalError, PrimOpError}; use crate::eval::{EvalError, PrimOpError};
assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))) prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))));
} }
} }

View File

@@ -1,136 +1,189 @@
use std::collections::HashSet; use crate::syntax::ast::{ConstantType, Expression, Name, Program, Statement, Value};
use crate::syntax::ast::{Expression, Program, Statement, Value};
use crate::syntax::location::Location; use crate::syntax::location::Location;
use proptest::sample::select; use proptest::sample::select;
use proptest::{ use proptest::{
prelude::{Arbitrary, BoxedStrategy, Strategy}, prelude::{Arbitrary, BoxedStrategy, Strategy},
strategy::{Just, Union}, strategy::{Just, Union},
}; };
use std::collections::HashMap;
use std::ops::Range;
const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*"; const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*";
#[derive(Debug)] impl ConstantType {
struct Name(String); fn get_operators(&self) -> &'static [(&'static str, usize)] {
match self {
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => {
&[("+", 2), ("-", 1), ("-", 2), ("*", 2), ("/", 2)]
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => {
&[("+", 2), ("-", 2), ("*", 2), ("/", 2)]
}
}
}
}
#[derive(Clone)]
pub struct GenerationEnvironment {
allow_inference: bool,
block_length: Range<usize>,
bindings: HashMap<Name, ConstantType>,
return_type: ConstantType,
}
impl Default for GenerationEnvironment {
fn default() -> Self {
GenerationEnvironment {
allow_inference: true,
block_length: 2..10,
bindings: HashMap::new(),
return_type: ConstantType::U64,
}
}
}
impl GenerationEnvironment {
pub fn new(allow_inference: bool) -> Self {
GenerationEnvironment {
allow_inference,
..Default::default()
}
}
}
impl Arbitrary for Program {
type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
proptest::collection::vec(
ProgramStatementInfo::arbitrary(),
genenv.block_length.clone(),
)
.prop_flat_map(move |mut items| {
let mut statements = Vec::new();
let mut genenv = genenv.clone();
for psi in items.drain(..) {
if genenv.bindings.is_empty() || psi.should_be_binding {
genenv.return_type = psi.binding_type;
let expr = Expression::arbitrary_with(genenv.clone());
genenv.bindings.insert(psi.name.clone(), psi.binding_type);
statements.push(
expr.prop_map(move |expr| {
Statement::Binding(Location::manufactured(), psi.name.clone(), expr)
})
.boxed(),
);
} else {
let printers = genenv.bindings.keys().map(|n| {
Just(Statement::Print(
Location::manufactured(),
Name::manufactured(n),
))
});
statements.push(Union::new(printers).boxed());
}
}
statements
.prop_map(|statements| Program { statements })
.boxed()
})
.boxed()
}
}
impl Arbitrary for Name { impl Arbitrary for Name {
type Parameters = (); type Parameters = ();
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
VALID_VARIABLE_NAMES.prop_map(Name).boxed() VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed()
} }
} }
impl Arbitrary for Program { #[derive(Debug)]
struct ProgramStatementInfo {
should_be_binding: bool,
name: Name,
binding_type: ConstantType,
}
impl Arbitrary for ProgramStatementInfo {
type Parameters = (); type Parameters = ();
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
let optionals = Vec::<Option<Name>>::arbitrary(); (
Union::new(vec![Just(true), Just(true), Just(false)]),
optionals Name::arbitrary(),
.prop_flat_map(|mut possible_names| { ConstantType::arbitrary(),
let mut statements = Vec::new(); )
let mut defined_variables: HashSet<String> = HashSet::new(); .prop_map(
|(should_be_binding, name, binding_type)| ProgramStatementInfo {
for possible_name in possible_names.drain(..) { should_be_binding,
match possible_name { name,
None if defined_variables.is_empty() => continue, binding_type,
None => statements.push( },
Union::new(defined_variables.iter().map(|name| {
Just(Statement::Print(Location::manufactured(), name.to_string()))
}))
.boxed(),
),
Some(new_name) => {
let closures_name = new_name.0.clone();
let retval =
Expression::arbitrary_with(Some(defined_variables.clone()))
.prop_map(move |exp| {
Statement::Binding(
Location::manufactured(),
closures_name.clone(),
exp,
) )
})
.boxed();
defined_variables.insert(new_name.0);
statements.push(retval);
}
}
}
statements
})
.prop_map(|statements| Program { statements })
.boxed() .boxed()
} }
} }
impl Arbitrary for Statement {
type Parameters = Option<HashSet<String>>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let duplicated_args = args.clone();
let defined_variables = args.unwrap_or_default();
let binding_strategy = (
VALID_VARIABLE_NAMES,
Expression::arbitrary_with(duplicated_args),
)
.prop_map(|(name, exp)| Statement::Binding(Location::manufactured(), name, exp))
.boxed();
if defined_variables.is_empty() {
binding_strategy
} else {
let print_strategy = Union::new(
defined_variables
.iter()
.map(|x| Just(Statement::Print(Location::manufactured(), x.to_string()))),
)
.boxed();
Union::new([binding_strategy, print_strategy]).boxed()
}
}
}
impl Arbitrary for Expression { impl Arbitrary for Expression {
type Parameters = Option<HashSet<String>>; type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
let defined_variables = args.unwrap_or_default(); // Value(Location, Value). These are the easiest variations to create, because we can always
// create one.
let value_strategy = Value::arbitrary() let value_strategy = Value::arbitrary_with(genenv.clone())
.prop_map(move |x| Expression::Value(Location::manufactured(), x)) .prop_map(|x| Expression::Value(Location::manufactured(), x))
.boxed(); .boxed();
let leaf_strategy = if defined_variables.is_empty() { // Reference(Location, String), These are slightly trickier, because we can end up in a situation
// where either no variables are defined, or where none of the defined variables have a type we
// can work with. So what we're going to do is combine this one with the previous one as a "leaf
// strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy
// if we can't actually create an references.
let mut bound_variables_of_type = genenv
.bindings
.iter()
.filter(|(_, v)| genenv.return_type == **v)
.map(|(n, _)| n)
.collect::<Vec<_>>();
let leaf_strategy = if bound_variables_of_type.is_empty() {
value_strategy value_strategy
} else { } else {
let reference_strategy = Union::new(defined_variables.iter().map(|x| { let mut strats = bound_variables_of_type
.drain(..)
.map(|x| {
Just(Expression::Reference( Just(Expression::Reference(
Location::manufactured(), Location::manufactured(),
x.to_owned(), x.name.clone(),
)) ))
})) .boxed()
.boxed(); })
Union::new([value_strategy, reference_strategy]).boxed() .collect::<Vec<_>>();
strats.push(value_strategy);
Union::new(strats).boxed()
}; };
// now we generate our recursive types, given our leaf strategy
leaf_strategy leaf_strategy
.prop_recursive(3, 64, 2, move |inner| { .prop_recursive(3, 10, 2, move |strat| {
( (
select(super::BINARY_OPERATORS), select(genenv.return_type.get_operators()),
proptest::collection::vec(inner, 2), strat.clone(),
strat,
) )
.prop_map(move |(operator, exprs)| { .prop_map(|((oper, count), left, right)| {
Expression::Primitive(Location::manufactured(), operator.to_string(), exprs) let mut args = vec![left, right];
while args.len() > count {
args.pop();
}
Expression::Primitive(Location::manufactured(), oper.to_string(), args)
}) })
}) })
.boxed() .boxed()
@@ -138,22 +191,57 @@ impl Arbitrary for Expression {
} }
impl Arbitrary for Value { impl Arbitrary for Value {
type Parameters = (); type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
let base_strategy = Union::new([ let printed_base_strategy = Union::new([
Just(None::<u8>), Just(None::<u8>),
Just(Some(2)), Just(Some(2)),
Just(Some(8)), Just(Some(8)),
Just(Some(10)), Just(Some(10)),
Just(Some(16)), Just(Some(16)),
]); ]);
let value_strategy = u64::arbitrary();
let value_strategy = i64::arbitrary(); (printed_base_strategy, bool::arbitrary(), value_strategy)
.prop_map(move |(base, declare_type, value)| {
(base_strategy, value_strategy) let converted_value = match genenv.return_type {
.prop_map(move |(base, value)| Value::Number(base, value)) ConstantType::I8 => value % (i8::MAX as u64),
ConstantType::U8 => value % (u8::MAX as u64),
ConstantType::I16 => value % (i16::MAX as u64),
ConstantType::U16 => value % (u16::MAX as u64),
ConstantType::I32 => value % (i32::MAX as u64),
ConstantType::U32 => value % (u32::MAX as u64),
ConstantType::I64 => value % (i64::MAX as u64),
ConstantType::U64 => value,
};
let ty = if declare_type || !genenv.allow_inference {
Some(genenv.return_type)
} else {
None
};
Value::Number(base, ty, converted_value)
})
.boxed()
}
}
impl Arbitrary for ConstantType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
Union::new([
Just(ConstantType::I8),
Just(ConstantType::I16),
Just(ConstantType::I32),
Just(ConstantType::I64),
Just(ConstantType::U8),
Just(ConstantType::U16),
Just(ConstantType::U32),
Just(ConstantType::U64),
])
.boxed() .boxed()
} }
} }

View File

@@ -1,7 +1,10 @@
use crate::syntax::Location; use std::fmt;
use std::hash::Hash;
/// The set of valid binary operators. use internment::ArcIntern;
pub static BINARY_OPERATORS: &[&str] = &["+", "-", "*", "/"];
pub use crate::syntax::tokens::ConstantType;
use crate::syntax::Location;
/// A structure represented a parsed program. /// A structure represented a parsed program.
/// ///
@@ -16,6 +19,56 @@ pub struct Program {
pub statements: Vec<Statement>, pub statements: Vec<Statement>,
} }
/// A Name.
///
/// This is basically a string, but annotated with the place the string
/// is in the source file.
#[derive(Clone, Debug)]
pub struct Name {
pub name: String,
pub location: Location,
}
impl Name {
pub fn new<S: ToString>(n: S, location: Location) -> Name {
Name {
name: n.to_string(),
location,
}
}
pub fn manufactured<S: ToString>(n: S) -> Name {
Name {
name: n.to_string(),
location: Location::manufactured(),
}
}
pub fn intern(self) -> ArcIntern<String> {
ArcIntern::new(self.name)
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state)
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
/// A parsed statement. /// A parsed statement.
/// ///
/// Statements are guaranteed to be syntactically valid, but may be /// Statements are guaranteed to be syntactically valid, but may be
@@ -29,8 +82,8 @@ pub struct Program {
/// thing, not if they are the exact same statement. /// thing, not if they are the exact same statement.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Statement { pub enum Statement {
Binding(Location, String, Expression), Binding(Location, Name, Expression),
Print(Location, String), Print(Location, Name),
} }
impl PartialEq for Statement { impl PartialEq for Statement {
@@ -58,6 +111,7 @@ impl PartialEq for Statement {
pub enum Expression { pub enum Expression {
Value(Location, Value), Value(Location, Value),
Reference(Location, String), Reference(Location, String),
Cast(Location, String, Box<Expression>),
Primitive(Location, String, Vec<Expression>), Primitive(Location, String, Vec<Expression>),
} }
@@ -72,6 +126,10 @@ impl PartialEq for Expression {
Expression::Reference(_, var2) => var1 == var2, Expression::Reference(_, var2) => var1 == var2,
_ => false, _ => false,
}, },
Expression::Cast(_, t1, e1) => match other {
Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2,
_ => false,
},
Expression::Primitive(_, prim1, args1) => match other { Expression::Primitive(_, prim1, args1) => match other {
Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2,
_ => false, _ => false,
@@ -83,6 +141,12 @@ impl PartialEq for Expression {
/// A value from the source syntax /// A value from the source syntax
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Value { pub enum Value {
/// The value of the number, and an optional base that it was written in /// The value of the number, an optional base that it was written in, and any
Number(Option<u8>, i64), /// type information provided.
///
/// u64 is chosen because it should be big enough to carry the amount of
/// information we need, and technically we interpret -4 as the primitive unary
/// operation "-" on the number 4. We'll translate this into a type-specific
/// number at a later time.
Number(Option<u8>, Option<ConstantType>, u64),
} }

View File

@@ -1,7 +1,8 @@
use internment::ArcIntern; use internment::ArcIntern;
use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value};
use crate::syntax::{Expression, Program, Statement}; use crate::syntax::{ConstantType, Expression, Program, Statement};
use std::str::FromStr;
impl Program { impl Program {
/// Evaluate the program, returning either an error or what it prints out when run. /// Evaluate the program, returning either an error or what it prints out when run.
@@ -24,11 +25,11 @@ impl Program {
match stmt { match stmt {
Statement::Binding(_, name, value) => { Statement::Binding(_, name, value) => {
let actual_value = value.eval(&env)?; let actual_value = value.eval(&env)?;
env = env.extend(ArcIntern::new(name.clone()), actual_value); env = env.extend(name.clone().intern(), actual_value);
} }
Statement::Print(_, name) => { Statement::Print(_, name) => {
let value = env.lookup(ArcIntern::new(name.clone()))?; let value = env.lookup(name.clone().intern())?;
let line = format!("{} = {}\n", name, value); let line = format!("{} = {}\n", name, value);
stdout.push_str(&line); stdout.push_str(&line);
} }
@@ -43,11 +44,28 @@ impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self { match self {
Expression::Value(_, v) => match v { Expression::Value(_, v) => match v {
super::Value::Number(_, v) => Ok(Value::I64(*v)), super::Value::Number(_, ty, v) => match ty {
None => Ok(Value::U64(*v)),
// FIXME: make these types validate their input size
Some(ConstantType::I8) => Ok(Value::I8(*v as i8)),
Some(ConstantType::I16) => Ok(Value::I16(*v as i16)),
Some(ConstantType::I32) => Ok(Value::I32(*v as i32)),
Some(ConstantType::I64) => Ok(Value::I64(*v as i64)),
Some(ConstantType::U8) => Ok(Value::U8(*v as u8)),
Some(ConstantType::U16) => Ok(Value::U16(*v as u16)),
Some(ConstantType::U32) => Ok(Value::U32(*v as u32)),
Some(ConstantType::U64) => Ok(Value::U64(*v)),
},
}, },
Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?), Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?),
Expression::Cast(_, target, expr) => {
let target_type = PrimitiveType::from_str(target)?;
let value = expr.eval(env)?;
Ok(target_type.safe_cast(&value)?)
}
Expression::Primitive(_, op, args) => { Expression::Primitive(_, op, args) => {
let mut arg_values = Vec::with_capacity(args.len()); let mut arg_values = Vec::with_capacity(args.len());
@@ -66,12 +84,12 @@ impl Expression {
fn two_plus_three() { fn two_plus_three() {
let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let output = input.eval().expect("runs successfully"); let output = input.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let output = input.eval().expect("runs successfully"); let output = input.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -1,13 +1,15 @@
use std::ops::Range;
use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::diagnostic::{Diagnostic, Label};
/// A source location, for use in pointing users towards warnings and errors. /// A source location, for use in pointing users towards warnings and errors.
/// ///
/// Internally, locations are very tied to the `codespan_reporting` library, /// Internally, locations are very tied to the `codespan_reporting` library,
/// and the primary use of them is to serve as anchors within that library. /// and the primary use of them is to serve as anchors within that library.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Location { pub struct Location {
file_idx: usize, file_idx: usize,
offset: usize, location: Range<usize>,
} }
impl Location { impl Location {
@@ -17,8 +19,8 @@ impl Location {
/// The file index is based on the file database being used. See the /// The file index is based on the file database being used. See the
/// `codespan_reporting::files::SimpleFiles::add` function, which is /// `codespan_reporting::files::SimpleFiles::add` function, which is
/// normally where we get this index. /// normally where we get this index.
pub fn new(file_idx: usize, offset: usize) -> Self { pub fn new(file_idx: usize, location: Range<usize>) -> Self {
Location { file_idx, offset } Location { file_idx, location }
} }
/// Generate a `Location` for a completely manufactured bit of code. /// Generate a `Location` for a completely manufactured bit of code.
@@ -30,7 +32,7 @@ impl Location {
pub fn manufactured() -> Self { pub fn manufactured() -> Self {
Location { Location {
file_idx: 0, file_idx: 0,
offset: 0, location: 0..0,
} }
} }
@@ -47,7 +49,7 @@ impl Location {
/// actually happened), but you'd probably want to make the first location /// actually happened), but you'd probably want to make the first location
/// the secondary label to help users find it. /// the secondary label to help users find it.
pub fn primary_label(&self) -> Label<usize> { pub fn primary_label(&self) -> Label<usize> {
Label::primary(self.file_idx, self.offset..self.offset) Label::primary(self.file_idx, self.location.clone())
} }
/// Generate a secondary label for a [`Diagnostic`], based on this source /// Generate a secondary label for a [`Diagnostic`], based on this source
@@ -64,35 +66,7 @@ impl Location {
/// probably want to make the first location the secondary label to help /// probably want to make the first location the secondary label to help
/// users find it. /// users find it.
pub fn secondary_label(&self) -> Label<usize> { pub fn secondary_label(&self) -> Label<usize> {
Label::secondary(self.file_idx, self.offset..self.offset) Label::secondary(self.file_idx, self.location.clone())
}
/// Given this location and another, generate a primary label that
/// specifies the area between those two locations.
///
/// See [`Self::primary_label`] for some discussion of primary versus
/// secondary labels. If the two locations are the same, this method does
/// the exact same thing as [`Self::primary_label`]. If this item was
/// generated by [`Self::manufactured`], it will act as if you'd called
/// `primary_label` on the argument. Otherwise, it will generate the obvious
/// span.
///
/// This function will return `None` only in the case that you provide
/// labels from two different files, which it cannot sensibly handle.
pub fn range_label(&self, end: &Location) -> Option<Label<usize>> {
if self.file_idx == 0 {
return Some(end.primary_label());
}
if self.file_idx != end.file_idx {
return None;
}
if self.offset > end.offset {
Some(Label::primary(self.file_idx, end.offset..self.offset))
} else {
Some(Label::primary(self.file_idx, self.offset..end.offset))
}
} }
/// Return an error diagnostic centered at this location. /// Return an error diagnostic centered at this location.
@@ -102,10 +76,7 @@ impl Location {
/// this particular location. You'll need to extend it with actually useful /// this particular location. You'll need to extend it with actually useful
/// information, like what kind of error it is. /// information, like what kind of error it is.
pub fn error(&self) -> Diagnostic<usize> { pub fn error(&self) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary( Diagnostic::error().with_labels(vec![Label::primary(self.file_idx, self.location.clone())])
self.file_idx,
self.offset..self.offset,
)])
} }
/// Return an error diagnostic centered at this location, with the given message. /// Return an error diagnostic centered at this location, with the given message.
@@ -115,10 +86,34 @@ impl Location {
/// even more information to ut, using [`Diagnostic::with_labels`], /// even more information to ut, using [`Diagnostic::with_labels`],
/// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`]. /// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`].
pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> { pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary( Diagnostic::error().with_labels(vec![
self.file_idx, Label::primary(self.file_idx, self.location.clone()).with_message(msg)
self.offset..self.offset, ])
) }
.with_message(msg)])
/// Merge two locations into a single location spanning the whole range between
/// them.
///
/// This function returns None if the locations are from different files; this
/// can happen if one of the locations is manufactured, for example.
pub fn merge(&self, other: &Self) -> Option<Self> {
if self.file_idx != other.file_idx {
None
} else {
let start = if self.location.start <= other.location.start {
self.location.start
} else {
other.location.start
};
let end = if self.location.end >= other.location.end {
self.location.end
} else {
other.location.end
};
Some(Location {
file_idx: self.file_idx,
location: start..end,
})
}
} }
} }

View File

@@ -9,8 +9,8 @@
//! eventually want to leave lalrpop behind.) //! eventually want to leave lalrpop behind.)
//! //!
use crate::syntax::{LexerError, Location}; use crate::syntax::{LexerError, Location};
use crate::syntax::ast::{Program,Statement,Expression,Value}; use crate::syntax::ast::{Program,Statement,Expression,Value,Name};
use crate::syntax::tokens::Token; use crate::syntax::tokens::{ConstantType, Token};
use internment::ArcIntern; use internment::ArcIntern;
// one cool thing about lalrpop: we can pass arguments. in this case, the // one cool thing about lalrpop: we can pass arguments. in this case, the
@@ -32,6 +32,8 @@ extern {
";" => Token::Semi, ";" => Token::Semi,
"(" => Token::LeftParen, "(" => Token::LeftParen,
")" => Token::RightParen, ")" => Token::RightParen,
"<" => Token::LessThan,
">" => Token::GreaterThan,
"print" => Token::Print, "print" => Token::Print,
@@ -44,7 +46,7 @@ extern {
// to name and use "their value", you get their source location. // to name and use "their value", you get their source location.
// For these, we want "their value" to be their actual contents, // For these, we want "their value" to be their actual contents,
// which is why we put their types in angle brackets. // which is why we put their types in angle brackets.
"<num>" => Token::Number((<Option<u8>>,<i64>)), "<num>" => Token::Number((<Option<u8>>,<Option<ConstantType>>,<u64>)),
"<var>" => Token::Variable(<ArcIntern<String>>), "<var>" => Token::Variable(<ArcIntern<String>>),
} }
} }
@@ -89,10 +91,19 @@ pub Statement: Statement = {
// A statement can be a variable binding. Note, here, that we use this // A statement can be a variable binding. Note, here, that we use this
// funny @L thing to get the source location before the variable, so that // funny @L thing to get the source location before the variable, so that
// we can say that this statement spans across everything. // we can say that this statement spans across everything.
<l:@L> <v:"<var>"> "=" <e:Expression> ";" => Statement::Binding(Location::new(file_idx, l), v.to_string(), e), <ls: @L> <v:"<var>"> <var_end: @L> "=" <e:Expression> ";" <le: @L> =>
Statement::Binding(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, ls..var_end)),
e,
),
// Alternatively, a statement can just be a print statement. // Alternatively, a statement can just be a print statement.
"print" <l:@L> <v:"<var>"> ";" => Statement::Print(Location::new(file_idx, l), v.to_string()), <ls: @L> "print" <name_start: @L> <v:"<var>"> <name_end: @L> ";" <le: @L> =>
Statement::Print(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, name_start..name_end)),
),
} }
// Expressions! Expressions are a little fiddly, because we're going to // Expressions! Expressions are a little fiddly, because we're going to
@@ -124,15 +135,27 @@ Expression: Expression = {
// we group addition and subtraction under the heading "additive" // we group addition and subtraction under the heading "additive"
AdditiveExpression: Expression = { AdditiveExpression: Expression = {
<e1:AdditiveExpression> <l:@L> "+" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "+".to_string(), vec![e1, e2]), <ls: @L> <e1:AdditiveExpression> <l: @L> "+" <e2:MultiplicativeExpression> <le: @L> =>
<e1:AdditiveExpression> <l:@L> "-" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "-".to_string(), vec![e1, e2]), Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]),
<ls: @L> <e1:AdditiveExpression> <l: @L> "-" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]),
MultiplicativeExpression, MultiplicativeExpression,
} }
// similarly, we group multiplication and division under "multiplicative" // similarly, we group multiplication and division under "multiplicative"
MultiplicativeExpression: Expression = { MultiplicativeExpression: Expression = {
<e1:MultiplicativeExpression> <l:@L> "*" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "*".to_string(), vec![e1, e2]), <ls: @L> <e1:MultiplicativeExpression> <l: @L> "*" <e2:UnaryExpression> <le: @L> =>
<e1:MultiplicativeExpression> <l:@L> "/" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "/".to_string(), vec![e1, e2]), Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]),
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "/" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]),
UnaryExpression,
}
UnaryExpression: Expression = {
<l: @L> "-" <e:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]),
<l: @L> "<" <v:"<var>"> ">" <e:UnaryExpression> <le: @L> =>
Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)),
AtomicExpression, AtomicExpression,
} }
@@ -140,22 +163,9 @@ MultiplicativeExpression: Expression = {
// they cannot be further divided into parts // they cannot be further divided into parts
AtomicExpression: Expression = { AtomicExpression: Expression = {
// just a variable reference // just a variable reference
<l:@L> <v:"<var>"> => Expression::Reference(Location::new(file_idx, l), v.to_string()), <l: @L> <v:"<var>"> <end: @L> => Expression::Reference(Location::new(file_idx, l..end), v.to_string()),
// just a number // just a number
<l:@L> <n:"<num>"> => { <l: @L> <n:"<num>"> <end: @L> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)),
let val = Value::Number(n.0, n.1);
Expression::Value(Location::new(file_idx, l), val)
},
// a tricky case: also just a number, but using a negative sign. an
// alternative way to do this -- and we may do this eventually -- is
// to implement a unary negation expression. this has the odd effect
// that the user never actually writes down a negative number; they just
// write positive numbers which are immediately sent to a negation
// primitive!
<l:@L> "-" <n:"<num>"> => {
let val = Value::Number(n.0, -n.1);
Expression::Value(Location::new(file_idx, l), val)
},
// finally, let people parenthesize expressions and get back to a // finally, let people parenthesize expressions and get back to a
// lower precedence // lower precedence
"(" <e:Expression> ")" => e, "(" <e:Expression> ")" => e,

View File

@@ -1,6 +1,8 @@
use crate::syntax::ast::{Expression, Program, Statement, Value, BINARY_OPERATORS}; use crate::syntax::ast::{Expression, Program, Statement, Value};
use pretty::{DocAllocator, DocBuilder, Pretty}; use pretty::{DocAllocator, DocBuilder, Pretty};
use super::ConstantType;
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where where
A: 'a, A: 'a,
@@ -50,14 +52,14 @@ where
match self { match self {
Expression::Value(_, val) => val.pretty(allocator), Expression::Value(_, val) => val.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.to_string()), Expression::Reference(_, var) => allocator.text(var.to_string()),
Expression::Primitive(_, op, exprs) if BINARY_OPERATORS.contains(&op.as_ref()) => { Expression::Cast(_, t, e) => allocator
assert_eq!( .text(t.clone())
exprs.len(), .angles()
2, .append(e.pretty(allocator)),
"Found binary operator with {} components?", Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator
exprs.len() .text(op.to_string())
); .append(exprs[0].pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator); let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator); let right = exprs[1].pretty(allocator);
@@ -84,15 +86,14 @@ where
{ {
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
match self { match self {
Value::Number(opt_base, value) => { Value::Number(opt_base, ty, value) => {
let sign = if *value < 0 { "-" } else { "" };
let value_str = match opt_base { let value_str = match opt_base {
None => format!("{}", value), None => format!("{}{}", value, type_suffix(ty)),
Some(2) => format!("{}0b{:b}", sign, value.abs()), Some(2) => format!("0b{:b}{}", value, type_suffix(ty)),
Some(8) => format!("{}0o{:o}", sign, value.abs()), Some(8) => format!("0o{:o}{}", value, type_suffix(ty)),
Some(10) => format!("{}0d{}", sign, value.abs()), Some(10) => format!("0d{}{}", value, type_suffix(ty)),
Some(16) => format!("{}0x{:x}", sign, value.abs()), Some(16) => format!("0x{:x}{}", value, type_suffix(ty)),
Some(_) => format!("!!{}{:x}!!", sign, value.abs()), Some(_) => format!("!!{:x}{}!!", value, type_suffix(ty)),
}; };
allocator.text(value_str) allocator.text(value_str)
@@ -101,6 +102,20 @@ where
} }
} }
fn type_suffix(x: &Option<ConstantType>) -> &'static str {
match x {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct CommaSep {} struct CommaSep {}

View File

@@ -40,6 +40,12 @@ pub enum Token {
#[token(")")] #[token(")")]
RightParen, RightParen,
#[token("<")]
LessThan,
#[token(">")]
GreaterThan,
// Next we take of any reserved words; I always like to put // Next we take of any reserved words; I always like to put
// these before we start recognizing more complicated regular // these before we start recognizing more complicated regular
// expressions. I don't think it matters, but it works for me. // expressions. I don't think it matters, but it works for me.
@@ -53,13 +59,14 @@ pub enum Token {
/// Numbers capture both the value we read from the input, /// Numbers capture both the value we read from the input,
/// converted to an `i64`, as well as the base the user used /// converted to an `i64`, as well as the base the user used
/// to write the number, if they did so. /// to write the number and/or the type the user specified,
#[regex(r"0b[01]+", |v| parse_number(Some(2), v))] /// if they did either.
#[regex(r"0o[0-7]+", |v| parse_number(Some(8), v))] #[regex(r"0b[01]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(2), v))]
#[regex(r"0d[0-9]+", |v| parse_number(Some(10), v))] #[regex(r"0o[0-7]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(8), v))]
#[regex(r"0x[0-9a-fA-F]+", |v| parse_number(Some(16), v))] #[regex(r"0d[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(10), v))]
#[regex(r"[0-9]+", |v| parse_number(None, v))] #[regex(r"0x[0-9a-fA-F]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(16), v))]
Number((Option<u8>, i64)), #[regex(r"[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(None, v))]
Number((Option<u8>, Option<ConstantType>, u64)),
// Variables; this is a very standard, simple set of characters // Variables; this is a very standard, simple set of characters
// for variables, but feel free to experiment with more complicated // for variables, but feel free to experiment with more complicated
@@ -88,15 +95,29 @@ impl fmt::Display for Token {
Token::Semi => write!(f, "';'"), Token::Semi => write!(f, "';'"),
Token::LeftParen => write!(f, "'('"), Token::LeftParen => write!(f, "'('"),
Token::RightParen => write!(f, "')'"), Token::RightParen => write!(f, "')'"),
Token::LessThan => write!(f, "<"),
Token::GreaterThan => write!(f, ">"),
Token::Print => write!(f, "'print'"), Token::Print => write!(f, "'print'"),
Token::Operator(c) => write!(f, "'{}'", c), Token::Operator(c) => write!(f, "'{}'", c),
Token::Number((None, v)) => write!(f, "'{}'", v), Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)),
Token::Number((Some(2), v)) => write!(f, "'0b{:b}'", v), Token::Number((Some(2), otype, v)) => {
Token::Number((Some(8), v)) => write!(f, "'0o{:o}'", v), write!(f, "'0b{:b}{}'", v, display_optional_type(otype))
Token::Number((Some(10), v)) => write!(f, "'{}'", v), }
Token::Number((Some(16), v)) => write!(f, "'0x{:x}'", v), Token::Number((Some(8), otype, v)) => {
Token::Number((Some(b), v)) => { write!(f, "'0o{:o}{}'", v, display_optional_type(otype))
write!(f, "Invalidly-based-number<base={},val={}>", b, v) }
Token::Number((Some(10), otype, v)) => {
write!(f, "'{}{}'", v, display_optional_type(otype))
}
Token::Number((Some(16), otype, v)) => {
write!(f, "'0x{:x}{}'", v, display_optional_type(otype))
}
Token::Number((Some(b), opt_type, v)) => {
write!(
f,
"Invalidly-based-number<base={},val={},opt_type={:?}>",
b, v, opt_type
)
} }
Token::Variable(s) => write!(f, "'{}'", s), Token::Variable(s) => write!(f, "'{}'", s),
Token::Error => write!(f, "<error>"), Token::Error => write!(f, "<error>"),
@@ -122,6 +143,125 @@ impl Token {
} }
} }
#[repr(i64)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ConstantType {
U8 = 10,
U16 = 11,
U32 = 12,
U64 = 13,
I8 = 20,
I16 = 21,
I32 = 22,
I64 = 23,
}
impl From<ConstantType> for cranelift_codegen::ir::Type {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8,
ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16,
ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32,
ConstantType::I64 | ConstantType::U64 => cranelift_codegen::ir::types::I64,
}
}
}
impl ConstantType {
/// Returns true if the given type is (a) numeric and (b) signed;
pub fn is_signed(&self) -> bool {
matches!(
self,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64
)
}
/// Return the set of types that can be safely casted into this type.
pub fn safe_casts_to(self) -> Vec<ConstantType> {
match self {
ConstantType::I8 => vec![ConstantType::I8],
ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8],
ConstantType::I32 => vec![
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::I64 => vec![
ConstantType::I64,
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::U8 => vec![ConstantType::U8],
ConstantType::U16 => vec![ConstantType::U16, ConstantType::U8],
ConstantType::U32 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8],
ConstantType::U64 => vec![
ConstantType::U64,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
}
}
/// Return the set of all currently-available constant types
pub fn all_types() -> Vec<Self> {
vec![
ConstantType::U8,
ConstantType::U16,
ConstantType::U32,
ConstantType::U64,
ConstantType::I8,
ConstantType::I16,
ConstantType::I32,
ConstantType::I64,
]
}
/// Return the name of the given type, as a string
pub fn name(&self) -> String {
match self {
ConstantType::I8 => "i8".to_string(),
ConstantType::I16 => "i16".to_string(),
ConstantType::I32 => "i32".to_string(),
ConstantType::I64 => "i64".to_string(),
ConstantType::U8 => "u8".to_string(),
ConstantType::U16 => "u16".to_string(),
ConstantType::U32 => "u32".to_string(),
ConstantType::U64 => "u64".to_string(),
}
}
}
#[derive(Debug, Error, PartialEq)]
pub enum InvalidConstantType {
#[error("Unrecognized constant {0} for constant type")]
Value(i64),
}
impl TryFrom<i64> for ConstantType {
type Error = InvalidConstantType;
fn try_from(value: i64) -> Result<Self, Self::Error> {
match value {
10 => Ok(ConstantType::U8),
11 => Ok(ConstantType::U16),
12 => Ok(ConstantType::U32),
13 => Ok(ConstantType::U64),
20 => Ok(ConstantType::I8),
21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64),
_ => Err(InvalidConstantType::Value(value)),
}
}
}
/// Parse a number in the given base, return a pair of the base and the /// Parse a number in the given base, return a pair of the base and the
/// parsed number. This is just a helper used for all of the number /// parsed number. This is just a helper used for all of the number
/// regular expression cases, which kicks off to the obvious Rust /// regular expression cases, which kicks off to the obvious Rust
@@ -129,24 +269,66 @@ impl Token {
fn parse_number( fn parse_number(
base: Option<u8>, base: Option<u8>,
value: &Lexer<Token>, value: &Lexer<Token>,
) -> Result<(Option<u8>, i64), ParseIntError> { ) -> Result<(Option<u8>, Option<ConstantType>, u64), ParseIntError> {
let (radix, strval) = match base { let (radix, strval) = match base {
None => (10, value.slice()), None => (10, value.slice()),
Some(radix) => (radix, &value.slice()[2..]), Some(radix) => (radix, &value.slice()[2..]),
}; };
let intval = i64::from_str_radix(strval, radix as u32)?; let (declared_type, strval) = if let Some(strval) = strval.strip_suffix("u8") {
Ok((base, intval)) (Some(ConstantType::U8), strval)
} else if let Some(strval) = strval.strip_suffix("u16") {
(Some(ConstantType::U16), strval)
} else if let Some(strval) = strval.strip_suffix("u32") {
(Some(ConstantType::U32), strval)
} else if let Some(strval) = strval.strip_suffix("u64") {
(Some(ConstantType::U64), strval)
} else if let Some(strval) = strval.strip_suffix("i8") {
(Some(ConstantType::I8), strval)
} else if let Some(strval) = strval.strip_suffix("i16") {
(Some(ConstantType::I16), strval)
} else if let Some(strval) = strval.strip_suffix("i32") {
(Some(ConstantType::I32), strval)
} else if let Some(strval) = strval.strip_suffix("i64") {
(Some(ConstantType::I64), strval)
} else {
(None, strval)
};
let intval = u64::from_str_radix(strval, radix as u32)?;
Ok((base, declared_type, intval))
}
fn display_optional_type(otype: &Option<ConstantType>) -> &'static str {
match otype {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
} }
#[test] #[test]
fn lex_numbers() { fn lex_numbers() {
let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc // 9"); let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9");
assert_eq!(lex0.next(), Some(Token::Number((None, 12)))); assert_eq!(lex0.next(), Some(Token::Number((None, None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(2), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(2), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(8), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(8), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(10), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(10), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(16), 12)))); assert_eq!(lex0.next(), Some(Token::Number((Some(16), None, 12))));
assert_eq!(
lex0.next(),
Some(Token::Number((None, Some(ConstantType::U8), 12)))
);
assert_eq!(
lex0.next(),
Some(Token::Number((Some(16), Some(ConstantType::I64), 12)))
);
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
@@ -168,6 +350,31 @@ fn lexer_spans() {
assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(lex0.next(), Some((Token::var("x"), 4..5))); assert_eq!(lex0.next(), Some((Token::var("x"), 4..5)));
assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7))); assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7)));
assert_eq!(lex0.next(), Some((Token::Number((None, 1)), 8..9))); assert_eq!(lex0.next(), Some((Token::Number((None, None, 1)), 8..9)));
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
#[test]
fn further_spans() {
let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned();
assert_eq!(lex0.next(), Some((Token::var("x"), 0..1)));
assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 4..8))
);
assert_eq!(lex0.next(), Some((Token::Operator('+'), 9..10)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 11..15))
);
assert_eq!(lex0.next(), Some((Token::Semi, 15..16)));
assert_eq!(lex0.next(), Some((Token::var("y"), 17..18)));
assert_eq!(lex0.next(), Some((Token::Equals, 19..20)));
assert_eq!(lex0.next(), Some((Token::Operator('-'), 21..22)));
assert_eq!(lex0.next(), Some((Token::var("x"), 22..23)));
assert_eq!(lex0.next(), Some((Token::Semi, 23..24)));
assert_eq!(lex0.next(), Some((Token::Print, 25..30)));
assert_eq!(lex0.next(), Some((Token::var("y"), 31..32)));
assert_eq!(lex0.next(), Some((Token::Semi, 32..33)));
}

View File

@@ -1,6 +1,9 @@
use crate::syntax::{Expression, Location, Program, Statement}; use crate::{
eval::PrimitiveType,
syntax::{Expression, Location, Program, Statement},
};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use std::collections::HashMap; use std::{collections::HashMap, str::FromStr};
/// An error we found while validating the input program. /// An error we found while validating the input program.
/// ///
@@ -11,6 +14,7 @@ use std::collections::HashMap;
/// and using [`codespan_reporting`] to present them to the user. /// and using [`codespan_reporting`] to present them to the user.
pub enum Error { pub enum Error {
UnboundVariable(Location, String), UnboundVariable(Location, String),
UnknownType(Location, String),
} }
impl From<Error> for Diagnostic<usize> { impl From<Error> for Diagnostic<usize> {
@@ -19,6 +23,10 @@ impl From<Error> for Diagnostic<usize> {
Error::UnboundVariable(location, name) => location Error::UnboundVariable(location, name) => location
.labelled_error("unbound here") .labelled_error("unbound here")
.with_message(format!("Unbound variable '{}'", name)), .with_message(format!("Unbound variable '{}'", name)),
Error::UnknownType(location, name) => location
.labelled_error("type referenced here")
.with_message(format!("Unknown type '{}'", name)),
} }
} }
} }
@@ -57,12 +65,24 @@ impl Program {
/// example, and generates warnings for things that are inadvisable but not /// example, and generates warnings for things that are inadvisable but not
/// actually a problem. /// actually a problem.
pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) { pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) {
let mut bound_variables = HashMap::new();
self.validate_with_bindings(&mut bound_variables)
}
/// Validate that the program makes semantic sense, not just syntactic sense.
///
/// This checks for things like references to variables that don't exist, for
/// example, and generates warnings for things that are inadvisable but not
/// actually a problem.
pub fn validate_with_bindings(
&self,
bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) {
let mut errors = vec![]; let mut errors = vec![];
let mut warnings = vec![]; let mut warnings = vec![];
let mut bound_variables = HashMap::new();
for stmt in self.statements.iter() { for stmt in self.statements.iter() {
let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables); let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables);
errors.append(&mut new_errors); errors.append(&mut new_errors);
warnings.append(&mut new_warnings); warnings.append(&mut new_warnings);
} }
@@ -81,7 +101,7 @@ impl Statement {
/// occurs. We use a `HashMap` to map these bound locations to the locations /// occurs. We use a `HashMap` to map these bound locations to the locations
/// where their bound, because these locations are handy when generating errors /// where their bound, because these locations are handy when generating errors
/// and warnings. /// and warnings.
pub fn validate( fn validate(
&self, &self,
bound_variables: &mut HashMap<String, Location>, bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) { ) -> (Vec<Error>, Vec<Warning>) {
@@ -97,20 +117,20 @@ impl Statement {
errors.append(&mut exp_errors); errors.append(&mut exp_errors);
warnings.append(&mut exp_warnings); warnings.append(&mut exp_warnings);
if let Some(original_binding_site) = bound_variables.get(var) { if let Some(original_binding_site) = bound_variables.get(&var.name) {
warnings.push(Warning::ShadowedVariable( warnings.push(Warning::ShadowedVariable(
original_binding_site.clone(), original_binding_site.clone(),
loc.clone(), loc.clone(),
var.clone(), var.to_string(),
)); ));
} else { } else {
bound_variables.insert(var.clone(), loc.clone()); bound_variables.insert(var.to_string(), loc.clone());
} }
} }
Statement::Print(_, var) if bound_variables.contains_key(var) => {} Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {}
Statement::Print(loc, var) => { Statement::Print(loc, var) => {
errors.push(Error::UnboundVariable(loc.clone(), var.clone())) errors.push(Error::UnboundVariable(loc.clone(), var.to_string()))
} }
} }
@@ -127,6 +147,15 @@ impl Expression {
vec![Error::UnboundVariable(loc.clone(), var.clone())], vec![Error::UnboundVariable(loc.clone(), var.clone())],
vec![], vec![],
), ),
Expression::Cast(location, t, expr) => {
let (mut errs, warns) = expr.validate(variable_map);
if PrimitiveType::from_str(t).is_err() {
errs.push(Error::UnknownType(location.clone(), t.clone()))
}
(errs, warns)
}
Expression::Primitive(_, _, args) => { Expression::Primitive(_, _, args) => {
let mut errors = vec![]; let mut errors = vec![];
let mut warnings = vec![]; let mut warnings = vec![];
@@ -142,3 +171,19 @@ impl Expression {
} }
} }
} }
#[test]
fn cast_checks_are_reasonable() {
let good_stmt = Statement::parse(0, "x = <u16>4u8;").expect("valid test case");
let (good_errs, good_warns) = good_stmt.validate(&mut HashMap::new());
assert!(good_errs.is_empty());
assert!(good_warns.is_empty());
let bad_stmt = Statement::parse(0, "x = <apple>4u8;").expect("valid test case");
let (bad_errs, bad_warns) = bad_stmt.validate(&mut HashMap::new());
assert!(bad_warns.is_empty());
assert_eq!(bad_errs.len(), 1);
assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple"));
}

52
src/type_infer.rs Normal file
View File

@@ -0,0 +1,52 @@
//! A type inference pass for NGR
//!
//! The type checker implemented here is a relatively straightforward one, designed to be
//! fairly easy to understand rather than super fast. So don't be expecting the fastest
//! type checker in the West, here.
//!
//! The actual type checker operates in three phases. In the first phase, we translate
//! the syntax AST into something that's close to the final IR. During the process, we
//! generate a list of type constraints to solve. In the second phase, we try to solve
//! all the constraints we've generated. If that's successful, in the final phase, we
//! do the final conversion to the IR AST, filling in any type information we've learned
//! along the way.
mod ast;
mod convert;
mod finalize;
mod solve;
use self::convert::convert_program;
use self::finalize::finalize_program;
use self::solve::solve_constraints;
pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning};
use crate::ir::ast as ir;
use crate::syntax;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
#[cfg(test)]
use proptest::prelude::Arbitrary;
impl syntax::Program {
/// Infer the types for the syntactic AST, returning either a type-checked program in
/// the IR, or a series of type errors encountered during inference.
///
/// You really should have made sure that this program was validated before running
/// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> TypeInferenceResult<ir::Program> {
let mut constraint_db = vec![];
let program = convert_program(self, &mut constraint_db);
let inference_result = solve_constraints(constraint_db);
inference_result.map(|resolutions| finalize_program(program, &resolutions))
}
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) {
let syntax_result = input.eval();
let ir = input.type_infer().expect("arbitrary should generate type-safe programs");
let ir_result = ir.eval();
proptest::prop_assert_eq!(syntax_result, ir_result);
}
}

336
src/type_infer/ast.rs Normal file
View File

@@ -0,0 +1,336 @@
pub use crate::ir::ast::Primitive;
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
/// to want to use while we're doing type inference, but don't want to keep around
/// afterwards. These are:
///
/// * A notion of a type variable
/// * An unknown numeric constant form
///
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use std::fmt;
use std::sync::atomic::AtomicUsize;
/// We're going to represent variables as interned strings.
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file.
///
/// In addition, for the moment there's not really much of interest to hold here besides
/// the list of statements read from the file. Order is important. In the future, you
/// could imagine caching analysis information in this structure.
///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
/// The representation of a statement in the language.
///
/// For now, this is either a binding site (`x = 4`) or a print statement
/// (`print x`). Someday, though, more!
///
/// As with `Program`, this type implements [`Pretty`], which should
/// be used to display the structure whenever possible. It does not
/// implement [`Arbitrary`], though, mostly because it's slightly
/// complicated to do so.
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
}
}
}
/// The representation of an expression.
///
/// Note that expressions, like everything else in this syntax tree,
/// supports [`Pretty`], and it's strongly encouraged that you use
/// that trait/module when printing these structures.
///
/// Also, Expressions at this point in the compiler are explicitly
/// defined so that they are *not* recursive. By this point, if an
/// expression requires some other data (like, for example, invoking
/// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a
/// variable reference.
#[derive(Debug, PartialEq)]
pub enum Expression {
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Clone, Debug, PartialEq)]
pub enum ValueOrRef {
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
Expression::Atomic(value)
}
}
/// A constant in the IR.
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Unknown(Option<u8>, u64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::Unknown(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
}
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType),
}
impl Type {
pub fn is_concrete(&self) -> bool {
!matches!(self, Type::Variable(_, _))
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f),
}
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), name)
}

378
src/type_infer/convert.rs Normal file
View File

@@ -0,0 +1,378 @@
use super::ast as ir;
use super::ast::Type;
use crate::eval::PrimitiveType;
use crate::syntax::{self, ConstantType};
use crate::type_infer::solve::Constraint;
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
///
/// If the input function has been validated (which it should be), then this should run
/// into no error conditions. However, if you failed to validate the input, then this
/// function can panic.
pub fn convert_program(
mut program: syntax::Program,
constraint_db: &mut Vec<Constraint>,
) -> ir::Program {
let mut statements = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
for stmt in program.statements.drain(..) {
statements.append(&mut convert_statement(
stmt,
constraint_db,
&mut renames,
&mut bindings,
));
}
ir::Program { statements }
}
/// This function takes a syntactic statements and converts it into a series of
/// IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name.to_string());
let final_name = renames
.get(&iname)
.map(Clone::clone)
.unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
}
syntax::Statement::Binding(loc, name, expr) => {
let (mut prereqs, expr, ty) =
convert_expression(expr, constraint_db, renames, bindings);
let iname = ArcIntern::new(name.to_string());
let final_name = if bindings.contains_key(&iname) {
let new_name = ir::gensym(iname.as_str());
renames.insert(iname, new_name.clone());
new_name
} else {
iname
};
bindings.insert(final_name.clone(), ty.clone());
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
prereqs
}
}
}
/// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression {
syntax::Expression::Value(loc, val) => match val {
syntax::Value::Number(base, mctype, value) => {
let (newval, newtype) = match mctype {
None => {
let newtype = ir::gentype();
let newval = ir::Value::Unknown(base, value);
constraint_db.push(Constraint::ConstantNumericType(
loc.clone(),
newtype.clone(),
));
(newval, newtype)
}
Some(ConstantType::U8) => (
ir::Value::U8(base, value as u8),
ir::Type::Primitive(PrimitiveType::U8),
),
Some(ConstantType::U16) => (
ir::Value::U16(base, value as u16),
ir::Type::Primitive(PrimitiveType::U16),
),
Some(ConstantType::U32) => (
ir::Value::U32(base, value as u32),
ir::Type::Primitive(PrimitiveType::U32),
),
Some(ConstantType::U64) => (
ir::Value::U64(base, value),
ir::Type::Primitive(PrimitiveType::U64),
),
Some(ConstantType::I8) => (
ir::Value::I8(base, value as i8),
ir::Type::Primitive(PrimitiveType::I8),
),
Some(ConstantType::I16) => (
ir::Value::I16(base, value as i16),
ir::Type::Primitive(PrimitiveType::I16),
),
Some(ConstantType::I32) => (
ir::Value::I32(base, value as i32),
ir::Type::Primitive(PrimitiveType::I32),
),
Some(ConstantType::I64) => (
ir::Value::I64(base, value as i64),
ir::Type::Primitive(PrimitiveType::I64),
),
};
constraint_db.push(Constraint::FitsInNumType(
loc.clone(),
newtype.clone(),
value,
));
(
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
newtype,
)
}
},
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
let rtype = bindings
.get(&final_name)
.cloned()
.expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype)
}
syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) =
convert_expression(*expr, constraint_db, renames, bindings);
let val_or_ref = simplify_expr(nexpr, &mut stmts);
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
let target_type = Type::Primitive(target_prim_type);
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
(stmts, res, target_type)
}
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = ir::gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) =
convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(
loc.clone(),
primop,
atypes.clone(),
ret_type.clone(),
));
(
stmts,
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
ret_type,
)
}
}
}
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref,
expr => {
let etype = expr.type_of().clone();
let loc = expr.location().clone();
let nname = ir::gensym("g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
stmts.push(nbinding);
ir::ValueOrRef::Ref(loc, etype, nname)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::syntax::Location;
fn one() -> syntax::Expression {
syntax::Expression::Value(
Location::manufactured(),
syntax::Value::Number(None, None, 1),
)
}
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool {
for x in x.iter() {
if f(x) {
return true;
}
}
false
}
fn infer_expression(
x: syntax::Expression,
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) {
let mut constraints = Vec::new();
let renames = HashMap::new();
let mut bindings = HashMap::new();
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
(expr, stmts, constraints, ty)
}
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
let mut constraints = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
(res, constraints)
}
#[test]
fn constant_one() {
let (expr, stmts, constraints, ty) = infer_expression(one());
assert!(stmts.is_empty());
assert!(matches!(
expr,
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty)
));
}
#[test]
fn one_plus_one() {
let opo = syntax::Expression::Primitive(
Location::manufactured(),
"+".to_string(),
vec![one(), one()],
);
let (expr, stmts, constraints, ty) = infer_expression(opo);
assert!(stmts.is_empty());
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
));
}
#[test]
fn one_plus_one_plus_one() {
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
let (stmts, constraints) = infer_statement(stmt);
assert_eq!(stmts.len(), 2);
let ir::Statement::Binding(
_args,
name1,
temp_ty1,
ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1),
) = stmts.get(0).expect("item two")
else {
panic!("Failed to match first statement");
};
let ir::Statement::Binding(
_args,
name2,
temp_ty2,
ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2),
) = stmts.get(1).expect("item two")
else {
panic!("Failed to match second statement");
};
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] =
&primargs1[..]
else {
panic!("Failed to match first arguments");
};
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] =
&primargs2[..]
else {
panic!("Failed to match first arguments");
};
assert_ne!(name1, name2);
assert_ne!(temp_ty1, temp_ty2);
assert_ne!(primty1, primty2);
assert_eq!(name1, left2name);
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty)
));
for (i, s) in stmts.iter().enumerate() {
println!("{}: {:?}", i, s);
}
for (i, c) in constraints.iter().enumerate() {
println!("{}: {:?}", i, c);
}
}
}

187
src/type_infer/finalize.rs Normal file
View File

@@ -0,0 +1,187 @@
use super::{ast as input, solve::TypeResolutions};
use crate::{eval::PrimitiveType, ir as output};
pub fn finalize_program(
mut program: input::Program,
resolutions: &TypeResolutions,
) -> output::Program {
output::Program {
statements: program
.statements
.drain(..)
.map(|x| finalize_statement(x, resolutions))
.collect(),
}
}
fn finalize_statement(
statement: input::Statement,
resolutions: &TypeResolutions,
) -> output::Statement {
match statement {
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding(
loc,
var,
finalize_type(ty, resolutions),
finalize_expression(expr, resolutions),
),
input::Statement::Print(loc, ty, var) => {
output::Statement::Print(loc, finalize_type(ty, resolutions), var)
}
}
}
fn finalize_expression(
expression: input::Expression,
resolutions: &TypeResolutions,
) -> output::Expression {
match expression {
input::Expression::Atomic(val_or_ref) => {
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
}
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
loc,
finalize_type(target, resolutions),
finalize_val_or_ref(val_or_ref, resolutions),
),
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
loc,
finalize_type(ty, resolutions),
prim,
args.drain(..)
.map(|x| finalize_val_or_ref(x, resolutions))
.collect(),
),
}
}
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type {
match ty {
input::Type::Primitive(x) => output::Type::Primitive(x),
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => output::Type::Primitive(*pt),
},
}
}
fn finalize_val_or_ref(
valref: input::ValueOrRef,
resolutions: &TypeResolutions,
) -> output::ValueOrRef {
match valref {
input::ValueOrRef::Ref(loc, ty, var) => {
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var)
}
input::ValueOrRef::Value(loc, ty, val) => {
let new_type = finalize_type(ty, resolutions);
match val {
input::Value::Unknown(base, value) => match new_type {
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U8(base, value as u8),
),
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U16(base, value as u16),
),
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U32(base, value as u32),
),
output::Type::Primitive(PrimitiveType::U64) => {
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I8(base, value as i8),
),
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I16(base, value as i16),
),
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I32(base, value as i32),
),
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I64(base, value as i64),
),
},
input::Value::U8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
}
input::Value::U16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
}
input::Value::U32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
}
input::Value::U64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
input::Value::I8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
}
input::Value::I16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
}
input::Value::I32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
}
input::Value::I64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
}
}
}
}
}

542
src/type_infer/solve.rs Normal file
View File

@@ -0,0 +1,542 @@
use super::ast as ir;
use super::ast::Type;
use crate::{eval::PrimitiveType, syntax::Location};
use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern;
use std::{collections::HashMap, fmt};
/// A type inference constraint that we're going to need to solve.
#[derive(Debug)]
pub enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, Type),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, Type, u64),
/// The given primitive has the proper arguments types associated with it
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
/// The given type can be casted to the target type safely
CanCastTo(Location, Type, Type),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, Type),
/// The given type is attached to a constant and must be some numeric type.
/// If we can't figure it out, we should warn the user and then just use a
/// default.
ConstantNumericType(Location, Type),
/// The two types should be equivalent
Equivalent(Location, Type, Type),
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => {
write!(f, "PRIM {} {} -> {}", op, args[0], ret)
}
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => {
write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret)
}
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
}
}
}
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
/// The results of type inference; like [`Result`], but with a bit more information.
///
/// This result is parameterized, because sometimes it's handy to return slightly
/// different things; there's a [`TypeInferenceResult::map`] function for performing
/// those sorts of conversions.
pub enum TypeInferenceResult<Result> {
Success {
result: Result,
warnings: Vec<TypeInferenceWarning>,
},
Failure {
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
},
}
impl<R> TypeInferenceResult<R> {
// If this was a successful type inference, run the function over the result to
// create a new result.
//
// This is the moral equivalent of [`Result::map`], but for type inference results.
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
where
F: FnOnce(R) -> U,
{
match self {
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
result: f(result),
warnings,
},
TypeInferenceResult::Failure { errors, warnings } => {
TypeInferenceResult::Failure { errors, warnings }
}
}
}
// Return the final result, or panic if it's not a success
pub fn expect(self, msg: &str) -> R {
match self {
TypeInferenceResult::Success { result, .. } => result,
TypeInferenceResult::Failure { .. } => {
panic!("tried to get value from failed type inference: {}", msg)
}
}
}
}
/// The various kinds of errors that can occur while doing type inference.
pub enum TypeInferenceError {
/// The user provide a constant that is too large for its inferred type.
ConstantTooLarge(Location, PrimitiveType, u64),
/// The two types needed to be equivalent, but weren't.
NotEquivalent(Location, PrimitiveType, PrimitiveType),
/// We cannot safely cast the first type to the second type.
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
/// The primitive invocation provided the wrong number of arguments.
WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize),
/// We had a constraint we just couldn't solve.
CouldNotSolve(Constraint),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
fn from(value: TypeInferenceError) -> Self {
match value {
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
.labelled_error("constant too large for type")
.with_message(format!(
"Type {} has a max value of {}, which is smaller than {}",
primty,
primty.max_value(),
value
)),
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
.labelled_error("type inference error")
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
.labelled_error("unsafe type cast")
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
.labelled_error("wrong number of arguments")
.with_message(format!(
"expected {} for {}, received {}",
if lower == upper && lower > 1 {
format!("{} arguments", lower)
} else if lower == upper {
format!("{} argument", lower)
} else {
format!("{}-{} arguments", lower, upper)
},
prim,
observed
)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {:#?}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {:#?} were equivalent",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not determine if {} could fit in {}",
val, ty
))
}
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if {} was a numeric type", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) =>
panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty),
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if type {} was printable", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not tell if primitive {} received the proper argument types",
prim
))
}
}
}
}
/// Warnings that we might want to tell the user about.
///
/// These are fine, probably, but could indicate some behavior the user might not
/// expect, and so they might want to do something about them.
pub enum TypeInferenceWarning {
DefaultedTo(Location, Type),
}
impl From<TypeInferenceWarning> for Diagnostic<usize> {
fn from(value: TypeInferenceWarning) -> Self {
match value {
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
.with_labels(vec![loc.primary_label().with_message("unknown type")])
.with_message(format!("Defaulted unknown type to {}", ty)),
}
}
}
/// Solve all the constraints in the provided database.
///
/// This process can take a bit, so you might not want to do it multiple times. Basically,
/// it's going to grind on these constraints until either it figures them out, or it stops
/// making progress. I haven't done the math on the constraints to even figure out if this
/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time.
///
/// The return value is a type inference result, which pairs some warnings with either a
/// successful set of type resolutions (mappings from type variables to their values), or
/// a series of inference errors.
pub fn solve_constraints(
mut constraint_db: Vec<Constraint>,
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
let mut warnings = vec![];
let mut resolutions = HashMap::new();
let mut changed_something = true;
// We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop.
while changed_something && !constraint_db.is_empty() {
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
changed_something = false;
// This is sort of a double-buffering thing; we're going to rename constraint_db
// and then set it to a new empty vector, which we'll add to as we find new
// constraints or find ourselves unable to solve existing ones.
let mut local_constraints = constraint_db;
constraint_db = vec![];
// OK. First thing we're going to do is run through all of our constraints,
// and see if we can solve any, or reduce them to theoretically more simple
// constraints.
for constraint in local_constraints.drain(..) {
match constraint {
// Currently, all of our types are printable
Constraint::Printable(_loc, _ty) => changed_something = true,
// Case #1: We have two primitive types. If they're equal, we've discharged this
// constraint! We can just continue. If they're not equal, add an error and then
// see what else we come up with.
Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => {
if t1 != t2 {
errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2));
}
changed_something = true;
}
// Case #2: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name))
| Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions.
Constraint::Equivalent(
ref loc,
Type::Variable(_, ref name1),
Type::Variable(_, ref name2),
) => match (resolutions.get(name1), resolutions.get(name2)) {
(None, None) => {
constraint_db.push(constraint);
}
(Some(pt), None) => {
resolutions.insert(name2.clone(), *pt);
changed_something = true;
}
(None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt);
changed_something = true;
}
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
changed_something = true;
}
(Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2));
changed_something = true;
}
},
// Make sure that the provided number fits within the provided constant type. For the
// moment, we're going to call an error here a failure, although this could be a
// warning in the future.
Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => {
if ctype.max_value() < val {
errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val));
}
changed_something = true;
}
// If we have a non-constant type, then let's see if we can advance this to a constant
// type
Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Variable(vloc, var),
val,
)),
Some(nt) => {
constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Primitive(*nt),
val,
));
changed_something = true;
}
}
}
// If the left type in a "can cast to" check is a variable, let's see if we can advance
// it into something more tangible
Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
Type::Variable(vloc, var),
to_type,
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
Type::Primitive(*nt),
to_type,
));
changed_something = true;
}
}
}
// If the right type in a "can cast to" check is a variable, same deal
Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Primitive(*nt),
));
changed_something = true;
}
}
}
// If both of them are types, then we can actually do the test. yay!
Constraint::CanCastTo(
loc,
Type::Primitive(from_type),
Type::Primitive(to_type),
) => {
if !from_type.can_cast_to(&to_type) {
errors.push(TypeInferenceError::CannotSafelyCast(
loc, from_type, to_type,
));
}
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::NumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db
.push(Constraint::NumericType(loc, Type::Variable(vloc, var))),
Some(nt) => {
constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::NumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::ConstantNumericType(
loc,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db
.push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::ConstantNumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// OK, this one could be a little tricky if we tried to do it all at once, but
// instead what we're going to do here is just use this constraint to generate
// a bunch more constraints, and then go have the engine solve those. The only
// real errors we're going to come up with here are "arity errors"; errors we
// find by discovering that the number of arguments provided doesn't make sense
// given the primitive being used.
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide
if args.len() != 2 =>
{
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
2,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc, left, ret));
changed_something = true;
}
ir::Primitive::Minus if args.is_empty() || args.len() > 2 => {
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
1,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Minus if args.len() == 1 => {
let arg = args.pop().expect("1 > 0");
constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(loc, arg, ret));
changed_something = true;
}
ir::Primitive::Minus => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret));
changed_something = true;
}
},
}
}
// If that didn't actually come up with anything, and we just recycled all the constraints
// back into the database unchanged, then let's take a look for cases in which we just
// wanted something we didn't know to be a number. Basically, those are cases where the
// user just wrote a number, but didn't tell us what type it was, and there isn't enough
// information in the context to tell us. If that happens, we'll just set that type to
// be u64, and warn the user that we did so.
if !changed_something && !constraint_db.is_empty() {
local_constraints = constraint_db;
constraint_db = vec![];
for constraint in local_constraints.drain(..) {
match constraint {
Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => {
let resty = Type::Primitive(PrimitiveType::U64);
constraint_db.push(Constraint::Equivalent(
loc.clone(),
t,
Type::Primitive(PrimitiveType::U64),
));
warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty));
changed_something = true;
}
_ => constraint_db.push(constraint),
}
}
}
}
// OK, we left our loop. Which means that either we solved everything, or we didn't.
// If we didn't, turn the unsolved constraints into type inference errors, and add
// them to our error list.
let mut unsolved_constraint_errors = constraint_db
.drain(..)
.map(TypeInferenceError::CouldNotSolve)
.collect();
errors.append(&mut unsolved_constraint_errors);
// How'd we do?
if errors.is_empty() {
TypeInferenceResult::Success {
result: resolutions,
warnings,
}
} else {
TypeInferenceResult::Failure { errors, warnings }
}
}