From c31be288ad4ab8cafb06d139f1a9d0725fcd7f01 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sun, 28 Sep 2025 11:42:11 -0700 Subject: [PATCH] Calls and infix expressions. --- src/syntax.rs | 9 ++ src/syntax/parse.rs | 140 ++++++++++++++++++++- src/syntax/parser_tests.rs | 252 +++++++++++++++++++++++++++++++++++++ 3 files changed, 400 insertions(+), 1 deletion(-) diff --git a/src/syntax.rs b/src/syntax.rs index fac81e6..dfcf197 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -126,6 +126,15 @@ pub enum Expression { Reference(Name), EnumerationValue(Name, Name, Option>), StructureValue(Name, Vec), + Call(Box, CallKind, Vec), +} + +#[derive(Debug)] +pub enum CallKind { + Infix, + Normal, + Postfix, + Prefix, } #[derive(Debug)] diff --git a/src/syntax/parse.rs b/src/syntax/parse.rs index 0ee635c..7161071 100644 --- a/src/syntax/parse.rs +++ b/src/syntax/parse.rs @@ -1,19 +1,58 @@ use crate::syntax::error::ParserError; use crate::syntax::tokens::{Lexer, LocatedToken, Token}; use crate::syntax::*; +use std::collections::HashMap; pub struct Parser<'a> { file_id: usize, lexer: Lexer<'a>, known_tokens: Vec, + precedence_table: HashMap, +} + +pub enum Associativity { + Left, + Right, + None, } impl<'a> Parser<'a> { + /// Create a new parser from the given file index and lexer. + /// + /// The file index will be used for annotating locations and for + /// error messages. If you don't care about either, you can use + /// 0 with no loss of functionality. (Obviously, it will be harder + /// to create quality error messages, but you already knew that.) pub fn new(file_id: usize, lexer: Lexer<'a>) -> Parser<'a> { Parser { file_id, lexer, known_tokens: vec![], + precedence_table: HashMap::new(), + } + } + + /// Add the given operator to our precedence table, at the given + /// precedence level and associativity. + pub fn add_infix_precedence( + &mut self, + operator: S, + associativity: Associativity, + level: u8 + ) { + let actual_associativity = match associativity { + Associativity::Left => (level * 2, (level * 2) + 1), + Associativity::Right => ((level * 2) + 1, level * 2), + Associativity::None => (level * 2, level * 2), + }; + + self.precedence_table.insert(operator.to_string(), actual_associativity); + } + + fn get_precedence(&self, name: &String) -> (u8, u8) { + match self.precedence_table.get(name) { + None => (19, 20), + Some(x) => *x, } } @@ -549,7 +588,106 @@ impl<'a> Parser<'a> { } pub fn parse_expression(&mut self) -> Result { - self.parse_base_expression() + let next = self.next()?.ok_or_else(|| + self.bad_eof("looking for an expression"))?; + + self.save(next.clone()); + match next.token { + Token::ValueName(x) if x == "match" => self.parse_match_expression(), + Token::ValueName(x) if x == "if" => self.parse_if_expression(), + _ => self.parse_infix(0), + } + } + + pub fn parse_match_expression(&mut self) -> Result { + unimplemented!() + } + + pub fn parse_if_expression(&mut self) -> Result { + unimplemented!() + } + + pub fn parse_infix(&mut self, level: u8) -> Result { + let mut lhs = self.parse_base_expression()?; + + loop { + let Some(next) = self.next()? else { + return Ok(lhs); + }; + + match next.token { + Token::OpenParen => { + self.save(next); + let args = self.parse_call_arguments()?; + lhs = Expression::Call(Box::new(lhs), CallKind::Normal, args); + } + Token::OperatorName(ref n) => { + let (left_pr, right_pr) = self.get_precedence(&n); + + if left_pr < level { + self.save(next); + break; + } + + let rhs = self.parse_infix(right_pr)?; + let name = Name::new(self.to_location(next.span), n); + let opref = Box::new(Expression::Reference(name)); + let args = vec![lhs, rhs]; + + lhs = Expression::Call(opref, CallKind::Infix, args); + } + _ => { + self.save(next); + return Ok(lhs); + } + } + } + + Ok(lhs) + } + + fn parse_call_arguments(&mut self) -> Result, ParserError> { + let next = self.next()?.ok_or_else(|| self.bad_eof( + "looking for open paren for function arguments"))?; + + if !matches!(next.token, Token::OpenParen) { + return Err(ParserError::UnexpectedToken { + file_id: self.file_id, + span: next.span, + token: next.token, + expected: "open paren for call arguments", + }); + } + + let mut args = vec![]; + + loop { + let next = self.next()?.ok_or_else(|| self.bad_eof( + "looking for an expression or close paren in function arguments"))?; + + if matches!(next.token, Token::CloseParen) { + break; + } + + self.save(next); + let argument = self.parse_infix(0)?; + args.push(argument); + + let next = self.next()?.ok_or_else(|| self.bad_eof( + "looking for comma or close paren in function arguments"))?; + match next.token { + Token::Comma => continue, + Token::CloseParen => break, + _ => return Err(ParserError::UnexpectedToken { + file_id: self.file_id, + span: next.span, + token: next.token, + expected: "comma or close paren in function arguments", + }), + } + } + + Ok(args) } pub fn parse_base_expression(&mut self) -> Result { diff --git a/src/syntax/parser_tests.rs b/src/syntax/parser_tests.rs index bb3ce5a..8be3eff 100644 --- a/src/syntax/parser_tests.rs +++ b/src/syntax/parser_tests.rs @@ -517,3 +517,255 @@ fn structure_value() { parse_st("Foo{ foo: 1,, bar: \"foo\", }"), Err(_))); } + +#[test] +fn infix_and_precedence() { + let parse_ex = |str| { + let lexer = Lexer::from(str); + let mut result = Parser::new(0, lexer); + result.add_infix_precedence("+", parse::Associativity::Left, 6); + result.add_infix_precedence("*", parse::Associativity::Right, 7); + result.parse_expression() + }; + + + assert!(matches!( + parse_ex("0"), + Ok(Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value, .. }))) + if value == 0)); + assert!(matches!( + parse_ex("(0)"), + Ok(Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value, .. }))) + if value == 0)); + assert!(matches!( + parse_ex("((0))"), + Ok(Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value, .. }))) + if value == 0)); + assert!(matches!( + parse_ex("1 + 2"), + Ok(Expression::Call(plus, CallKind::Infix, args)) + if matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v1, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v2, .. })) + ] if *v1 == 1 && *v2 == 2))); + assert!(matches!( + parse_ex("1 + 2 + 3"), + Ok(Expression::Call(plus, CallKind::Infix, args)) + if matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(args.as_slice(), [ + Expression::Call(innerplus, CallKind::Infix, inner_args), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v3, .. })) + ] if *v3 == 3 && + matches!(innerplus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(inner_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v1, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v2, .. })) + ] if *v1 == 1 && *v2 == 2)))); + assert!(matches!( + parse_ex("1 * 2 * 3"), + Ok(Expression::Call(times, CallKind::Infix, args)) + if matches!(times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v1, .. })), + Expression::Call(innertimes, CallKind::Infix, inner_args), + ] if *v1 == 1 && + matches!(innertimes.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(inner_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v2, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v3, .. })) + ] if *v2 == 2 && *v3 == 3)))); + + assert!(matches!( + parse_ex("1 + 2 * 3 + 4"), + Ok(Expression::Call(plus_right, CallKind::Infix, outer_args)) if + matches!(plus_right.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(outer_args.as_slice(), [ + Expression::Call(plus_left, CallKind::Infix, left_args), + Expression::Value(ConstantValue::Integer(_, v4)) + ] if + matches!(v4, IntegerWithBase{ value: 4, .. }) && + matches!(plus_left.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(left_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v1)), + Expression::Call(times, CallKind::Infix, times_args) + ] if + matches!(v1, IntegerWithBase{ value: 1, .. }) && + matches!(times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(times_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v2)), + Expression::Value(ConstantValue::Integer(_, v3)) + ] if + matches!(v2, IntegerWithBase{ value: 2, .. }) && + matches!(v3, IntegerWithBase{ value: 3, .. })))))); + + assert!(matches!( + parse_ex("1 * 2 + 3 * 4"), + Ok(Expression::Call(plus, CallKind::Infix, outer_args)) if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(outer_args.as_slice(), [ + Expression::Call(left_times, CallKind::Infix, left_args), + Expression::Call(right_times, CallKind::Infix, right_args) + ] if + matches!(left_times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(right_times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(left_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v1)), + Expression::Value(ConstantValue::Integer(_, v2)), + ] if + matches!(v1, IntegerWithBase { value: 1, .. }) && + matches!(v2, IntegerWithBase { value: 2, .. })) && + matches!(right_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v3)), + Expression::Value(ConstantValue::Integer(_, v4)), + ] if + matches!(v3, IntegerWithBase { value: 3, .. }) && + matches!(v4, IntegerWithBase { value: 4, .. }))))); +} + +#[test] +fn calls() { + let parse_ex = |str| { + let lexer = Lexer::from(str); + let mut result = Parser::new(0, lexer); + result.add_infix_precedence("+", parse::Associativity::Left, 6); + result.add_infix_precedence("*", parse::Associativity::Right, 7); + result.parse_expression() + }; + + assert!(matches!( + parse_ex("f()"), + Ok(Expression::Call(f, CallKind::Normal, args)) if + matches!(f.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + args.is_empty())); + assert!(matches!( + parse_ex("f(a)"), + Ok(Expression::Call(f, CallKind::Normal, args)) if + matches!(f.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(args.as_slice(), [Expression::Reference(n)] if n.as_printed() == "a"))); + assert!(matches!( + parse_ex("f(a,b)"), + Ok(Expression::Call(f, CallKind::Normal, args)) if + matches!(f.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(args.as_slice(), [ + Expression::Reference(a), + Expression::Reference(b), + ] if a.as_printed() == "a" && b.as_printed() == "b"))); + assert!(matches!( + parse_ex("f(a,b,)"), + Ok(Expression::Call(f, CallKind::Normal, args)) if + matches!(f.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(args.as_slice(), [ + Expression::Reference(a), + Expression::Reference(b), + ] if a.as_printed() == "a" && b.as_printed() == "b"))); + assert!(matches!( + parse_ex("f(,a,b,)"), + Err(_))); + assert!(matches!( + parse_ex("f(a,,b,)"), + Err(_))); + assert!(matches!( + parse_ex("f(a,b,,)"), + Err(_))); + + assert!(matches!( + parse_ex("f()()"), + Ok(Expression::Call(f, CallKind::Normal, args)) if + matches!(f.as_ref(), Expression::Call(inner, CallKind::Normal, inner_args) if + matches!(inner.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + inner_args.is_empty()) && + args.is_empty())); + + assert!(matches!( + parse_ex("f() + 1"), + Ok(Expression::Call(plus, CallKind::Infix, args)) if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(args.as_slice(), [ + Expression::Call(subcall, CallKind::Normal, subargs), + Expression::Value(ConstantValue::Integer(_, v1)) + ] if + matches!(v1, IntegerWithBase{ value: 1, .. }) && + matches!(subcall.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + subargs.is_empty()))); + + assert!(matches!( + parse_ex("f(a + b, c*d)"), + Ok(Expression::Call(eff, CallKind::Normal, args)) if + matches!(eff.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(args.as_slice(), [ + Expression::Call(plus, CallKind::Infix, pargs), + Expression::Call(times, CallKind::Infix, targs), + ] if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(pargs.as_slice(), [ Expression::Reference(a), Expression::Reference(b) ] if + a.as_printed() == "a" && b.as_printed() == "b") && + matches!(targs.as_slice(), [ Expression::Reference(c), Expression::Reference(d) ] if + c.as_printed() == "c" && d.as_printed() == "d")))); + + assert!(matches!( + parse_ex("f(a + b, c*d,)"), + Ok(Expression::Call(eff, CallKind::Normal, args)) if + matches!(eff.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(args.as_slice(), [ + Expression::Call(plus, CallKind::Infix, pargs), + Expression::Call(times, CallKind::Infix, targs), + ] if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(pargs.as_slice(), [ Expression::Reference(a), Expression::Reference(b) ] if + a.as_printed() == "a" && b.as_printed() == "b") && + matches!(targs.as_slice(), [ Expression::Reference(c), Expression::Reference(d) ] if + c.as_printed() == "c" && d.as_printed() == "d")))); + + assert!(matches!( + parse_ex("3 + f(1 + 2)"), + Ok(Expression::Call(plus, CallKind::Infix, args)) if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v3)), + Expression::Call(eff, CallKind::Normal, fargs) + ] if + matches!(v3, IntegerWithBase{ value: 3, .. }) && + matches!(eff.as_ref(), Expression::Reference(n) if n.as_printed() == "f") && + matches!(fargs.as_slice(), [Expression::Call(p, CallKind::Infix, pargs)] if + matches!(p.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(pargs.as_slice(), [Expression::Value(v1), Expression::Value(v2)] if + matches!(v1, ConstantValue::Integer(_, IntegerWithBase { value: 1, .. })) && + matches!(v2, ConstantValue::Integer(_, IntegerWithBase { value: 2, .. }))))))); + + assert!(matches!( + parse_ex("(f . g)(1 + 2)"), + Ok(Expression::Call(fg, CallKind::Normal, args)) if + matches!(fg.as_ref(), Expression::Call(dot, CallKind::Infix, fgargs) if + matches!(dot.as_ref(), Expression::Reference(n) if n.as_printed() == ".") && + matches!(fgargs.as_slice(), [Expression::Reference(f), Expression::Reference(g)] if + f.as_printed() == "f" && g.as_printed() == "g")) && + matches!(args.as_slice(), [Expression::Call(plus, CallKind::Infix, pargs)] if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(pargs.as_slice(), [Expression::Value(v1), Expression::Value(v2)] if + matches!(v1, ConstantValue::Integer(_, IntegerWithBase{ value: 1, .. })) && + matches!(v2, ConstantValue::Integer(_, IntegerWithBase{ value: 2, .. })))))); + + assert!(matches!( + parse_ex("a + b(2 + 3) * c"), + Ok(Expression::Call(plus, CallKind::Infix, pargs)) if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(pargs.as_slice(), [ + Expression::Reference(a), + Expression::Call(times, CallKind::Infix, targs) + ] if a.as_printed() == "a" && + matches!(times.as_ref(), Expression::Reference(n) if n.as_printed() == "*") && + matches!(targs.as_slice(), [ + Expression::Call(b, CallKind::Normal, bargs), + Expression::Reference(c), + ] if c.as_printed() == "c" && + matches!(b.as_ref(), Expression::Reference(n) if n.as_printed() == "b") && + matches!(bargs.as_slice(), [Expression::Call(plus, CallKind::Infix, pargs)] if + matches!(plus.as_ref(), Expression::Reference(n) if n.as_printed() == "+") && + matches!(pargs.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: 2, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: 3, .. })) + ])))))); +}