diff --git a/src/ir.rs b/src/ir.rs index 48ac76f..5308405 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -15,6 +15,7 @@ mod arbitrary; pub mod ast; mod eval; +mod fields; mod pretty; mod strings; mod top_level; diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 88820b6..8d86fdf 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -1,4 +1,5 @@ use crate::eval::PrimitiveType; +pub use crate::ir::fields::Fields; use crate::syntax::{ConstantType, Location}; use internment::ArcIntern; use proptest::arbitrary::Arbitrary; @@ -247,7 +248,7 @@ impl Value { pub enum Type { Primitive(PrimitiveType), Function(Vec, Box), - Structure(HashMap, Type>), + Structure(Fields), } impl Type { @@ -281,7 +282,7 @@ pub enum TypeOrVar { Primitive(PrimitiveType), Variable(Location, ArcIntern), Function(Vec, Box), - Structure(HashMap, TypeOrVar>), + Structure(Fields), } impl Default for TypeOrVar { @@ -330,7 +331,7 @@ impl TypeOrVar { TypeOrVar::Primitive(_) => false, TypeOrVar::Structure(fields) => { - fields.values_mut().any(|x| x.replace(name, replace_with)) + fields.types_mut().any(|x| x.replace(name, replace_with)) } } } @@ -344,7 +345,7 @@ impl TypeOrVar { TypeOrVar::Function(args, ret) => { args.iter().all(TypeOrVar::is_resolved) && ret.is_resolved() } - TypeOrVar::Structure(fields) => fields.values().all(TypeOrVar::is_resolved), + TypeOrVar::Structure(fields) => fields.types().all(TypeOrVar::is_resolved), } } } @@ -364,7 +365,7 @@ impl PartialEq for TypeOrVar { Type::Structure(fields1) => match self { TypeOrVar::Structure(fields2) => { - fields1.len() == fields2.len() + fields1.count() == fields2.count() && fields1.iter().all(|(name, subtype)| { fields2.get(name).map(|x| x == subtype).unwrap_or(false) }) @@ -418,9 +419,7 @@ impl> From for TypeOrVar { args.into_iter().map(Into::into).collect(), Box::new((*ret).into()), ), - Type::Structure(fields) => { - TypeOrVar::Structure(fields.into_iter().map(|(n, t)| (n, t.into())).collect()) - } + Type::Structure(fields) => TypeOrVar::Structure(fields.map(Into::into)), } } } @@ -450,16 +449,21 @@ impl TryFrom for Type { TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)), TypeOrVar::Structure(fields) => { - let mut new_fields = HashMap::with_capacity(fields.len()); + let mut new_fields = Fields::new(fields.ordering()); + let mut errored = false; for (name, field) in fields.iter() { if let Ok(new_field) = field.clone().try_into() { new_fields.insert(name.clone(), new_field); } else { - return Err(TypeOrVar::Structure(fields)); + errored = true; } } + if errored { + return Err(TypeOrVar::Structure(fields)); + } + Ok(Type::Structure(new_fields)) } diff --git a/src/ir/fields.rs b/src/ir/fields.rs new file mode 100644 index 0000000..71a541c --- /dev/null +++ b/src/ir/fields.rs @@ -0,0 +1,101 @@ +use internment::ArcIntern; +use std::fmt; + +#[derive(Clone, PartialEq, Eq)] +pub struct Fields { + ordering: FieldOrdering, + fields: Vec<(ArcIntern, T)>, +} + +impl fmt::Debug for Fields { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Fields:")?; + self.fields.fmt(f) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum FieldOrdering { + Standard, +} + +impl Default for Fields { + fn default() -> Self { + Self::new(FieldOrdering::Standard) + } +} + +impl Fields { + pub fn new(ordering: FieldOrdering) -> Fields { + Fields { + ordering, + fields: vec![], + } + } + + pub fn ordering(&self) -> FieldOrdering { + self.ordering + } + + pub fn insert(&mut self, name: ArcIntern, t: T) { + self.fields.push((name, t)); + } + + pub fn get(&self, name: &ArcIntern) -> Option<&T> { + for (n, res) in self.fields.iter() { + if n == name { + return Some(res); + } + } + + None + } + + pub fn map T2>(self, f: F) -> Fields { + Fields { + ordering: self.ordering, + fields: self.fields.into_iter().map(|(n, t)| (n, f(t))).collect(), + } + } + + pub fn count(&self) -> usize { + self.fields.len() + } + + pub fn has_field(&self, name: &ArcIntern) -> bool { + self.fields.iter().any(|(current, _)| current == name) + } + + pub fn remove_field(&mut self, name: &ArcIntern) -> Option { + let mut field_index = None; + + for (idx, (current, _)) in self.fields.iter().enumerate() { + if current == name { + field_index = Some(idx); + break; + } + } + + field_index.map(|i| self.fields.remove(i).1) + } + + pub fn iter(&self) -> impl Iterator, &T)> { + self.fields.iter().map(|(x, y)| (x, y)) + } + + pub fn into_iter(self) -> impl Iterator, T)> { + self.fields.into_iter() + } + + pub fn field_names(&self) -> impl Iterator> { + self.fields.iter().map(|(n, _)| n) + } + + pub fn types(&self) -> impl Iterator { + self.fields.iter().map(|(_, x)| x) + } + + pub fn types_mut(&mut self) -> impl Iterator { + self.fields.iter_mut().map(|(_, x)| x) + } +} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 1e103d3..c6c9a88 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -149,15 +149,15 @@ fn convert_top_level( convert_statement(stmt, constraint_db, renames, bindings), )), - syntax::TopLevel::Structure(_loc, name, fields) => TopLevelItem::Type( - name.intern(), - ir::TypeOrVar::Structure( - fields - .into_iter() - .map(|(name, t)| (name.intern(), convert_type(t, constraint_db))) - .collect(), - ), - ), + syntax::TopLevel::Structure(_loc, name, fields) => { + let mut updated_fields = ir::Fields::default(); + + for (name, field_type) in fields.into_iter() { + updated_fields.insert(name.intern(), convert_type(field_type, constraint_db)); + } + + TopLevelItem::Type(name.intern(), ir::TypeOrVar::Structure(updated_fields)) + } } } @@ -294,7 +294,7 @@ fn convert_expression( syntax::Expression::Constructor(loc, name, fields) => { let mut result_fields = HashMap::new(); - let mut type_fields = HashMap::new(); + let mut type_fields = ir::Fields::default(); let mut prereqs = vec![]; let result_type = ir::TypeOrVar::new(); @@ -479,18 +479,18 @@ fn convert_type(ty: syntax::Type, constraint_db: &mut Vec) -> ir::Ty } Ok(v) => ir::TypeOrVar::Primitive(v), }, - syntax::Type::Struct(fields) => ir::TypeOrVar::Structure( - fields - .into_iter() - .map(|(n, t)| { - ( - n.intern(), - t.map(|x| convert_type(x, constraint_db)) - .unwrap_or_else(ir::TypeOrVar::new), - ) - }) - .collect(), - ), + syntax::Type::Struct(fields) => { + let mut new_fields = ir::Fields::default(); + + for (name, field_type) in fields.into_iter() { + let new_field_type = field_type + .map(|x| convert_type(x, constraint_db)) + .unwrap_or_else(ir::TypeOrVar::new); + new_fields.insert(name.intern(), new_field_type); + } + + ir::TypeOrVar::Structure(new_fields) + } } } diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 1990332..03edec2 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -134,12 +134,9 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type { .collect(), Box::new(finalize_type(*ret, resolutions)), ), - TypeOrVar::Structure(fields) => Type::Structure( - fields - .into_iter() - .map(|(name, subtype)| (name, finalize_type(subtype, resolutions))) - .collect(), - ), + TypeOrVar::Structure(fields) => { + Type::Structure(fields.map(|subtype| finalize_type(subtype, resolutions))) + } } } diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index e2c715e..e0b69a7 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -446,7 +446,7 @@ pub fn solve_constraints( TypeOrVar::Structure(mut fields), field, result_type, - ) => match fields.remove(&field) { + ) => match fields.remove_field(&field) { None => { let reconstituted = TypeOrVar::Structure(fields); tracing::trace!(structure_type = %reconstituted, %field, "no field found in type"); @@ -772,12 +772,12 @@ pub fn solve_constraints( TypeOrVar::Structure(fields1), TypeOrVar::Structure(mut fields2), ) => { - if fields1.len() == fields2.len() - && fields1.keys().all(|x| fields2.contains_key(x)) + if fields1.count() == fields2.count() + && fields1.field_names().all(|x| fields2.has_field(x)) { for (name, subtype1) in fields1.into_iter() { let subtype2 = fields2 - .remove(&name) + .remove_field(&name) .expect("can find matching field after equivalence check"); new_constraints.push(Constraint::Equivalent( loc.clone(),