From a7b85d37dae1c4744da077911d48ffec9e2a651f Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sat, 16 Mar 2024 16:41:23 -0700 Subject: [PATCH] basic support for structures through the IR --- src/backend/into_crane.rs | 12 ++ src/ir/arbitrary.rs | 8 +- src/ir/ast.rs | 54 +++++++- src/ir/eval.rs | 25 ++++ src/ir/pretty.rs | 46 +++++++ src/syntax/ast.rs | 4 +- src/syntax/parser.lalrpop | 15 +- src/syntax/pretty.rs | 19 +-- src/type_infer.rs | 2 +- src/type_infer/convert.rs | 172 +++++++++++++++++++---- src/type_infer/finalize.rs | 36 ++++- src/type_infer/solve.rs | 273 ++++++++++++++++++++++++++++++++----- 12 files changed, 572 insertions(+), 94 deletions(-) diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 3806c8a..fe7e1c0 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -33,6 +33,9 @@ impl Backend { types::Type::triple_pointer_type(&self.platform), ir::ArgumentExtension::None, ), + Type::Structure(_) => { + unimplemented!() + } Type::Primitive(PrimitiveType::Void) => (types::I8, ir::ArgumentExtension::None), // FIXME? Type::Primitive(PrimitiveType::I8) => (types::I8, ir::ArgumentExtension::Sext), Type::Primitive(PrimitiveType::I16) => (types::I16, ir::ArgumentExtension::Sext), @@ -88,6 +91,10 @@ impl Backend { self.defined_symbols .insert(top_level_name, (data_id, pt.into())); } + + Type::Structure(_) => { + unimplemented!() + } } } @@ -392,6 +399,9 @@ impl Backend { } } + Expression::Construct(_, _, _, _) => unimplemented!(), + Expression::FieldRef(_, _, _, _) => unimplemented!(), + Expression::Block(_, _, mut exprs) => match exprs.pop() { None => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)), Some(last) => { @@ -511,8 +521,10 @@ impl Backend { Type::Function(_, _) => { panic!("function returns a function?") } + Type::Structure(_) => unimplemented!(), Type::Primitive(ct) => Ok((*result, ct.into())), }, + Type::Structure(_) => unimplemented!(), }, _ => panic!("don't support multi-value returns yet"), } diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs index 71de979..b704eaa 100644 --- a/src/ir/arbitrary.rs +++ b/src/ir/arbitrary.rs @@ -9,6 +9,7 @@ use proptest::test_runner::{TestRng, TestRunner}; use rand::distributions::{Distribution, WeightedIndex}; use rand::seq::SliceRandom; use rand::Rng; +use std::collections::HashMap; use std::str::FromStr; lazy_static::lazy_static! { @@ -214,7 +215,10 @@ impl ProgramTree { } } - let current = Program { items }; + let current = Program { + items, + type_definitions: HashMap::new(), + }; ProgramTree { _rng: rng, current } } @@ -328,6 +332,7 @@ fn generate_random_expression( .expect("actually chose type"); Expression::Cast(Location::manufactured(), Type::Primitive(*to_type), inner) } + Type::Structure(_) => unimplemented!(), } } @@ -350,6 +355,7 @@ fn generate_random_expression( Expression::Primitive(Location::manufactured(), out_type, primop, args) } }, + Type::Structure(_) => unimplemented!(), } } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index ae76a95..88820b6 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -2,6 +2,7 @@ use crate::eval::PrimitiveType; use crate::syntax::{ConstantType, Location}; use internment::ArcIntern; use proptest::arbitrary::Arbitrary; +use std::collections::HashMap; use std::convert::TryFrom; use std::str::FromStr; use std::sync::atomic::AtomicUsize; @@ -50,7 +51,9 @@ pub fn gensym(base: &str) -> Variable { pub struct Program { // For now, a program is just a vector of statements. In the future, we'll probably // extend this to include a bunch of other information, but for now: just a list. - pub(crate) items: Vec>, + pub items: Vec>, + // The set of types declared in this program. + pub type_definitions: HashMap, Type>, } impl Arbitrary for Program { @@ -103,6 +106,13 @@ pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), Primitive(Location, Type, Primitive, Vec>), + Construct( + Location, + Type, + ArcIntern, + HashMap, ValueOrRef>, + ), + FieldRef(Location, Type, ValueOrRef, ArcIntern), Block(Location, Type, Vec>), Print(Location, ValueOrRef), Call(Location, Type, Box>, Vec>), @@ -117,6 +127,8 @@ impl Expression { Expression::Atomic(x) => x.type_of(), Expression::Cast(_, t, _) => t.clone(), Expression::Primitive(_, t, _, _) => t.clone(), + Expression::Construct(_, t, _, _) => t.clone(), + Expression::FieldRef(_, t, _, _) => t.clone(), Expression::Block(_, t, _) => t.clone(), Expression::Print(_, _) => Type::void(), Expression::Call(_, t, _, _) => t.clone(), @@ -131,6 +143,8 @@ impl Expression { Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, Expression::Cast(l, _, _) => l, Expression::Primitive(l, _, _, _) => l, + Expression::Construct(l, _, _, _) => l, + Expression::FieldRef(l, _, _, _) => l, Expression::Block(l, _, _) => l, Expression::Print(l, _) => l, Expression::Call(l, _, _, _) => l, @@ -233,6 +247,7 @@ impl Value { pub enum Type { Primitive(PrimitiveType), Function(Vec, Box), + Structure(HashMap, Type>), } impl Type { @@ -256,6 +271,7 @@ impl<'a> TryInto for &'a Type { match self { Type::Primitive(pt) => Ok((*pt).into()), Type::Function(_, _) => Err(()), + Type::Structure(_) => Err(()), } } } @@ -265,6 +281,7 @@ pub enum TypeOrVar { Primitive(PrimitiveType), Variable(Location, ArcIntern), Function(Vec, Box), + Structure(HashMap, TypeOrVar>), } impl Default for TypeOrVar { @@ -311,6 +328,10 @@ impl TypeOrVar { } TypeOrVar::Primitive(_) => false, + + TypeOrVar::Structure(fields) => { + fields.values_mut().any(|x| x.replace(name, replace_with)) + } } } @@ -323,6 +344,7 @@ impl TypeOrVar { TypeOrVar::Function(args, ret) => { args.iter().all(TypeOrVar::is_resolved) && ret.is_resolved() } + TypeOrVar::Structure(fields) => fields.values().all(TypeOrVar::is_resolved), } } } @@ -339,6 +361,16 @@ impl PartialEq for TypeOrVar { TypeOrVar::Primitive(x) => a == x, _ => false, }, + + Type::Structure(fields1) => match self { + TypeOrVar::Structure(fields2) => { + fields1.len() == fields2.len() + && fields1.iter().all(|(name, subtype)| { + fields2.get(name).map(|x| x == subtype).unwrap_or(false) + }) + } + _ => false, + }, } } } @@ -386,6 +418,9 @@ impl> From for TypeOrVar { args.into_iter().map(Into::into).collect(), Box::new((*ret).into()), ), + Type::Structure(fields) => { + TypeOrVar::Structure(fields.into_iter().map(|(n, t)| (n, t.into())).collect()) + } } } } @@ -413,7 +448,22 @@ impl TryFrom for Type { } TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)), - _ => Err(value), + + TypeOrVar::Structure(fields) => { + let mut new_fields = HashMap::with_capacity(fields.len()); + + for (name, field) in fields.iter() { + if let Ok(new_field) = field.clone().try_into() { + new_fields.insert(name.clone(), new_field); + } else { + return Err(TypeOrVar::Structure(fields)); + } + } + + Ok(Type::Structure(new_fields)) + } + + TypeOrVar::Variable(_, _) => Err(value), } } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 5b758ae..6c8b01c 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -2,6 +2,7 @@ use super::{Primitive, Type, ValueOrRef}; use crate::eval::{EvalError, Value}; use crate::ir::{Expression, Program, TopLevel, Variable}; use crate::util::scoped_map::ScopedMap; +use std::collections::HashMap; type IRValue = Value>; type IREvalError = EvalError>; @@ -60,6 +61,7 @@ where match ty { Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), Type::Function(_, _) => Err(EvalError::CastToFunction(ty.to_string())), + Type::Structure(_) => unimplemented!(), } } @@ -79,6 +81,29 @@ where } } + Expression::Construct(_, _, name, fields) => { + let mut result_fields = HashMap::with_capacity(fields.len()); + + for (name, subexpr) in fields.iter() { + result_fields.insert(name.clone(), subexpr.eval(env)?); + } + + Ok(Value::Structure(Some(name.clone()), result_fields)) + } + + Expression::FieldRef(loc, _, valref, field) => match valref.eval(env)? { + Value::Structure(oname, mut fields) => match fields.remove(field) { + None => Err(EvalError::NoFieldForValue( + loc.clone(), + Value::Structure(oname, fields), + field.clone(), + )), + Some(value) => Ok(value), + }, + + x => Err(EvalError::NoFieldForValue(loc.clone(), x, field.clone())), + }, + Expression::Block(_, _, stmts) => { let mut result = Value::Void; diff --git a/src/ir/pretty.rs b/src/ir/pretty.rs index ae0ce15..b407059 100644 --- a/src/ir/pretty.rs +++ b/src/ir/pretty.rs @@ -56,6 +56,28 @@ impl Expression { Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { op.pretty(allocator).append(exprs[0].pretty(allocator)) } + Expression::Construct(_, _, name, fields) => { + let inner = allocator + .intersperse( + fields.iter().map(|(k, v)| { + allocator + .text(k.to_string()) + .append(":") + .append(allocator.space()) + .append(v.pretty(allocator)) + .append(allocator.text(";")) + }), + allocator.line(), + ) + .indent(2) + .braces(); + allocator.text(name.to_string()).append(inner) + } + Expression::FieldRef(_, _, val, field) => val.pretty(allocator).append( + allocator + .text(".") + .append(allocator.text(field.to_string())), + ), Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { let left = exprs[0].pretty(allocator); let right = exprs[1].pretty(allocator); @@ -180,6 +202,18 @@ impl Type { match self { Type::Function(args, rettype) => pretty_function_type!(allocator, args, rettype), Type::Primitive(prim) => prim.pretty(allocator), + Type::Structure(fields) => allocator.text("struct").append( + allocator + .concat(fields.iter().map(|(n, t)| { + allocator + .text(n.to_string()) + .append(allocator.text(":")) + .append(allocator.space()) + .append(t.pretty(allocator)) + .append(allocator.text(";")) + })) + .braces(), + ), } } } @@ -190,6 +224,18 @@ impl TypeOrVar { TypeOrVar::Function(args, rettype) => pretty_function_type!(allocator, args, rettype), TypeOrVar::Primitive(prim) => prim.pretty(allocator), TypeOrVar::Variable(_, name) => allocator.text(name.to_string()), + TypeOrVar::Structure(fields) => allocator.text("struct").append( + allocator + .concat(fields.iter().map(|(n, t)| { + allocator + .text(n.to_string()) + .append(allocator.text(":")) + .append(allocator.space()) + .append(t.pretty(allocator)) + .append(allocator.text(";")) + })) + .braces(), + ), } } } diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index 92733df..2fff204 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -33,7 +33,7 @@ pub enum TopLevel { Option, Expression, ), - Structure(Location, Option, Vec<(Name, Type)>), + Structure(Location, Name, Vec<(Name, Type)>), } /// A Name. @@ -212,5 +212,5 @@ pub enum Value { #[derive(Clone, Debug, PartialEq, Eq)] pub enum Type { Named(Name), - Struct(Option, Vec<(Option, Option)>), + Struct(Vec<(Name, Option)>), } diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index e93e29f..70534ea 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -100,8 +100,8 @@ OptionalComma: () = { } Structure: TopLevel = { - "struct" "{" "}" => { - TopLevel::Structure(Location::new(file_idx, s..e), on, fields) + "struct" "{" "}" => { + TopLevel::Structure(Location::new(file_idx, s..e), n, fields) } } @@ -113,14 +113,13 @@ Field: (Name, Type) = { Type: Type = { => Type::Named(name), => Type::Named(t), - "struct" "{" "}" => - Type::Struct(on, fields), + "struct" "{" "}" => + Type::Struct(fields), } -TypeField: (Option, Option) = { - ":" ";" => (Some(name), Some(ty)), - (":" "_")? ";" => (Some(name), None), - "_" ":" ";" => (None, Some(ty)), +TypeField: (Name, Option) = { + ":" ";" => (name, Some(ty)), + (":" "_")? ";" => (name, None), } Name: Name = { diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index ea306b6..10fa97c 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -65,11 +65,7 @@ impl TopLevel { TopLevel::Structure(_, name, fields) => allocator .text("struct") .append(allocator.space()) - .append( - name.as_ref() - .map(|x| allocator.text(x.to_string())) - .unwrap_or_else(|| allocator.nil()), - ) + .append(allocator.text(name.to_string())) .append(allocator.space()) .append(allocator.text("{")) .append(allocator.hardline()) @@ -224,22 +220,13 @@ impl Type { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { match self { Type::Named(x) => allocator.text(x.to_string()), - Type::Struct(name, fields) => allocator + Type::Struct(fields) => allocator .text("struct") .append(allocator.space()) - .append( - name.as_ref() - .map(|x| allocator.text(x.to_string())) - .unwrap_or_else(|| allocator.nil()), - ) .append(allocator.intersperse( fields.iter().map(|(name, ty)| { allocator - .text( - name.as_ref() - .map(|x| x.to_string()) - .unwrap_or_else(|| "_".to_string()), - ) + .text(name.to_string()) .append(allocator.text(":")) .append(allocator.space()) .append( diff --git a/src/type_infer.rs b/src/type_infer.rs index 080ce8d..31916dd 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -33,7 +33,7 @@ impl syntax::Program { /// this method, otherwise you may experience panics during operation. pub fn type_infer(self) -> TypeInferenceResult> { let (program, constraint_db) = convert_program(self); - let inference_result = solve_constraints(constraint_db); + let inference_result = solve_constraints(&program.type_definitions, constraint_db); inference_result.map(|resolutions| finalize_program(program, &resolutions)) } diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 5e004d8..1e103d3 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -7,6 +7,11 @@ use internment::ArcIntern; use std::collections::HashMap; use std::str::FromStr; +enum TopLevelItem { + Type(ArcIntern, ir::TypeOrVar), + Expression(ir::TopLevel), +} + /// This function takes a syntactic program and converts it into the IR version of the /// program, with appropriate type variables introduced and their constraints added to /// the given database. @@ -21,28 +26,37 @@ pub fn convert_program( let mut items = Vec::new(); let mut renames = ScopedMap::new(); let mut bindings = HashMap::new(); + let mut type_definitions = HashMap::new(); for item in program.items.drain(..) { - items.push(convert_top_level( - item, - &mut constraint_db, - &mut renames, - &mut bindings, - )); + let tli = convert_top_level(item, &mut constraint_db, &mut renames, &mut bindings); + + match tli { + TopLevelItem::Expression(item) => items.push(item), + TopLevelItem::Type(name, decl) => { + let _ = type_definitions.insert(name, decl); + } + } } - (ir::Program { items }, constraint_db) + ( + ir::Program { + items, + type_definitions, + }, + constraint_db, + ) } /// This function takes a top-level item and converts it into the IR version of the /// program, with all the appropriate type variables introduced and their constraints /// added to the given database. -pub fn convert_top_level( +fn convert_top_level( top_level: syntax::TopLevel, constraint_db: &mut Vec, renames: &mut ScopedMap, ArcIntern>, bindings: &mut HashMap, ir::TypeOrVar>, -) -> ir::TopLevel { +) -> TopLevelItem { match top_level { syntax::TopLevel::Function(name, args, _, expr) => { // First, at some point we're going to want to know a location for this function, @@ -75,8 +89,8 @@ pub fn convert_top_level( // appropriately. renames.new_scope(); let arginfo = args - .iter() - .map(|(name, _)| { + .into_iter() + .map(|(name, mut declared_type)| { let new_type = ir::TypeOrVar::new(); constraint_db.push(Constraint::IsSomething( name.location.clone(), @@ -84,7 +98,16 @@ pub fn convert_top_level( )); let new_name = finalize_name(bindings, renames, name.clone()); bindings.insert(new_name.clone(), new_type.clone()); - unimplemented!(); + + if let Some(declared_type) = declared_type.take() { + let declared_type = convert_type(declared_type, constraint_db); + constraint_db.push(Constraint::Equivalent( + name.location.clone(), + new_type.clone(), + declared_type, + )); + } + (new_name, new_type) }) .collect::>(); @@ -114,16 +137,27 @@ pub fn convert_top_level( // Remember to exit this scoping level! renames.release_scope(); - ir::TopLevel::Function(function_name, arginfo, rettype, expr) + TopLevelItem::Expression(ir::TopLevel::Function( + function_name, + arginfo, + rettype, + expr, + )) } - syntax::TopLevel::Statement(stmt) => { - ir::TopLevel::Statement(convert_statement(stmt, constraint_db, renames, bindings)) - } + syntax::TopLevel::Statement(stmt) => TopLevelItem::Expression(ir::TopLevel::Statement( + convert_statement(stmt, constraint_db, renames, bindings), + )), - syntax::TopLevel::Structure(loc, oname, fields) => { - unimplemented!() - } + syntax::TopLevel::Structure(_loc, name, fields) => TopLevelItem::Type( + name.intern(), + ir::TypeOrVar::Structure( + fields + .into_iter() + .map(|(name, t)| (name.intern(), convert_type(t, constraint_db))) + .collect(), + ), + ), } } @@ -258,7 +292,31 @@ fn convert_expression( } }, - syntax::Expression::Constructor(_, _, _) => unimplemented!(), + syntax::Expression::Constructor(loc, name, fields) => { + let mut result_fields = HashMap::new(); + let mut type_fields = HashMap::new(); + let mut prereqs = vec![]; + let result_type = ir::TypeOrVar::new(); + + for (name, syntax_expr) in fields.into_iter() { + let (ir_expr, expr_type) = + convert_expression(syntax_expr, constraint_db, renames, bindings); + type_fields.insert(name.clone().intern(), expr_type); + let (prereq, value) = simplify_expr(ir_expr); + result_fields.insert(name.clone().intern(), value); + merge_prereq(&mut prereqs, prereq); + } + + constraint_db.push(Constraint::NamedTypeIs( + loc.clone(), + name.clone().intern(), + ir::TypeOrVar::Structure(type_fields), + )); + let result = + ir::Expression::Construct(loc, result_type.clone(), name.intern(), result_fields); + + (finalize_expressions(prereqs, result), result_type) + } syntax::Expression::Reference(loc, name) => { let iname = ArcIntern::new(name); @@ -273,7 +331,26 @@ fn convert_expression( (refexp, rtype) } - syntax::Expression::FieldRef(_, _, _) => unimplemented!(), + syntax::Expression::FieldRef(loc, expr, field) => { + let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings); + let (prereqs, val_or_ref) = simplify_expr(nexpr); + let result_type = ir::TypeOrVar::new(); + let result = ir::Expression::FieldRef( + loc.clone(), + result_type.clone(), + val_or_ref, + field.clone().intern(), + ); + + constraint_db.push(Constraint::TypeHasField( + loc, + etype, + field.intern(), + result_type.clone(), + )); + + (finalize_expression(prereqs, result), result_type) + } syntax::Expression::Cast(loc, target, expr) => { let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings); @@ -366,15 +443,7 @@ fn convert_expression( let last_call = ir::Expression::Call(loc.clone(), return_type.clone(), Box::new(fun), new_args); - if prereqs.is_empty() { - (last_call, return_type) - } else { - prereqs.push(last_call); - ( - ir::Expression::Block(loc, return_type.clone(), prereqs), - return_type, - ) - } + (finalize_expressions(prereqs, last_call), return_type) } syntax::Expression::Block(loc, stmts) => { @@ -396,6 +465,35 @@ fn convert_expression( } } +fn convert_type(ty: syntax::Type, constraint_db: &mut Vec) -> ir::TypeOrVar { + match ty { + syntax::Type::Named(x) => match PrimitiveType::from_str(x.name.as_str()) { + Err(_) => { + let retval = ir::TypeOrVar::new_located(x.location.clone()); + constraint_db.push(Constraint::NamedTypeIs( + x.location.clone(), + x.intern(), + retval.clone(), + )); + retval + } + Ok(v) => ir::TypeOrVar::Primitive(v), + }, + syntax::Type::Struct(fields) => ir::TypeOrVar::Structure( + fields + .into_iter() + .map(|(n, t)| { + ( + n.intern(), + t.map(|x| convert_type(x, constraint_db)) + .unwrap_or_else(ir::TypeOrVar::new), + ) + }) + .collect(), + ), + } +} + fn simplify_expr( expr: ir::Expression, ) -> ( @@ -431,6 +529,20 @@ fn finalize_expression( } } +fn finalize_expressions( + mut prereqs: Vec>, + actual: ir::Expression, +) -> ir::Expression { + if prereqs.is_empty() { + actual + } else { + let return_type = actual.type_of(); + let loc = actual.location().clone(); + prereqs.push(actual); + ir::Expression::Block(loc, return_type, prereqs) + } +} + fn finalize_name( bindings: &HashMap, ir::TypeOrVar>, renames: &mut ScopedMap, ArcIntern>, diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 883f2a2..1990332 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -3,7 +3,7 @@ use crate::eval::PrimitiveType; use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, Value, ValueOrRef}; pub fn finalize_program( - mut program: Program, + program: Program, resolutions: &TypeResolutions, ) -> Program { for (name, ty) in resolutions.iter() { @@ -13,9 +13,15 @@ pub fn finalize_program( Program { items: program .items - .drain(..) + .into_iter() .map(|x| finalize_top_level(x, resolutions)) .collect(), + + type_definitions: program + .type_definitions + .into_iter() + .map(|(n, t)| (n, finalize_type(t, resolutions))) + .collect(), } } @@ -57,6 +63,23 @@ fn finalize_expression( .collect(), ), + Expression::Construct(loc, ty, name, fields) => Expression::Construct( + loc, + finalize_type(ty, resolutions), + name, + fields + .into_iter() + .map(|(k, v)| (k, finalize_val_or_ref(v, resolutions))) + .collect(), + ), + + Expression::FieldRef(loc, ty, valref, field) => Expression::FieldRef( + loc, + finalize_type(ty, resolutions), + finalize_val_or_ref(valref, resolutions), + field, + ), + Expression::Block(loc, ty, exprs) => { let mut final_exprs = Vec::with_capacity(exprs.len()); @@ -111,6 +134,12 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type { .collect(), Box::new(finalize_type(*ret, resolutions)), ), + TypeOrVar::Structure(fields) => Type::Structure( + fields + .into_iter() + .map(|(name, subtype)| (name, finalize_type(subtype, resolutions))) + .collect(), + ), } } @@ -129,6 +158,9 @@ fn finalize_val_or_ref( Type::Function(_, _) => { panic!("Somehow inferred that a constant was a function") } + Type::Structure(_) => { + panic!("Somehow inferred that a constant was a structure") + } Type::Primitive(PrimitiveType::Void) => { panic!("Somehow inferred that a constant was void") } diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 82123b3..e2c715e 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -16,6 +16,9 @@ pub enum Constraint { ProperPrimitiveArgs(Location, Primitive, Vec, TypeOrVar), /// The given type can be casted to the target type safely CanCastTo(Location, TypeOrVar, TypeOrVar), + /// The given type has the given field in it, and the type of that field + /// is as given. + TypeHasField(Location, TypeOrVar, ArcIntern, TypeOrVar), /// The given type must be some numeric type, but this is not a constant /// value, so don't try to default it if we can't figure it out NumericType(Location, TypeOrVar), @@ -29,6 +32,8 @@ pub enum Constraint { IsSomething(Location, TypeOrVar), /// The given type can be negated IsSigned(Location, TypeOrVar), + /// Checks to see if the given named type is equivalent to the provided one. + NamedTypeIs(Location, ArcIntern, TypeOrVar), } impl fmt::Display for Constraint { @@ -44,11 +49,15 @@ impl fmt::Display for Constraint { } Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret), Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), + Constraint::TypeHasField(_, ty1, field, ty2) => { + write!(f, "FIELD {}.{} -> {}", ty1, field, ty2) + } Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty), Constraint::IsSigned(_, ty) => write!(f, "SIGNED {}", ty), + Constraint::NamedTypeIs(_, name, ty) => write!(f, "TYPE_EQUIV {} == {}", name, ty), } } } @@ -65,6 +74,9 @@ impl Constraint { Constraint::CanCastTo(_, ty1, ty2) => { ty1.replace(name, replace_with) || ty2.replace(name, replace_with) } + Constraint::TypeHasField(_, ty1, _, ty2) => { + ty1.replace(name, replace_with) || ty2.replace(name, replace_with) + } Constraint::ConstantNumericType(_, ty) => ty.replace(name, replace_with), Constraint::Equivalent(_, ty1, ty2) => { ty1.replace(name, replace_with) || ty2.replace(name, replace_with) @@ -76,6 +88,7 @@ impl Constraint { ret.replace(name, replace_with) | args.iter_mut().any(|x| x.replace(name, replace_with)) } + Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with), } } } @@ -142,21 +155,22 @@ pub enum TypeInferenceError { CannotSafelyCast(Location, PrimitiveType, PrimitiveType), /// The primitive invocation provided the wrong number of arguments. WrongPrimitiveArity(Location, Primitive, usize, usize, usize), - /// We cannot cast between the given function types, usually because they - /// have different argument lengths - CannotCastBetweenFunctinoTypes(Location, TypeOrVar, TypeOrVar), - /// We cannot cast from a function type to something else. - CannotCastFromFunctionType(Location, TypeOrVar), - /// We cannot cast to a function type from something else. - CannotCastToFunctionType(Location, TypeOrVar), + /// We cannot cast between the type types, for any number of reasons + CannotCast(Location, TypeOrVar, TypeOrVar), /// We cannot turn a number into a function. CannotMakeNumberAFunction(Location, TypeOrVar, Option), + /// We cannot turn a number into a Structure. + CannotMakeNumberAStructure(Location, TypeOrVar, Option), /// We had a constraint we just couldn't solve. CouldNotSolve(Constraint), /// Functions are not printable. FunctionsAreNotPrintable(Location), /// The given type isn't signed, and can't be negated - IsNotSigned(Location, PrimitiveType), + IsNotSigned(Location, TypeOrVar), + /// The given type doesn't have the given field. + NoFieldForType(Location, ArcIntern, TypeOrVar), + /// There is no type with the given name. + UnknownTypeName(Location, ArcIntern), } impl From for Diagnostic { @@ -196,22 +210,12 @@ impl From for Diagnostic { prim, observed )), - TypeInferenceError::CannotCastBetweenFunctinoTypes(loc, t1, t2) => loc - .labelled_error("cannot cast between function types") + TypeInferenceError::CannotCast(loc, t1, t2) => loc + .labelled_error("cannot cast between types") .with_message(format!( "tried to cast from {} to {}", t1, t2, )), - TypeInferenceError::CannotCastFromFunctionType(loc, t) => loc - .labelled_error("cannot cast from a function type to anything else") - .with_message(format!( - "function type was {}", t, - )), - TypeInferenceError::CannotCastToFunctionType(loc, t) => loc - .labelled_error("cannot cast to a function type") - .with_message(format!( - "function type was {}", t, - )), TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc .labelled_error(if let Some(val) = val { format!("cannot turn {} into a function", val) @@ -219,17 +223,32 @@ impl From for Diagnostic { "cannot use a constant as a function type".to_string() }) .with_message(format!("function type was {}", t)), + TypeInferenceError::CannotMakeNumberAStructure(loc, t, val) => loc + .labelled_error(if let Some(val) = val { + format!("cannot turn {} into a function", val) + } else { + "cannot use a constant as a function type".to_string() + }) + .with_message(format!("function type was {}", t)), TypeInferenceError::FunctionsAreNotPrintable(loc) => loc .labelled_error("cannot print function values"), TypeInferenceError::IsNotSigned(loc, pt) => loc .labelled_error(format!("type {} is not signed", pt)) .with_message("and so it cannot be negated"), + TypeInferenceError::NoFieldForType(loc, field, t) => loc + .labelled_error(format!("no field {} available for type {}", field, t)), + TypeInferenceError::UnknownTypeName(loc , name) => loc + .labelled_error(format!("unknown type named {}", name)), TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { loc.labelled_error("internal error").with_message(format!( - "could not determine if it was safe to cast from {} to {:#?}", + "could not determine if it was safe to cast from {} to {}", a, b )) } + TypeInferenceError::CouldNotSolve(Constraint::TypeHasField(loc, a, field, _)) => { + loc.labelled_error("internal error") + .with_message(format!("fould not determine if type {} has field {}", a, field)) + } TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { loc.labelled_error("internal error").with_message(format!( "could not determine if {} and {} were equivalent", @@ -263,6 +282,9 @@ impl From for Diagnostic { TypeInferenceError::CouldNotSolve(Constraint::IsSigned(loc, t)) => loc .labelled_error("internal error") .with_message(format!("could not infer that type {} was signed", t)), + TypeInferenceError::CouldNotSolve(Constraint::NamedTypeIs(loc, name, ty)) => loc + .labelled_error("internal error") + .with_message(format!("could not infer that the name {} refers to {}", name, ty)), } } } @@ -296,6 +318,7 @@ impl From for Diagnostic { /// successful set of type resolutions (mappings from type variables to their values), or /// a series of inference errors. pub fn solve_constraints( + known_types: &HashMap, TypeOrVar>, mut constraint_db: Vec, ) -> TypeInferenceResult { let mut errors = vec![]; @@ -346,7 +369,7 @@ pub fn solve_constraints( } all_constraints_solved = false; } else { - errors.push(TypeInferenceError::CannotCastBetweenFunctinoTypes( + errors.push(TypeInferenceError::CannotCast( loc, TypeOrVar::Function(args1, ret1), TypeOrVar::Function(args2, ret2), @@ -360,21 +383,100 @@ pub fn solve_constraints( Constraint::CanCastTo( loc, - ft @ TypeOrVar::Function(_, _), - pt @ TypeOrVar::Primitive(_), + st1 @ TypeOrVar::Structure(_), + st2 @ TypeOrVar::Structure(_), ) => { - tracing::trace!(function_type = %ft, primitive_type = %pt, "we can't cast a function type to a primitive type"); - errors.push(TypeInferenceError::CannotCastFromFunctionType(loc, pt)); + tracing::trace!( + "structures can be equivalent, if their fields and types are exactly the same" + ); + new_constraints.push(Constraint::Equivalent(loc, st1, st2)); + changed_something = true; + } + + Constraint::CanCastTo( + loc, + ft @ TypeOrVar::Function(_, _), + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Structure(_), + ) => { + tracing::trace!(function_type = %ft, other_type = %ot, "we can't cast a function type to a primitive or structure type"); + errors.push(TypeInferenceError::CannotCast(loc, ft, ot)); changed_something = true; } Constraint::CanCastTo( loc, pt @ TypeOrVar::Primitive(_), - ft @ TypeOrVar::Function(_, _), + ot @ TypeOrVar::Function(_, _) | ot @ TypeOrVar::Structure(_), ) => { - tracing::trace!(function_type = %ft, primitive_type = %pt, "we can't cast a primitive type to a function type"); - errors.push(TypeInferenceError::CannotCastToFunctionType(loc, pt)); + tracing::trace!(other_type = %ot, primitive_type = %pt, "we can't cast a primitive type to a function or structure type"); + errors.push(TypeInferenceError::CannotCast(loc, pt, ot)); + changed_something = true; + } + + Constraint::CanCastTo( + loc, + st @ TypeOrVar::Structure(_), + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), + ) => { + tracing::trace!(structure_type = %st, other_type = %ot, "we can't cast a structure type to a function or primitive type"); + errors.push(TypeInferenceError::CannotCast(loc, st, ot)); + changed_something = true; + } + + Constraint::NamedTypeIs(loc, name, ty) => match known_types.get(&name) { + None => { + tracing::trace!(type_name = %name, "we don't know a type named name"); + errors.push(TypeInferenceError::UnknownTypeName(loc, name)); + changed_something = true; + } + + Some(declared_type) => { + tracing::trace!(type_name = %name, declared = %declared_type, provided = %ty, "validating that named type is equivalent to provided"); + new_constraints.push(Constraint::Equivalent( + loc, + declared_type.clone(), + ty, + )); + changed_something = true; + } + }, + + Constraint::TypeHasField( + loc, + TypeOrVar::Structure(mut fields), + field, + result_type, + ) => match fields.remove(&field) { + None => { + let reconstituted = TypeOrVar::Structure(fields); + tracing::trace!(structure_type = %reconstituted, %field, "no field found in type"); + errors.push(TypeInferenceError::NoFieldForType( + loc, + field, + reconstituted, + )); + changed_something = true; + } + + Some(field_subtype) => { + tracing::trace!(%field_subtype, %result_type, %field, "validating that field's subtype matches target result type"); + new_constraints.push(Constraint::Equivalent( + loc, + result_type, + field_subtype, + )); + changed_something = true; + } + }, + + Constraint::TypeHasField( + loc, + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), + field, + _, + ) => { + tracing::trace!(other_type = %ot, %field, "can't get field from primitive or function type"); + errors.push(TypeInferenceError::NoFieldForType(loc, field, ot)); changed_something = true; } @@ -394,6 +496,13 @@ pub fn solve_constraints( changed_something = true; } + // if we're testing if a function type is numeric, then throw a useful warning + Constraint::ConstantNumericType(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structures can't be constant numbers"); + errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); + changed_something = true; + } + // if we're testing if a number can fit into a numeric type, we can just do that! Constraint::FitsInNumType(loc, TypeOrVar::Primitive(ctype), val) => { match ctype.max_value() { @@ -420,9 +529,21 @@ pub fn solve_constraints( changed_something = true; } + // if we're testing if a function type can fit into a numeric type, that's a problem + Constraint::FitsInNumType(loc, t @ TypeOrVar::Structure(_), val) => { + tracing::trace!(function_type = %t, "values don't fit in structure types"); + errors.push(TypeInferenceError::CannotMakeNumberAStructure( + loc, + t, + Some(val), + )); + changed_something = true; + } + // if we want to know if a type is something, and it is something, then we're done Constraint::IsSomething(_, t @ TypeOrVar::Function(_, _)) - | Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_)) => { + | Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_)) + | Constraint::IsSomething(_, t @ TypeOrVar::Structure(_)) => { tracing::trace!(tested_type = %t, "type is definitely something"); changed_something = true; } @@ -431,7 +552,10 @@ pub fn solve_constraints( Constraint::IsSigned(loc, TypeOrVar::Primitive(pt)) => { tracing::trace!(primitive_type = %pt, "we can check if a primitive is signed"); if !pt.valid_operators().contains(&("-", 1)) { - errors.push(TypeInferenceError::IsNotSigned(loc, pt)); + errors.push(TypeInferenceError::IsNotSigned( + loc, + TypeOrVar::Primitive(pt), + )); } changed_something = true; } @@ -439,7 +563,14 @@ pub fn solve_constraints( // again with the functions and the numbers Constraint::IsSigned(loc, t @ TypeOrVar::Function(_, _)) => { tracing::trace!(function_type = %t, "functions are not signed"); - errors.push(TypeInferenceError::CannotCastFromFunctionType(loc, t)); + errors.push(TypeInferenceError::IsNotSigned(loc, t)); + changed_something = true; + } + + // again with the functions and the numbers + Constraint::IsSigned(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structures are not signed"); + errors.push(TypeInferenceError::IsNotSigned(loc, t)); changed_something = true; } @@ -459,6 +590,13 @@ pub fn solve_constraints( changed_something = true; } + // if we're testing if a structure type is numeric, then throw a useful warning + Constraint::NumericType(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structure types aren't numeric"); + errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); + changed_something = true; + } + // all of our primitive types are printable Constraint::Printable(_, TypeOrVar::Primitive(pt)) => { tracing::trace!(primitive_type = %pt, "primitive types are printable"); @@ -472,6 +610,17 @@ pub fn solve_constraints( changed_something = true; } + // structure types are printable if all the types inside them are printable + Constraint::Printable(loc, TypeOrVar::Structure(fields)) => { + tracing::trace!( + "structure types are printable if all their subtypes are printable" + ); + for (_, subtype) in fields.into_iter() { + new_constraints.push(Constraint::Printable(loc.clone(), subtype)); + } + changed_something = true; + } + Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim { Primitive::Plus | Primitive::Minus | Primitive::Times | Primitive::Divide if args.len() == 2 => @@ -558,6 +707,36 @@ pub fn solve_constraints( changed_something = true; } + Constraint::Equivalent( + loc, + pt @ TypeOrVar::Primitive(_), + st @ TypeOrVar::Structure(_), + ) + | Constraint::Equivalent( + loc, + st @ TypeOrVar::Structure(_), + pt @ TypeOrVar::Primitive(_), + ) => { + tracing::trace!(primitive_type = %pt, structure_type = %st, "structure and primitive types cannot be equivalent"); + errors.push(TypeInferenceError::NotEquivalent(loc, pt, st)); + changed_something = true; + } + + Constraint::Equivalent( + loc, + st @ TypeOrVar::Structure(_), + ft @ TypeOrVar::Function(_, _), + ) + | Constraint::Equivalent( + loc, + ft @ TypeOrVar::Function(_, _), + st @ TypeOrVar::Structure(_), + ) => { + tracing::trace!(structure_type = %st, function_type = %ft, "structure and primitive types cannot be equivalent"); + errors.push(TypeInferenceError::NotEquivalent(loc, st, ft)); + changed_something = true; + } + Constraint::Equivalent( _, TypeOrVar::Variable(_, name1), @@ -588,6 +767,35 @@ pub fn solve_constraints( tracing::trace!("we checked/rewrote if function types are equivalent"); } + Constraint::Equivalent( + loc, + TypeOrVar::Structure(fields1), + TypeOrVar::Structure(mut fields2), + ) => { + if fields1.len() == fields2.len() + && fields1.keys().all(|x| fields2.contains_key(x)) + { + for (name, subtype1) in fields1.into_iter() { + let subtype2 = fields2 + .remove(&name) + .expect("can find matching field after equivalence check"); + new_constraints.push(Constraint::Equivalent( + loc.clone(), + subtype1, + subtype2, + )); + } + } else { + errors.push(TypeInferenceError::NotEquivalent( + loc, + TypeOrVar::Structure(fields1), + TypeOrVar::Structure(fields2), + )) + } + changed_something = true; + tracing::trace!("we checked/rewrote if structures are equivalent"); + } + Constraint::Equivalent(_, TypeOrVar::Variable(_, ref name), ref rhs) => { changed_something |= replace_variable(&mut constraint_db, name, rhs); changed_something |= replace_variable(&mut new_constraints, name, rhs); @@ -607,6 +815,7 @@ pub fn solve_constraints( Constraint::CanCastTo(_, TypeOrVar::Variable(_, _), _) | Constraint::CanCastTo(_, _, TypeOrVar::Variable(_, _)) + | Constraint::TypeHasField(_, TypeOrVar::Variable(_, _), _, _) | Constraint::ConstantNumericType(_, TypeOrVar::Variable(_, _)) | Constraint::FitsInNumType(_, TypeOrVar::Variable(_, _), _) | Constraint::IsSomething(_, TypeOrVar::Variable(_, _))