From afff04259c09db4ba5bad5a92e68fc0f0aea9714 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Mon, 13 May 2024 20:47:11 -0700 Subject: [PATCH] Free variable analysis. --- src/lambda_lift.rs | 9 +- src/lambda_lift/free_variables.rs | 161 ++++++++++++++++++++++++++++++ src/syntax.rs | 19 +++- src/syntax/parser.lalrpop | 4 +- src/type_infer/convert.rs | 8 +- 5 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 src/lambda_lift/free_variables.rs diff --git a/src/lambda_lift.rs b/src/lambda_lift.rs index 8880273..339b1c0 100644 --- a/src/lambda_lift.rs +++ b/src/lambda_lift.rs @@ -1,8 +1 @@ -use crate::syntax::{Expression, Name}; -use std::collections::{HashSet, HashMap}; - -impl Expression { - fn free_variables(&self) -> HashSet { - unimplemented!() - } -} \ No newline at end of file +mod free_variables; diff --git a/src/lambda_lift/free_variables.rs b/src/lambda_lift/free_variables.rs new file mode 100644 index 0000000..2f8f868 --- /dev/null +++ b/src/lambda_lift/free_variables.rs @@ -0,0 +1,161 @@ +use crate::syntax::{Expression, Name}; +use std::collections::HashSet; + +impl Expression { + /// Find the set of free variables used within this expression. + /// + /// Obviously, if this expression contains a function definition, argument + /// variables in the body will not be reported as free. + pub fn free_variables(&self) -> HashSet { + match self { + Expression::Value(_, _) => HashSet::new(), + Expression::Constructor(_, _, args) => + args.iter().fold(HashSet::new(), |mut existing, (_, expr)| { + existing.extend(expr.free_variables()); + existing + }), + Expression::Reference(n) => HashSet::from([n.clone()]), + Expression::FieldRef(_, expr, _) => expr.free_variables(), + Expression::Cast(_, _, expr) => expr.free_variables(), + Expression::Primitive(_, _) => HashSet::new(), + Expression::Call(_, f, args) => + args.iter() + .fold(f.free_variables(), + |mut existing, expr| { + existing.extend(expr.free_variables()); + existing + }), + Expression::Block(_, exprs) => { + let mut free_vars = HashSet::new(); + let mut bound_vars = HashSet::new(); + + for expr in exprs.iter() { + for var in expr.free_variables().into_iter() { + if !bound_vars.contains(&var) { + free_vars.insert(var); + } + } + + bound_vars.extend(expr.new_bindings()); + } + + free_vars + } + Expression::Binding(_, _, expr) => expr.free_variables(), + Expression::Function(_, name, args, _, body) => { + let mut candidates = body.free_variables(); + if let Some(name) = name { + candidates.remove(name); + } + for (name, _) in args.iter() { + candidates.remove(name); + } + candidates + } + } + } + + /// Find the set of new bindings in the provided expression. + /// + /// New bindings are those that introduce a variable that can be + /// referenced in subsequent statements / expressions within a + /// parent construct. This eventually means something in the next + /// block, but can involve some odd effects in the language. + pub fn new_bindings(&self) -> HashSet { + match self { + Expression::Value(_, _) => HashSet::new(), + Expression::Constructor(_, _, args) => + args.iter().fold(HashSet::new(), |mut existing, (_, expr)| { + existing.extend(expr.new_bindings()); + existing + }), + Expression::Reference(_) => HashSet::new(), + Expression::FieldRef(_, expr, _) => expr.new_bindings(), + Expression::Cast(_, _, expr) => expr.new_bindings(), + Expression::Primitive(_, _) => HashSet::new(), + Expression::Call(_, f, args) => + args.iter() + .fold(f.new_bindings(), + |mut existing, expr| { + existing.extend(expr.new_bindings()); + existing + }), + Expression::Block(_, _) => HashSet::new(), + Expression::Binding(_, name, expr) => { + let mut others = expr.new_bindings(); + others.insert(name.clone()); + others + } + Expression::Function(_, Some(name), _, _, _) => HashSet::from([name.clone()]), + Expression::Function(_, None, _, _, _) => HashSet::new(), + } + } +} + +#[test] +fn basic_frees_works() { + let test = Expression::parse(0, "1u64").unwrap(); + assert_eq!(0, test.free_variables().len()); + + let test = Expression::parse(0, "1u64 + 2").unwrap(); + assert_eq!(0, test.free_variables().len()); + + let test = Expression::parse(0, "x").unwrap(); + assert_eq!(1, test.free_variables().len()); + assert!(test.free_variables().contains(&Name::manufactured("x"))); + + let test = Expression::parse(0, "1 + x").unwrap(); + assert_eq!(1, test.free_variables().len()); + assert!(test.free_variables().contains(&Name::manufactured("x"))); + + let test = Expression::parse(0, "Structure{ field1: x; field2: y; }").unwrap(); + assert_eq!(2, test.free_variables().len()); + assert!(test.free_variables().contains(&Name::manufactured("x"))); + assert!(test.free_variables().contains(&Name::manufactured("y"))); + + let test = Expression::parse(0, "{ print x; print y }").unwrap(); + assert_eq!(2, test.free_variables().len()); + assert!(test.free_variables().contains(&Name::manufactured("x"))); + assert!(test.free_variables().contains(&Name::manufactured("y"))); +} + +#[test] +fn test_around_function() { + let nada = Expression::parse(0, "function(x) x").unwrap(); + assert_eq!(0, nada.free_variables().len()); + + let lift = Expression::parse(0, "function(x) x + y").unwrap(); + assert_eq!(1, lift.free_variables().len()); + assert!(lift.free_variables().contains(&Name::manufactured("y"))); + + let nest = Expression::parse(0, "function(y) function(x) x + y").unwrap(); + assert_eq!(0, nest.free_variables().len()); + + let multi = Expression::parse(0, "function(x, y) x + y + z").unwrap(); + assert_eq!(1, multi.free_variables().len()); + assert!(multi.free_variables().contains(&Name::manufactured("z"))); +} + +#[test] +fn test_is_set() { + let multi = Expression::parse(0, "function(x, y) x + y + z + z").unwrap(); + assert_eq!(1, multi.free_variables().len()); + assert!(multi.free_variables().contains(&Name::manufactured("z"))); +} + +#[test] +fn bindings_remove() { + let x_bind = Expression::parse(0, "{ x = 4; print x }").unwrap(); + assert_eq!(0, x_bind.free_variables().len()); + + let inner = Expression::parse(0, "{ { x = 4; print x }; print y }").unwrap(); + assert_eq!(1, inner.free_variables().len()); + assert!(inner.free_variables().contains(&Name::manufactured("y"))); + + let inner = Expression::parse(0, "{ { x = 4; print x }; print x }").unwrap(); + assert_eq!(1, inner.free_variables().len()); + assert!(inner.free_variables().contains(&Name::manufactured("x"))); + + let double = Expression::parse(0, "{ x = y = 1; x + y }").unwrap(); + assert_eq!(0, double.free_variables().len()); +} diff --git a/src/syntax.rs b/src/syntax.rs index 1e279d5..7e54c31 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -46,7 +46,7 @@ use crate::syntax::arbitrary::GenerationEnvironment; pub use crate::syntax::ast::*; pub use crate::syntax::location::Location; pub use crate::syntax::name::Name; -pub use crate::syntax::parser::{ProgramParser, TopLevelParser}; +pub use crate::syntax::parser::{ProgramParser, TopLevelParser, ExpressionParser}; pub use crate::syntax::tokens::{LexerError, Token}; use lalrpop_util::ParseError; #[cfg(test)] @@ -257,6 +257,23 @@ impl TopLevel { } } +impl Expression { + /// Parse an expression from a string, using the given index for [`Location`]s. + /// + /// As with [`Program::parse`], if you use a bad file index, you'll get weird behaviors + /// when you try to print errors, but things should otherwise work fine. This function + /// will only parse a single expression, which is useful for testing, but probably shouldn't + /// be used when reading in whole files. + pub fn parse(file_idx: usize, buffer: &str) -> Result { + let lexer = Token::lexer(buffer) + .spanned() + .map(|x| permute_lexer_result(file_idx, x)); + ExpressionParser::new() + .parse(file_idx, lexer) + .map_err(|e| ParserError::convert(file_idx, e)) + } +} + fn permute_lexer_result( file_idx: usize, result: (Result, Range), diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index 6c4d53d..e2777b5 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -150,7 +150,7 @@ TypeName: Name = { // to run through a few examples. Consider thinking about how you want to // parse something like "1 + 2 * 3", for example, versus "1 + 2 + 3" or // "1 * 2 + 3", and hopefully that'll help. -Expression: Expression = { +pub Expression: Expression = { BindingExpression, } @@ -271,4 +271,4 @@ Comma: Vec = { v } } -}; \ No newline at end of file +}; diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 225f9db..2eac333 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -261,7 +261,7 @@ impl InferenceEngine { } syntax::Expression::Primitive(loc, name) => { - let primop = ir::Primitive::from_str(&name.current_name()).expect("valid primitive"); + let primop = ir::Primitive::from_str(name.current_name()).expect("valid primitive"); match primop { ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { @@ -409,8 +409,8 @@ impl InferenceEngine { // First, at some point we're going to want to know a location for this function, // which should either be the name if we have one, or the body if we don't. let function_location = match name { - None => expr.location().clone(), - Some(ref name) => loc, + None => loc, + Some(ref name) => name.location().clone(), }; // Next, let us figure out what we're going to name this function. If the user // didn't provide one, we'll just call it "function:" for them. (We'll @@ -532,7 +532,7 @@ impl InferenceEngine { .variable_types .contains_key(name.current_interned()) { - let new_name = ir::gensym(&name.original_name()); + let new_name = ir::gensym(name.original_name()); renames.insert(name.current_interned().clone(), new_name.clone()); new_name } else {