DRY out validator, simplify code

This commit is contained in:
Scott Richmond 2024-12-15 23:49:43 -05:00
parent 6c78cffe56
commit 9c3205d4c1

View File

@ -1,5 +1,5 @@
use crate::parser::*;
use crate::spans::Span;
use crate::spans::{Span, Spanned};
use crate::value::Value;
use std::collections::{HashMap, HashSet};
@ -143,6 +143,13 @@ impl<'a> Validator<'a> {
}
}
fn visit(&mut self, node: &'a Spanned<Ast>) {
let (expr, span) = node;
self.ast = expr;
self.span = *span;
self.validate();
}
pub fn validate(&mut self) {
use Ast::*;
let root = self.ast;
@ -179,18 +186,13 @@ impl<'a> Validator<'a> {
}
let to = self.locals.len();
let tailpos = self.status.tail_position;
for (expr, span) in block.iter().take(block.len() - 1) {
for line in block.iter().take(block.len() - 1) {
self.status.tail_position = false;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(line);
}
let (expr, span) = block.last().unwrap();
self.ast = expr;
self.span = *span;
self.status.tail_position = tailpos;
self.validate();
self.visit(block.last().unwrap());
let block_bindings = self.locals.split_off(to);
@ -207,22 +209,12 @@ impl<'a> Validator<'a> {
let tailpos = self.status.tail_position;
self.status.tail_position = false;
let (expr, span) = cond.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(cond.as_ref());
// pass through tailpos only to then/else
self.status.tail_position = tailpos;
let (expr, span) = then.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
let (expr, span) = r#else.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(then.as_ref());
self.visit(r#else.as_ref());
}
Tuple(members) => {
if members.is_empty() {
@ -230,10 +222,8 @@ impl<'a> Validator<'a> {
}
let tailpos = self.status.tail_position;
self.status.tail_position = false;
for (expr, span) in members {
self.ast = expr;
self.span = *span;
self.validate();
for member in members {
self.visit(member);
}
self.status.tail_position = tailpos;
}
@ -244,10 +234,8 @@ impl<'a> Validator<'a> {
}
let tailpos = self.status.tail_position;
self.status.tail_position = false;
for (expr, span) in args {
self.ast = expr;
self.span = *span;
self.validate();
for arg in args {
self.visit(arg);
}
self.status.has_placeholder = false;
self.status.tail_position = tailpos;
@ -267,30 +255,21 @@ impl<'a> Validator<'a> {
}
let tailpos = self.status.tail_position;
self.status.tail_position = false;
for (expr, span) in list {
self.ast = expr;
self.span = *span;
self.validate();
for member in list {
self.visit(member);
}
self.status.tail_position = tailpos;
}
Pair(_, value) => {
let (expr, span) = value.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
}
Pair(_, value) => self.visit(value.as_ref()),
Dict(dict) => {
if dict.is_empty() {
return;
}
let tailpos = self.status.tail_position;
self.status.tail_position = false;
for (expr, span) in dict {
self.ast = expr;
self.span = *span;
self.validate();
for pair in dict {
self.visit(pair)
}
self.status.tail_position = tailpos;
}
@ -299,31 +278,16 @@ impl<'a> Validator<'a> {
// check arity against fn info if first term is word and second term is args
Synthetic(first, second, rest) => {
match (&first.0, &second.0) {
(Ast::Word(_), Ast::Keyword(_)) => {
let (expr, span) = first.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
}
(Ast::Word(_), Ast::Keyword(_)) => self.visit(first.as_ref()),
(Ast::Keyword(_), Ast::Arguments(args)) => {
if args.len() != 1 {
self.err("called keywords may only take one argument".to_string())
}
let (expr, span) = second.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(second.as_ref());
}
(Ast::Word(name), Ast::Arguments(args)) => {
let (expr, span) = first.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
let (expr, span) = second.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(first.as_ref());
self.visit(second.as_ref());
//TODO: check arities of prelude fns, too
let fn_binding = self.bound(name);
@ -337,32 +301,20 @@ impl<'a> Validator<'a> {
_ => unreachable!(),
}
for term in rest {
let (expr, span) = term;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(term);
}
}
WhenClause(cond, body) => {
let tailpos = self.status.tail_position;
self.status.tail_position = false;
let (expr, span) = cond.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(cond.as_ref());
//pass through tail position for when bodies
self.status.tail_position = tailpos;
let (expr, span) = body.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(body.as_ref());
}
When(clauses) => {
for clause in clauses {
let (expr, span) = clause;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(clause);
}
}
@ -374,54 +326,30 @@ impl<'a> Validator<'a> {
} else {
self.bind(name.to_string());
}
let (expr, span) = boxed.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(boxed.as_ref());
}
Let(lhs, rhs) => {
let (expr, span) = rhs.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
let (expr, span) = lhs.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(rhs.as_ref());
self.visit(lhs.as_ref());
}
MatchClause(pattern, guard, body) => {
let to = self.locals.len();
let (patt, span) = pattern.as_ref();
self.ast = patt;
self.span = *span;
self.validate();
self.visit(pattern.as_ref());
if let Some((expr, span)) = guard.as_ref() {
self.ast = expr;
self.span = *span;
self.validate();
if let Some(guard) = guard.as_ref() {
self.visit(guard);
}
let (expr, span) = body.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(body.as_ref());
self.locals.truncate(to);
}
Match(scrutinee, clauses) => {
let (expr, span) = scrutinee.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(scrutinee.as_ref());
for clause in clauses {
let (expr, span) = clause;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(clause);
}
}
FnDeclaration(name) => {
@ -488,10 +416,7 @@ impl<'a> Validator<'a> {
Panic(msg) => {
let tailpos = self.status.tail_position;
self.status.tail_position = false;
let (expr, span) = msg.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(msg.as_ref());
self.status.tail_position = tailpos;
}
// TODO: fix the tail call here?
@ -500,39 +425,23 @@ impl<'a> Validator<'a> {
return self.err("do expressions must have at least two terms".to_string());
}
for term in terms.iter().take(terms.len() - 1) {
let (expr, span) = term;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(term);
}
let (expr, span) = terms.last().unwrap();
self.ast = expr;
self.span = *span;
if matches!(expr, Ast::Recur(_)) {
let last = terms.last().unwrap();
self.visit(last);
if matches!(last.0, Ast::Recur(_)) {
self.err("`recur` may not be used in `do` forms".to_string());
}
self.validate();
}
Repeat(times, body) => {
self.status.tail_position = false;
let (expr, span) = times.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
let (expr, span) = body.as_ref();
self.ast = expr;
self.span = *span;
self.validate();
self.visit(times.as_ref());
self.visit(body.as_ref());
}
Loop(with, body) => {
let (expr, span) = with.as_ref();
self.span = *span;
self.ast = expr;
self.validate();
self.visit(with.as_ref());
let Ast::Tuple(input) = expr else {
let Ast::Tuple(input) = &with.0 else {
unreachable!()
};
@ -588,10 +497,7 @@ impl<'a> Validator<'a> {
self.status.tail_position = false;
for arg in args {
let (expr, span) = arg;
self.ast = expr;
self.span = *span;
self.validate();
self.visit(arg);
}
}
WordPattern(name) => match self.bound(name) {
@ -654,25 +560,15 @@ impl<'a> Validator<'a> {
return;
}
for term in terms.iter().take(terms.len() - 1) {
let (patt, span) = term;
self.ast = patt;
self.span = *span;
self.validate();
self.visit(term);
}
self.status.last_term = true;
let (patt, span) = terms.last().unwrap();
self.ast = patt;
self.span = *span;
self.validate();
let last = terms.last().unwrap();
self.visit(last);
self.status.last_term = false;
}
PairPattern(_, patt) => {
let (patt, span) = patt.as_ref();
self.ast = patt;
self.span = *span;
self.validate();
}
PairPattern(_, patt) => self.visit(patt.as_ref()),
// terminals can never be invalid
Nil | Boolean(_) | Number(_) | Keyword(_) | String(_) => (),
// terminal patterns can never be invalid