diff --git a/src/syntax/parse.rs b/src/syntax/parse.rs index d7cfc06..ae49b4a 100644 --- a/src/syntax/parse.rs +++ b/src/syntax/parse.rs @@ -251,13 +251,12 @@ impl<'lexer> Parser<'lexer> { let mut definitions = vec![]; loop { - let next_token = self.next()?; - - if next_token.is_none() { + if let Some(next_token) = self.next()? { + self.save(next_token); + definitions.push(self.parse_definition()?); + } else { return Ok(Module { definitions }); } - - definitions.push(self.parse_definition()?); } } @@ -696,6 +695,8 @@ impl<'lexer> Parser<'lexer> { 5 }; + let _ = self.require_token(Token::Arrow, "operator definition")?; + let function_name = self.parse_name("operator function definition")?; let end = self.require_token(Token::Semi, "end of operator definition")?; diff --git a/src/syntax/parser_tests.rs b/src/syntax/parser_tests.rs index 05c3e0a..2219d7e 100644 --- a/src/syntax/parser_tests.rs +++ b/src/syntax/parser_tests.rs @@ -254,7 +254,6 @@ fn structures() { assert!(parse_st("structure {").is_err()); assert!(parse_st("structure foo {}").is_err()); - println!("result: {:?}", parse_st("structure Foo {}")); assert!(matches!( parse_st("structure Foo {}"), Ok(StructureDef { name, fields, .. }) @@ -480,7 +479,6 @@ fn structure_value() { assert!(parse_st("Foo{ foo, }").is_err()); assert!(parse_st("Foo{ foo: , }").is_err()); assert!(parse_st("Foo{ , foo: 1, }").is_err()); - println!("result: {:?}", parse_st("Foo{ foo: 1 }")); assert!(matches!( parse_st("Foo{ foo: 1 }"), Ok(Expression::Structure(sv)) @@ -1084,3 +1082,104 @@ fn definitions() { return_type.is_some() && body.len() == 1))); } + +#[test] +fn operators() { + let parse = |str| { + let lexer = Lexer::from(str); + let mut result = Parser::new("test", lexer); + result.parse_module() + }; + + let all_the_operators = r#" +prefix operator - -> negate; +postfix operator ++ -> mutable_add; +infix left operator + -> sum; +infix right operator - -> subtract; +infix operator * at 8 -> multiply; +postfix operator ! at 3 -> factorial; +prefix operator $$ at 1 -> money; +"#; + + assert!(parse(all_the_operators).is_ok()); + + assert!(parse("left prefix operator - -> negate;").is_err()); + assert!(parse("right prefix operator - -> negate;").is_err()); + assert!(parse("right infix operator - -> negate;").is_err()); + assert!(parse("left infix operator - -> negate;").is_err()); + assert!(parse("infix operator at 8 - -> negate;").is_err()); + + + // these are designed to replicate the examples in the infix_and_precedence + // tests, but with the precedence set automatically by the parser. + let plus_and_times = |expr| format!(r#" +infix left operator + at 6 -> add; +infix right operator * at 7 -> mul; + +x = {expr}; +"#); + + let plus_example = plus_and_times("1 + 2 + 3"); + assert!(matches!( + parse(&plus_example), + Ok(Module { definitions }) if + matches!(definitions.last(), Some(Definition{ definition, .. }) if + matches!(definition, Def::Value(ValueDef{ value, .. }) if + matches!(value, Expression::Call(plus, CallKind::Infix, args) if + matches!(plus.as_ref(), Expression::Reference(_,n) if n.as_printed() == "+") && + matches!(args.as_slice(), [ + Expression::Call(innerplus, CallKind::Infix, inner_args), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v3, .. })) + ] if *v3 == 3 && + matches!(innerplus.as_ref(), Expression::Reference(_,n) if n.as_printed() == "+") && + matches!(inner_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v1, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v2, .. })) + ] if *v1 == 1 && *v2 == 2))))))); + + let times_example = plus_and_times("1 * 2 * 3"); + assert!(matches!( + parse(×_example), + Ok(Module { definitions }) if + matches!(definitions.last(), Some(Definition{ definition, .. }) if + matches!(definition, Def::Value(ValueDef{ value, .. }) if + matches!(value, Expression::Call(times, CallKind::Infix, args) if + matches!(times.as_ref(), Expression::Reference(_,n) if n.as_printed() == "*") && + matches!(args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v1, .. })), + Expression::Call(innertimes, CallKind::Infix, inner_args), + ] if *v1 == 1 && + matches!(innertimes.as_ref(), Expression::Reference(_,n) if n.as_printed() == "*") && + matches!(inner_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v2, .. })), + Expression::Value(ConstantValue::Integer(_, IntegerWithBase{ value: v3, .. })) + ] if *v2 == 2 && *v3 == 3))))))); + + + let mixed_example = plus_and_times("1 + 2 * 3 + 4"); + assert!(matches!( + parse(&mixed_example), + Ok(Module { definitions }) if + matches!(definitions.last(), Some(Definition{ definition, .. }) if + matches!(definition, Def::Value(ValueDef{ value, .. }) if + matches!(value, Expression::Call(plus_right, CallKind::Infix, outer_args) if + matches!(plus_right.as_ref(), Expression::Reference(_,n) if n.as_printed() == "+") && + matches!(outer_args.as_slice(), [ + Expression::Call(plus_left, CallKind::Infix, left_args), + Expression::Value(ConstantValue::Integer(_, v4)) + ] if + matches!(v4, IntegerWithBase{ value: 4, .. }) && + matches!(plus_left.as_ref(), Expression::Reference(_,n) if n.as_printed() == "+") && + matches!(left_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v1)), + Expression::Call(times, CallKind::Infix, times_args) + ] if + matches!(v1, IntegerWithBase{ value: 1, .. }) && + matches!(times.as_ref(), Expression::Reference(_,n) if n.as_printed() == "*") && + matches!(times_args.as_slice(), [ + Expression::Value(ConstantValue::Integer(_, v2)), + Expression::Value(ConstantValue::Integer(_, v3)) + ] if + matches!(v2, IntegerWithBase{ value: 2, .. }) && + matches!(v3, IntegerWithBase{ value: 3, .. }))))))))); +} diff --git a/src/syntax/tokens.rs b/src/syntax/tokens.rs index 15c5e2d..2d245f7 100644 --- a/src/syntax/tokens.rs +++ b/src/syntax/tokens.rs @@ -411,12 +411,18 @@ impl<'a> LexerState<'a> { }); } - let mut value = 0; + let mut value: u32 = 0; while let Some((idx, char)) = self.next_char() { if let Some(digit) = char.to_digit(16) { - value = (value * 16) + digit; - continue; + if let Some(shifted) = value.checked_shl(4) { + value = shifted + digit; + continue; + } else { + return Err(LexerError::InvalidUnicode { + span: token_start_offset..idx, + }); + } } if char == '}' { @@ -730,3 +736,17 @@ fn arrow_requires_nonop() { let mut next_token = move || lexer.next().map(|x| x.expect("Can read valid token").token); assert_eq!(Some(Token::Arrow), next_token()); } + +#[test] +fn unicode() { + let mut lexer = Lexer::from("'\\u{00BE}'"); + let mut next_token = move || lexer.next().map(|x| x.expect("Can read valid token").token); + assert_eq!(Some(Token::Character('¾')), next_token()); + + let mut lexer = Lexer::from("'\\u{111111111111}'"); + assert!(lexer.next().unwrap().is_err()); + let mut lexer = Lexer::from("'\\u{00BE'"); + assert!(lexer.next().unwrap().is_err()); + let mut lexer = Lexer::from("'\\u00BE}'"); + assert!(lexer.next().unwrap().is_err()); +}