jit works

This commit is contained in:
2024-02-02 10:31:13 -08:00
parent 7ebb31b42f
commit 4ba196d2a6
21 changed files with 477 additions and 185 deletions

View File

@@ -60,6 +60,7 @@ pub struct Backend<M: Module> {
data_ctx: DataDescription,
runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>,
defined_functions: HashMap<ArcIntern<String>, FuncId>,
defined_symbols: HashMap<ArcIntern<String>, (DataId, ConstantType)>,
output_buffer: Option<String>,
platform: Triple,
@@ -92,6 +93,7 @@ impl Backend<JITModule> {
data_ctx: DataDescription::new(),
runtime_functions,
defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer,
platform: Triple::host(),
@@ -132,6 +134,7 @@ impl Backend<ObjectModule> {
data_ctx: DataDescription::new(),
runtime_functions,
defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer: None,
platform,

View File

@@ -55,6 +55,7 @@ impl From<BackendError> for Diagnostic<usize> {
match value {
BackendError::Cranelift(me) => {
Diagnostic::error().with_message(format!("Internal cranelift error: {}", me))
.with_notes(vec![format!("{:?}", me)])
}
BackendError::BuiltinError(me) => {
Diagnostic::error().with_message(format!("Internal runtime function error: {}", me))

View File

@@ -26,34 +26,7 @@ impl Backend<JITModule> {
/// of the built-in test systems.)
pub fn eval(program: Program<Type>) -> Result<String, EvalError<Expression<Type>>> {
let mut jitter = Backend::jit(Some(String::new()))?;
let mut function_map = HashMap::new();
let mut main_function_body = vec![];
for item in program.items {
match item {
TopLevel::Function(name, args, rettype, body) => {
let function_id =
jitter.compile_function(name.as_str(), args.as_slice(), rettype, body)?;
function_map.insert(name, function_id);
}
TopLevel::Statement(stmt) => {
main_function_body.push(stmt);
}
}
}
let main_function_body = Expression::Block(
Location::manufactured(),
Type::Primitive(crate::eval::PrimitiveType::Void),
main_function_body,
);
let function_id = jitter.compile_function(
"___test_jit_eval___",
&[],
Type::Primitive(crate::eval::PrimitiveType::Void),
main_function_body,
)?;
let function_id = jitter.compile_program("___test_jit_eval___", program)?;
jitter.module.finalize_definitions()?;
let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
@@ -89,8 +62,13 @@ impl Backend<ObjectModule> {
for item in program.items {
match item {
TopLevel::Function(name, args, rettype, body) => {
let function_id =
backend.compile_function(name.as_str(), args.as_slice(), rettype, body)?;
let function_id = backend.compile_function(
&mut HashMap::new(),
name.as_str(),
args.as_slice(),
rettype,
body,
)?;
function_map.insert(name, function_id);
}
@@ -111,6 +89,7 @@ impl Backend<ObjectModule> {
let executable_path = my_directory.path().join("test_executable");
backend.compile_function(
&mut HashMap::new(),
"gogogo",
&[],
Type::Primitive(crate::eval::PrimitiveType::Void),
@@ -154,6 +133,7 @@ impl Backend<ObjectModule> {
.join("runtime")
.join("rts.c"),
)
.arg("-Wl,-ld_classic")
.arg(object_file)
.arg("-o")
.arg(executable_path)
@@ -206,16 +186,15 @@ proptest::proptest! {
#[test]
fn jit_backend(program in Program::arbitrary()) {
use crate::eval::PrimOpError;
use pretty::{DocAllocator, Pretty};
let allocator = pretty::BoxAllocator;
allocator
.text("---------------")
.append(allocator.hardline())
.append(program.pretty(&allocator))
.1
.render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
.expect("rendering works");
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let basic_result = program.eval().map(|(_,x)| x);

View File

@@ -1,24 +1,23 @@
use crate::backend::error::BackendError;
use crate::backend::Backend;
use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable};
use crate::syntax::{ConstantType, Location};
use crate::util::scoped_map::ScopedMap;
use cranelift_codegen::ir::{
self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, Signature, UserFuncName,
self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context;
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_module::{FuncId, Linkage, Module};
use cranelift_module::{DataDescription, FuncId, Linkage, Module};
use internment::ArcIntern;
use crate::backend::error::BackendError;
use crate::backend::Backend;
use std::collections::{hash_map, HashMap};
/// When we're talking about variables, it's handy to just have a table that points
/// from a variable to "what to do if you want to reference this variable", which is
/// agnostic about whether the variable is local, global, an argument, etc. Since
/// the type of that function is a little bit annoying, we summarize it here.
enum ReferenceBuilder {
pub enum ReferenceBuilder {
Global(ConstantType, GlobalValue),
Local(ConstantType, cranelift_frontend::Variable),
Argument(ConstantType, entities::Value),
@@ -29,7 +28,8 @@ impl ReferenceBuilder {
match self {
ReferenceBuilder::Global(ty, gv) => {
let cranelift_type = ir::Type::from(*ty);
let value = builder.ins().symbol_value(cranelift_type, *gv);
let ptr_value = builder.ins().symbol_value(types::I64, *gv);
let value = builder.ins().load(cranelift_type, MemFlags::new(), ptr_value, 0);
(value, *ty)
}
@@ -81,11 +81,44 @@ impl<M: Module> Backend<M> {
program: Program<Type>,
) -> Result<FuncId, BackendError> {
let mut generated_body = vec![];
let mut variables = HashMap::new();
for (top_level_name, top_level_type) in program.get_top_level_variables() {
match top_level_type {
Type::Function(argument_types, return_type) => {
let func_id = self.declare_function(
top_level_name.as_str(),
Linkage::Export,
argument_types,
*return_type,
)?;
self.defined_functions.insert(top_level_name, func_id);
}
Type::Primitive(pt) => {
let data_id = self.module.declare_data(
top_level_name.as_str(),
Linkage::Export,
true,
false,
)?;
self.module.define_data(data_id, &pt.blank_data())?;
self.defined_symbols
.insert(top_level_name, (data_id, pt.into()));
}
}
}
let void = Type::Primitive(PrimitiveType::Void);
let main_func_id =
self.declare_function(function_name, Linkage::Export, vec![], void.clone())?;
self.defined_functions
.insert(ArcIntern::new(function_name.to_string()), main_func_id);
for item in program.items {
match item {
TopLevel::Function(name, args, rettype, body) => {
self.compile_function(name.as_str(), &args, rettype, body)?;
self.compile_function(&mut variables, name.as_str(), &args, rettype, body)?;
}
TopLevel::Statement(stmt) => {
@@ -94,8 +127,8 @@ impl<M: Module> Backend<M> {
}
}
let void = Type::Primitive(PrimitiveType::Void);
self.compile_function(
&mut variables,
function_name,
&[],
void.clone(),
@@ -103,14 +136,47 @@ impl<M: Module> Backend<M> {
)
}
fn declare_function(
&mut self,
name: &str,
linkage: Linkage,
argument_types: Vec<Type>,
return_type: Type,
) -> Result<FuncId, cranelift_module::ModuleError> {
let basic_signature = Signature {
params: argument_types
.iter()
.map(|t| self.translate_type(t))
.collect(),
returns: if return_type == Type::Primitive(PrimitiveType::Void) {
vec![]
} else {
vec![self.translate_type(&return_type)]
},
call_conv: CallConv::triple_default(&self.platform),
};
// this generates the handle for the function that we'll eventually want to
// return to the user. For now, we declare all functions defined by this
// function as public/global/exported, although we may want to reconsider
// this decision later.
self.module
.declare_function(name, linkage, &basic_signature)
}
/// Compile the given function.
pub fn compile_function(
&mut self,
variables: &mut HashMap<Variable, ReferenceBuilder>,
function_name: &str,
arguments: &[(Variable, Type)],
return_type: Type,
body: Expression<Type>,
) -> Result<FuncId, BackendError> {
// reset the next variable counter. this value shouldn't matter; hopefully
// we won't be using close to 2^32 variables!
self.reset_local_variable_tracker();
let basic_signature = Signature {
params: arguments
.iter()
@@ -124,17 +190,23 @@ impl<M: Module> Backend<M> {
call_conv: CallConv::triple_default(&self.platform),
};
// reset the next variable counter. this value shouldn't matter; hopefully
// we won't be using close to 2^32 variables!
self.reset_local_variable_tracker();
// this generates the handle for the function that we'll eventually want to
// return to the user. For now, we declare all functions defined by this
// function as public/global/exported, although we may want to reconsider
// this decision later.
let func_id =
self.module
.declare_function(function_name, Linkage::Export, &basic_signature)?;
let interned_name = ArcIntern::new(function_name.to_string());
let func_id = match self.defined_functions.entry(interned_name) {
hash_map::Entry::Occupied(entry) => *entry.get(),
hash_map::Entry::Vacant(vac) => {
let func_id = self.module.declare_function(
function_name,
Linkage::Export,
&basic_signature,
)?;
vac.insert(func_id);
func_id
}
};
// Next we have to generate the compilation context for the rest of this
// function. Currently, we generate a fresh context for every function.
@@ -145,14 +217,6 @@ impl<M: Module> Backend<M> {
let user_func_name = UserFuncName::user(0, func_id.as_u32());
ctx.func = Function::with_name_signature(user_func_name, basic_signature);
// Let's start creating the variable table we'll use when we're dereferencing
// them later. This table is a little interesting because instead of pointing
// from data to data, we're going to point from data (the variable) to an
// action to take if we encounter that variable at some later point. This
// makes it nice and easy to have many different ways to access data, such
// as globals, function arguments, etc.
let mut variables: ScopedMap<ArcIntern<String>, ReferenceBuilder> = ScopedMap::new();
// At the outer-most scope of things, we'll put global variables we've defined
// elsewhere in the program.
for (name, (data_id, ty)) in self.defined_symbols.iter() {
@@ -160,12 +224,6 @@ impl<M: Module> Backend<M> {
variables.insert(name.clone(), ReferenceBuilder::Global(*ty, local_data));
}
// Once we have these, we're going to actually push a level of scope and
// add our arguments. We push scope because if there happen to be any with
// the same name (their shouldn't be, but just in case), we want the arguments
// to win.
variables.new_scope();
// Finally (!), we generate the function builder that we're going to use to
// make this function!
let mut fctx = FunctionBuilderContext::new();
@@ -192,7 +250,7 @@ impl<M: Module> Backend<M> {
builder.switch_to_block(main_block);
let (value, _) = self.compile_expression(body, &mut variables, &mut builder)?;
let (value, _) = self.compile_expression(body, variables, &mut builder)?;
// Now that we're done, inject a return function (one with no actual value; basically
// the equivalent of Rust's `return;`). We then seal the block (which lets Cranelift
@@ -221,7 +279,7 @@ impl<M: Module> Backend<M> {
fn compile_expression(
&mut self,
expr: Expression<Type>,
variables: &mut ScopedMap<Variable, ReferenceBuilder>,
variables: &mut HashMap<Variable, ReferenceBuilder>,
builder: &mut FunctionBuilder,
) -> Result<(entities::Value, ConstantType), BackendError> {
match expr {
@@ -282,6 +340,29 @@ impl<M: Module> Backend<M> {
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
(ConstantType::Void, Type::Primitive(PrimitiveType::Void)) => {
Ok((val, val_type))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::I16))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::I32))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::I32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
_ => Err(BackendError::InvalidTypeCast {
from: val_type.into(),
to: target_type,
@@ -327,7 +408,7 @@ impl<M: Module> Backend<M> {
}
Expression::Block(_, _, mut exprs) => match exprs.pop() {
None => Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)),
None => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)),
Some(last) => {
for inner in exprs {
// we can ignore all of these return values and such, because we
@@ -354,7 +435,7 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation.
let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var);
let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var.clone());
let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
@@ -379,7 +460,7 @@ impl<M: Module> Backend<M> {
print_func_ref,
&[buffer_ptr, name_ptr, vtype_repr, casted_val],
);
Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8))
Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void))
}
Expression::Bind(_, name, _, expr) => {
@@ -390,7 +471,7 @@ impl<M: Module> Backend<M> {
builder.declare_var(variable, ir_type);
builder.def_var(variable, value);
variables.insert(name, ReferenceBuilder::Local(value_type, variable));
Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8))
Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void))
}
}
}
@@ -400,7 +481,7 @@ impl<M: Module> Backend<M> {
fn compile_value_or_ref(
&self,
valref: ValueOrRef<Type>,
variables: &ScopedMap<Variable, ReferenceBuilder>,
variables: &HashMap<Variable, ReferenceBuilder>,
builder: &mut FunctionBuilder,
) -> Result<(entities::Value, ConstantType), BackendError> {
match valref {
@@ -453,3 +534,23 @@ impl<M: Module> Backend<M> {
}
}
}
impl PrimitiveType {
fn blank_data(&self) -> DataDescription {
let (size, alignment) = match self {
PrimitiveType::Void => (8, 8),
PrimitiveType::U8 => (1, 1),
PrimitiveType::U16 => (2, 2),
PrimitiveType::U32 => (4, 4),
PrimitiveType::U64 => (4, 4),
PrimitiveType::I8 => (1, 1),
PrimitiveType::I16 => (2, 2),
PrimitiveType::I32 => (4, 4),
PrimitiveType::I64 => (4, 4),
};
let mut result = DataDescription::new();
result.define_zeroinit(size);
result.set_align(alignment);
result
}
}

View File

@@ -119,7 +119,7 @@ extern "C" fn runtime_print(
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value),
Err(_) => format!("{} = {}<unknown type {}>", reconstituted, value, vtype_repr),
};
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {

View File

@@ -191,6 +191,13 @@ impl PrimitiveType {
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
(PrimitiveType::I16, Value::U8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I32, Value::U8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I32, Value::U16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::U32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::Void, Value::Void) => Ok(Value::Void),
_ => Err(PrimOpError::UnsafeCast {

View File

@@ -16,5 +16,6 @@ mod arbitrary;
pub mod ast;
mod eval;
mod strings;
mod top_level;
pub use ast::*;

View File

@@ -1,5 +1,5 @@
use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable};
use crate::ir::{Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable};
use crate::syntax::Location;
use crate::util::scoped_map::ScopedMap;
use proptest::strategy::{NewTree, Strategy, ValueTree};
@@ -288,14 +288,25 @@ fn generate_random_expression(
ExpressionType::Block => {
let num_stmts = BLOCK_LENGTH_DISTRIBUTION.sample(rng);
let mut stmts = Vec::new();
let mut last_type = Type::Primitive(PrimitiveType::Void);
if num_stmts == 0 {
return Expression::Block(Location::manufactured(), Type::void(), stmts);
}
env.new_scope();
for _ in 0..num_stmts {
let next = generate_random_expression(rng, env);
last_type = next.type_of();
for _ in 1..num_stmts {
let mut next = generate_random_expression(rng, env);
let next_type = next.type_of();
if !next_type.is_void() {
let name = generate_random_name(rng);
env.insert(name.clone(), next_type.clone());
next = Expression::Bind(Location::manufactured(), name, next_type, Box::new(next));
}
stmts.push(next);
}
let last_expr = generate_random_expression(rng, env);
let last_type = last_expr.type_of();
stmts.push(last_expr);
env.release_scope();
Expression::Block(Location::manufactured(), last_type, stmts)

View File

@@ -6,7 +6,10 @@ use crate::{
use internment::ArcIntern;
use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::arbitrary::Arbitrary;
use std::{fmt, str::FromStr, sync::atomic::AtomicUsize};
use std::convert::TryFrom;
use std::fmt;
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
use super::arbitrary::ProgramGenerator;
@@ -97,6 +100,19 @@ pub enum TopLevel<Type> {
Function(Variable, Vec<(Variable, Type)>, Type, Expression<Type>),
}
impl<T: Clone + TypeWithVoid + TypeWithFunction> TopLevel<T> {
/// Return the type of the item, as inferred or recently
/// computed.
pub fn type_of(&self) -> T {
match self {
TopLevel::Statement(expr) => expr.type_of(),
TopLevel::Function(_, args, ret, _) => {
T::build_function_type(args.iter().map(|(_, t)| t.clone()).collect(), ret.clone())
}
}
}
}
impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b TopLevel<Type>
where
A: 'a,
@@ -148,7 +164,7 @@ pub enum Expression<Type> {
}
impl<Type: Clone + TypeWithVoid> Expression<Type> {
/// Return a reference to the type of the expression, as inferred or recently
/// Return the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> Type {
match self {
@@ -242,6 +258,16 @@ where
}
}
impl Expression<Type> {
pub fn to_pretty(&self) -> String {
let arena = pretty::Arena::<()>::new();
let doc = self.pretty(&arena);
let mut output_bytes = Vec::new();
doc.render(72, &mut output_bytes).unwrap();
String::from_utf8(output_bytes).expect("pretty generates valid utf-8")
}
}
/// A type representing the primitives allowed in the language.
///
/// Having this as an enumeration avoids a lot of "this should not happen"
@@ -565,27 +591,56 @@ impl TypeOrVar {
}
}
impl PartialEq<Type> for TypeOrVar {
fn eq(&self, other: &Type) -> bool {
match other {
Type::Function(a, b) => match self {
TypeOrVar::Function(x, y) => x == a && y.as_ref() == b.as_ref(),
_ => false,
},
Type::Primitive(a) => match self {
TypeOrVar::Primitive(x) => a == x,
_ => false,
},
}
}
}
pub trait TypeWithVoid {
fn void() -> Self;
fn is_void(&self) -> bool;
}
impl TypeWithVoid for Type {
fn void() -> Self {
Type::Primitive(PrimitiveType::Void)
}
fn is_void(&self) -> bool {
self == &Type::Primitive(PrimitiveType::Void)
}
}
impl TypeWithVoid for TypeOrVar {
fn void() -> Self {
TypeOrVar::Primitive(PrimitiveType::Void)
}
fn is_void(&self) -> bool {
self == &TypeOrVar::Primitive(PrimitiveType::Void)
}
}
//impl From<Type> for TypeOrVar {
// fn from(value: Type) -> Self {
// TypeOrVar::Type(value)
// }
//}
pub trait TypeWithFunction: Sized {
fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self;
}
impl TypeWithFunction for Type {
fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self {
Type::Function(arg_types, Box::new(ret_type))
}
}
impl<T: Into<Type>> From<T> for TypeOrVar {
fn from(value: T) -> Self {
@@ -598,3 +653,24 @@ impl<T: Into<Type>> From<T> for TypeOrVar {
}
}
}
impl TryFrom<TypeOrVar> for Type {
type Error = TypeOrVar;
fn try_from(value: TypeOrVar) -> Result<Self, Self::Error> {
match value {
TypeOrVar::Function(args, ret) => {
let args = args
.into_iter()
.map(Type::try_from)
.collect::<Result<_, _>>()?;
let ret = Type::try_from(*ret)?;
Ok(Type::Function(args, Box::new(ret)))
}
TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)),
_ => Err(value),
}
}
}

45
src/ir/top_level.rs Normal file
View File

@@ -0,0 +1,45 @@
use crate::ir::{Expression, Program, TopLevel, TypeWithFunction, TypeWithVoid, Variable};
use std::collections::HashMap;
impl<T: Clone + TypeWithVoid + TypeWithFunction> Program<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this program.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
let mut result = HashMap::new();
for item in self.items.iter() {
result.extend(item.get_top_level_variables());
}
result
}
}
impl<T: Clone + TypeWithVoid + TypeWithFunction> TopLevel<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this top-level item.
///
/// For functions, this is the function name. For expressions this can be a little
/// bit more complicated, as it sort of depends on the block structuring.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
match self {
TopLevel::Function(name, _, _, _) => HashMap::from([(name.clone(), self.type_of())]),
TopLevel::Statement(expr) => expr.get_top_level_variables(),
}
}
}
impl<T: Clone> Expression<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this expression. Basically, returns the variable named in bind.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
match self {
Expression::Bind(_, name, ty, expr) => {
let mut tlvs = expr.get_top_level_variables();
tlvs.insert(name.clone(), ty.clone());
tlvs
},
_ => HashMap::new(),
}
}
}

View File

@@ -154,6 +154,19 @@ impl PartialEq for Expression {
}
}
impl Expression {
/// Get the location of the expression in the source file (if there is one).
pub fn location(&self) -> &Location {
match self {
Expression::Value(loc, _) => loc,
Expression::Reference(loc, _) => loc,
Expression::Cast(loc, _, _) => loc,
Expression::Primitive(loc, _, _) => loc,
Expression::Block(loc, _) => loc,
}
}
}
/// A value from the source syntax
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Value {

View File

@@ -270,6 +270,7 @@ impl TryFrom<i64> for ConstantType {
21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64),
255 => Ok(ConstantType::Void),
_ => Err(InvalidConstantType::Value(value)),
}
}

View File

@@ -4,6 +4,7 @@ use crate::syntax::{self, ConstantType};
use crate::type_infer::solve::Constraint;
use crate::util::scoped_map::ScopedMap;
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the
@@ -19,7 +20,7 @@ pub fn convert_program(
let mut constraint_db = Vec::new();
let mut items = Vec::new();
let mut renames = ScopedMap::new();
let mut bindings = ScopedMap::new();
let mut bindings = HashMap::new();
for item in program.items.drain(..) {
items.push(convert_top_level(
@@ -40,43 +41,68 @@ pub fn convert_top_level(
top_level: syntax::TopLevel,
constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>,
bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> ir::TopLevel<ir::TypeOrVar> {
match top_level {
syntax::TopLevel::Function(name, args, expr) => {
// First, let us figure out what we're going to name this function. If the user
// First, at some point we're going to want to know a location for this function,
// which should either be the name if we have one, or the body if we don't.
let function_location = match name {
None => expr.location().clone(),
Some(ref name) => name.location.clone(),
};
// Next, let us figure out what we're going to name this function. If the user
// didn't provide one, we'll just call it "function:<something>" for them. (We'll
// want a name for this function, eventually, so we might as well do it now.)
//
// If they did provide a name, see if we're shadowed. IF we are, then we'll have
// to specialize the name a bit. Otherwise we'll stick with their name.
let funname = match name {
let function_name = match name {
None => ir::gensym("function"),
Some(unbound) => finalize_name(bindings, renames, unbound),
};
// Now we manufacture types for the inputs and outputs, and then a type for the
// function itself. We're not going to make any claims on these types, yet; they're
// all just unknown type variables we need to work out.
let argtypes: Vec<ir::TypeOrVar> = args.iter().map(|_| ir::TypeOrVar::new()).collect();
// This function is going to have a type. We don't know what it is, but it'll have
// one.
let function_type = ir::TypeOrVar::new();
bindings.insert(function_name.clone(), function_type.clone());
// Then, let's figure out what to do with the argument names, which similarly
// may need to be renamed. We'll also generate some new type variables to associate
// with all of them.
//
// Note that we want to do all this in a new renaming scope, so that we shadow
// appropriately.
renames.new_scope();
let arginfo = args
.iter()
.map(|name| {
let new_type = ir::TypeOrVar::new();
constraint_db.push(Constraint::IsSomething(
name.location.clone(),
new_type.clone(),
));
let new_name = finalize_name(bindings, renames, name.clone());
bindings.insert(new_name.clone(), new_type.clone());
(new_name, new_type)
})
.collect::<Vec<_>>();
// Now we manufacture types for the outputs and then a type for the function itself.
// We're not going to make any claims on these types, yet; they're all just unknown
// type variables we need to work out.
let rettype = ir::TypeOrVar::new();
let funtype = ir::TypeOrVar::Function(argtypes.clone(), Box::new(rettype.clone()));
// Now let's bind these types into the environment. First, we bind our function
// namae to the function type we just generated.
bindings.insert(funname.clone(), funtype);
// And then we attach the argument names to the argument types. (We have to go
// convert all the names, first.)
let iargs: Vec<ArcIntern<String>> =
args.iter().map(|x| ArcIntern::new(x.to_string())).collect();
assert_eq!(argtypes.len(), iargs.len());
let mut function_args = vec![];
for ((arg_name, arg_type), orig_name) in iargs.iter().zip(argtypes).zip(args) {
bindings.insert(arg_name.clone(), arg_type.clone());
function_args.push((arg_name.clone(), arg_type.clone()));
constraint_db.push(Constraint::IsSomething(orig_name.location, arg_type));
}
let actual_function_type = ir::TypeOrVar::Function(
arginfo.iter().map(|x| x.1.clone()).collect(),
Box::new(rettype.clone()),
);
constraint_db.push(Constraint::Equivalent(
function_location,
function_type,
actual_function_type,
));
// Now let's convert the body over to the new IR.
let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings);
constraint_db.push(Constraint::Equivalent(
expr.location().clone(),
@@ -84,7 +110,10 @@ pub fn convert_top_level(
ty,
));
ir::TopLevel::Function(funname, function_args, rettype, expr)
// Remember to exit this scoping level!
renames.release_scope();
ir::TopLevel::Function(function_name, arginfo, rettype, expr)
}
syntax::TopLevel::Statement(stmt) => {
@@ -108,7 +137,7 @@ fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>,
bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> ir::Expression<ir::TypeOrVar> {
match statement {
syntax::Statement::Print(loc, name) => {
@@ -152,7 +181,7 @@ fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>,
bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> (ir::Expression<ir::TypeOrVar>, ir::TypeOrVar) {
match expression {
// converting values is mostly tedious, because there's so many cases
@@ -339,7 +368,7 @@ fn finalize_expression(
}
fn finalize_name(
bindings: &ScopedMap<ArcIntern<String>, ir::TypeOrVar>,
bindings: &HashMap<ArcIntern<String>, ir::TypeOrVar>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
name: syntax::Name,
) -> ArcIntern<String> {

View File

@@ -105,7 +105,7 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type {
TypeOrVar::Primitive(x) => Type::Primitive(x),
TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => Type::Primitive(*pt),
Some(pt) => pt.clone(),
},
TypeOrVar::Function(mut args, ret) => Type::Function(
args.drain(..)

View File

@@ -1,5 +1,5 @@
use crate::eval::PrimitiveType;
use crate::ir::{Primitive, TypeOrVar};
use crate::ir::{Primitive, Type, TypeOrVar};
use crate::syntax::Location;
use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern;
@@ -50,7 +50,7 @@ impl fmt::Display for Constraint {
}
}
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
pub type TypeResolutions = HashMap<ArcIntern<String>, Type>;
/// The results of type inference; like [`Result`], but with a bit more information.
///
@@ -257,18 +257,21 @@ pub fn solve_constraints(
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
let mut warnings = vec![];
let mut resolutions = HashMap::new();
let mut resolutions: HashMap<ArcIntern<String>, Type> = HashMap::new();
let mut changed_something = true;
println!("CONSTRAINTS:");
for constraint in constraint_db.iter() {
println!("{}", constraint);
}
// We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop.
while changed_something && !constraint_db.is_empty() {
println!("CONSTRAINT:");
for constraint in constraint_db.iter() {
println!(" {}", constraint);
}
println!("RESOLUTIONS:");
for (name, ty) in resolutions.iter() {
println!(" {} = {}", name, ty);
}
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
@@ -292,9 +295,13 @@ pub fn solve_constraints(
Constraint::IsSomething(_, TypeOrVar::Function(_, _))
| Constraint::IsSomething(_, TypeOrVar::Primitive(_)) => changed_something = true,
// Otherwise, we'll keep looking for it.
Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) => {
constraint_db.push(constraint);
// Otherwise, see if we've resolved this variable to anything. If not, add it
// back.
Constraint::IsSomething(_, TypeOrVar::Variable(_, ref name)) => {
if resolutions.get(name).is_none() {
constraint_db.push(constraint);
}
changed_something = true;
}
// Case #1a: We have two primitive types. If they're equal, we've discharged this
@@ -311,35 +318,7 @@ pub fn solve_constraints(
changed_something = true;
}
// Case #2: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(
loc,
TypeOrVar::Primitive(t),
TypeOrVar::Variable(_, name),
)
| Constraint::Equivalent(
loc,
TypeOrVar::Variable(_, name),
TypeOrVar::Primitive(t),
) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(
loc,
TypeOrVar::Primitive(t),
TypeOrVar::Primitive(*t2),
)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// Case #2: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions.
Constraint::Equivalent(
ref loc,
@@ -350,11 +329,11 @@ pub fn solve_constraints(
constraint_db.push(constraint);
}
(Some(pt), None) => {
resolutions.insert(name2.clone(), *pt);
resolutions.insert(name2.clone(), pt.clone());
changed_something = true;
}
(None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt);
resolutions.insert(name1.clone(), pt.clone());
changed_something = true;
}
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
@@ -363,13 +342,43 @@ pub fn solve_constraints(
(Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent(
loc.clone(),
TypeOrVar::Primitive(*pt1),
TypeOrVar::Primitive(*pt2),
pt1.clone().into(),
pt2.clone().into(),
));
changed_something = true;
}
},
// Case #3: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, t, TypeOrVar::Variable(vloc, name))
| Constraint::Equivalent(loc, TypeOrVar::Variable(vloc, name), t) => {
match resolutions.get(&name) {
None => match t.try_into() {
Ok(real_type) => {
resolutions.insert(name, real_type);
}
Err(variable_type) => {
constraint_db.push(Constraint::Equivalent(
loc,
variable_type,
TypeOrVar::Variable(vloc, name),
));
continue;
}
},
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(
loc,
t,
t2.clone().into(),
)),
}
changed_something = true;
}
// Case #4: Like primitives, but for function types. This is a little complicated, because
// we first want to resolve all the type variables in the two types, and then see if they're
// equivalent. Fortunately, though, we can cheat a bit. What we're going to do is first see
@@ -445,7 +454,7 @@ pub fn solve_constraints(
Some(nt) => {
constraint_db.push(Constraint::FitsInNumType(
loc,
TypeOrVar::Primitive(*nt),
nt.clone().into(),
val,
));
changed_something = true;
@@ -474,7 +483,7 @@ pub fn solve_constraints(
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
TypeOrVar::Primitive(*nt),
nt.clone().into(),
to_type,
));
changed_something = true;
@@ -494,7 +503,7 @@ pub fn solve_constraints(
constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
TypeOrVar::Primitive(*nt),
nt.clone().into(),
));
changed_something = true;
}
@@ -560,8 +569,7 @@ pub fn solve_constraints(
None => constraint_db
.push(Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var))),
Some(nt) => {
constraint_db
.push(Constraint::NumericType(loc, TypeOrVar::Primitive(*nt)));
constraint_db.push(Constraint::NumericType(loc, nt.clone().into()));
changed_something = true;
}
}
@@ -592,10 +600,8 @@ pub fn solve_constraints(
TypeOrVar::Variable(vloc, var),
)),
Some(nt) => {
constraint_db.push(Constraint::ConstantNumericType(
loc,
TypeOrVar::Primitive(*nt),
));
constraint_db
.push(Constraint::ConstantNumericType(loc, nt.clone().into()));
changed_something = true;
}
}