fix upvalue resolution, start work on TCO

This commit is contained in:
Scott Richmond 2025-06-19 20:26:16 -04:00
parent bf1e7e4072
commit 1af75bc516
5 changed files with 62 additions and 28 deletions

View File

@ -242,7 +242,7 @@ println!("{a} // {b_high}/{b_low} // {c}");
To reiterate the punch list that *I would have needed for Computer Class 1*: To reiterate the punch list that *I would have needed for Computer Class 1*:
* [x] jump instructions need 16 bits of operand * [x] jump instructions need 16 bits of operand
- Whew, that took longer than I expected - Whew, that took longer than I expected
* [ ] splatterns * [x] splatterns
- [ ] validator should ensure splatterns are the longest patterns in a form - [ ] validator should ensure splatterns are the longest patterns in a form
* [ ] improve validator * [ ] improve validator
- [ ] Tuples may not be longer than n members - [ ] Tuples may not be longer than n members

View File

@ -76,6 +76,7 @@ pub enum Op {
Stringify, Stringify,
Call, Call,
TailCall,
Return, Return,
Partial, Partial,
@ -218,6 +219,7 @@ impl std::fmt::Display for Op {
Call => "call", Call => "call",
Return => "return", Return => "return",
Partial => "partial", Partial => "partial",
TailCall => "tail_call",
SetUpvalue => "set_upvalue", SetUpvalue => "set_upvalue",
GetUpvalue => "get_upvalue", GetUpvalue => "get_upvalue",
@ -275,7 +277,8 @@ impl Chunk {
PushBinding | MatchTuple | MatchSplattedTuple | LoadSplattedTuple | MatchList PushBinding | MatchTuple | MatchSplattedTuple | LoadSplattedTuple | MatchList
| MatchSplattedList | LoadSplattedList | MatchDict | MatchSplattedDict | MatchSplattedList | LoadSplattedList | MatchDict | MatchSplattedDict
| DropDictEntry | LoadDictValue | PushTuple | PushBox | MatchDepth | PopN | StoreAt | DropDictEntry | LoadDictValue | PushTuple | PushBox | MatchDepth | PopN | StoreAt
| Call | SetUpvalue | GetUpvalue | Partial | MatchString | PushStringMatches => { | Call | SetUpvalue | GetUpvalue | Partial | MatchString | PushStringMatches
| TailCall => {
let next = self.bytecode[*i + 1]; let next = self.bytecode[*i + 1];
println!("{i:04}: {:16} {next:03}", op.to_string()); println!("{i:04}: {:16} {next:03}", op.to_string());
*i += 1; *i += 1;
@ -357,6 +360,7 @@ pub struct Compiler<'a> {
pub enclosing: Option<&'a Compiler<'a>>, pub enclosing: Option<&'a Compiler<'a>>,
pub upvalues: Vec<Upvalue>, pub upvalues: Vec<Upvalue>,
loop_info: Vec<LoopInfo>, loop_info: Vec<LoopInfo>,
tail_pos: bool,
} }
fn is_binding(expr: &Spanned<Ast>) -> bool { fn is_binding(expr: &Spanned<Ast>) -> bool {
@ -401,6 +405,7 @@ impl<'a> Compiler<'a> {
upvalues: vec![], upvalues: vec![],
src, src,
name, name,
tail_pos: false,
} }
} }
@ -533,7 +538,7 @@ impl<'a> Compiler<'a> {
} }
None => match self.resolve_upvalue(name) { None => match self.resolve_upvalue(name) {
Some(position) => { Some(position) => {
println!("resolved upvalue: {name}"); println!("resolved upvalue: {name} at {position}");
self.emit_op(Op::GetUpvalue); self.emit_op(Op::GetUpvalue);
self.emit_byte(position); self.emit_byte(position);
self.stack_depth += 1; self.stack_depth += 1;
@ -542,7 +547,7 @@ impl<'a> Compiler<'a> {
println!("setting upvalue: {name}"); println!("setting upvalue: {name}");
let upvalue = self.get_upvalue(name); let upvalue = self.get_upvalue(name);
self.emit_op(Op::GetUpvalue); self.emit_op(Op::GetUpvalue);
self.emit_byte(upvalue.stack_pos); self.emit_byte(self.upvalues.len());
self.upvalues.push(upvalue); self.upvalues.push(upvalue);
dbg!(&self.upvalues); dbg!(&self.upvalues);
self.stack_depth += 1; self.stack_depth += 1;
@ -551,16 +556,6 @@ impl<'a> Compiler<'a> {
} }
} }
// fn resolve_binding(&self, name: &'static str) -> usize {
// if let Some(pos) = self.resolve_local(name) {
// return pos;
// }
// match self.enclosing {
// Some(compiler) => compiler.resolve_binding(name),
// None => unreachable!(),
// }
// }
fn pop(&mut self) { fn pop(&mut self) {
self.emit_op(Op::Pop); self.emit_op(Op::Pop);
self.stack_depth -= 1; self.stack_depth -= 1;
@ -617,6 +612,8 @@ impl<'a> Compiler<'a> {
} }
Keyword(s) => self.emit_constant(Value::Keyword(s)), Keyword(s) => self.emit_constant(Value::Keyword(s)),
Block(lines) => { Block(lines) => {
let tail_pos = self.tail_pos;
self.tail_pos = false;
// increase the scope // increase the scope
self.scope_depth += 1; self.scope_depth += 1;
// stash the stack depth // stash the stack depth
@ -650,6 +647,7 @@ impl<'a> Compiler<'a> {
} }
// otherwise, just evaluate it and leave the value on the stack // otherwise, just evaluate it and leave the value on the stack
_ => { _ => {
self.tail_pos = tail_pos;
self.visit(last_expr); self.visit(last_expr);
} }
} }
@ -675,9 +673,12 @@ impl<'a> Compiler<'a> {
self.emit_op(Op::Load); self.emit_op(Op::Load);
} }
If(cond, then, r#else) => { If(cond, then, r#else) => {
let tail_pos = self.tail_pos;
self.tail_pos = false;
self.visit(cond); self.visit(cond);
let jif_idx = self.stub_jump(Op::JumpIfFalse); let jif_idx = self.stub_jump(Op::JumpIfFalse);
self.stack_depth -= 1; self.stack_depth -= 1;
self.tail_pos = tail_pos;
self.visit(then); self.visit(then);
let jump_idx = self.stub_jump(Op::Jump); let jump_idx = self.stub_jump(Op::Jump);
self.visit(r#else); self.visit(r#else);
@ -1015,7 +1016,10 @@ impl<'a> Compiler<'a> {
self.emit_constant(Value::Keyword(key)); self.emit_constant(Value::Keyword(key));
self.visit(value); self.visit(value);
} }
// TODO: thread tail position through this
Synthetic(first, second, rest) => { Synthetic(first, second, rest) => {
let tail_pos = self.tail_pos;
self.tail_pos = false;
match (&first.0, &second.0) { match (&first.0, &second.0) {
(Word(_), Keyword(_)) => { (Word(_), Keyword(_)) => {
self.visit(first); self.visit(first);
@ -1090,7 +1094,15 @@ impl<'a> Compiler<'a> {
self.visit(arg); self.visit(arg);
} }
self.resolve_binding(fn_name); self.resolve_binding(fn_name);
self.emit_op(Op::Call); // if we're in tail position AND there aren't any rest args, this should be a tail call (I think)
if rest.is_empty() {
self.tail_pos = tail_pos;
}
if self.tail_pos {
self.emit_op(Op::TailCall);
} else {
self.emit_op(Op::Call);
}
self.emit_byte(arity); self.emit_byte(arity);
self.stack_depth -= arity; self.stack_depth -= arity;
} }
@ -1099,6 +1111,7 @@ impl<'a> Compiler<'a> {
} }
_ => unreachable!(), _ => unreachable!(),
} }
// the last term in rest should be in tail position if we are in tail position
for (term, _) in rest { for (term, _) in rest {
match term { match term {
Keyword(str) => { Keyword(str) => {
@ -1123,12 +1136,15 @@ impl<'a> Compiler<'a> {
} }
} }
When(clauses) => { When(clauses) => {
let tail_pos = self.tail_pos;
let mut jump_idxes = vec![]; let mut jump_idxes = vec![];
let mut clauses = clauses.iter(); let mut clauses = clauses.iter();
while let Some((WhenClause(cond, body), _)) = clauses.next() { while let Some((WhenClause(cond, body), _)) = clauses.next() {
self.tail_pos = false;
self.visit(cond.as_ref()); self.visit(cond.as_ref());
let jif_jump_idx = self.stub_jump(Op::JumpIfFalse); let jif_jump_idx = self.stub_jump(Op::JumpIfFalse);
self.stack_depth -= 1; self.stack_depth -= 1;
self.tail_pos = tail_pos;
self.visit(body); self.visit(body);
self.stack_depth -= 1; self.stack_depth -= 1;
jump_idxes.push(self.stub_jump(Op::Jump)); jump_idxes.push(self.stub_jump(Op::Jump));
@ -1142,11 +1158,14 @@ impl<'a> Compiler<'a> {
} }
WhenClause(..) => unreachable!(), WhenClause(..) => unreachable!(),
Match(scrutinee, clauses) => { Match(scrutinee, clauses) => {
let tail_pos = self.tail_pos;
self.tail_pos = false;
self.visit(scrutinee.as_ref()); self.visit(scrutinee.as_ref());
let stack_depth = self.stack_depth; let stack_depth = self.stack_depth;
let mut jump_idxes = vec![]; let mut jump_idxes = vec![];
let mut clauses = clauses.iter(); let mut clauses = clauses.iter();
while let Some((MatchClause(pattern, guard, body), _)) = clauses.next() { while let Some((MatchClause(pattern, guard, body), _)) = clauses.next() {
self.tail_pos = false;
let mut no_match_jumps = vec![]; let mut no_match_jumps = vec![];
self.scope_depth += 1; self.scope_depth += 1;
self.match_depth = 0; self.match_depth = 0;
@ -1159,6 +1178,7 @@ impl<'a> Compiler<'a> {
no_match_jumps.push(self.stub_jump(Op::JumpIfFalse)); no_match_jumps.push(self.stub_jump(Op::JumpIfFalse));
self.stack_depth -= 1; self.stack_depth -= 1;
} }
self.tail_pos = tail_pos;
self.visit(body); self.visit(body);
self.emit_op(Op::Store); self.emit_op(Op::Store);
self.scope_depth -= 1; self.scope_depth -= 1;
@ -1222,6 +1242,7 @@ impl<'a> Compiler<'a> {
compilers.get_mut(&arity).unwrap() compilers.get_mut(&arity).unwrap()
} }
}; };
compiler.tail_pos = false;
compiler.stack_depth += arity as usize; compiler.stack_depth += arity as usize;
compiler.scope_depth += 1; compiler.scope_depth += 1;
@ -1244,12 +1265,8 @@ impl<'a> Compiler<'a> {
for idx in tup_jump_idxes { for idx in tup_jump_idxes {
compiler.patch_jump(idx, compiler.len() - idx - 3); compiler.patch_jump(idx, compiler.len() - idx - 3);
} }
// compiler.pop_n(arity as usize);
compiler.emit_op(Op::PopN); compiler.emit_op(Op::PopN);
compiler.emit_byte(arity as usize); compiler.emit_byte(arity as usize);
// for _ in 0..arity {
// compiler.emit_op(Op::Pop);
// }
compiler.patch_jump(jump_idx, compiler.len() - jump_idx - 3); compiler.patch_jump(jump_idx, compiler.len() - jump_idx - 3);
let mut no_match_jumps = vec![]; let mut no_match_jumps = vec![];
no_match_jumps.push(compiler.stub_jump(Op::JumpIfNoMatch)); no_match_jumps.push(compiler.stub_jump(Op::JumpIfNoMatch));
@ -1260,6 +1277,7 @@ impl<'a> Compiler<'a> {
no_match_jumps.push(compiler.stub_jump(Op::JumpIfFalse)); no_match_jumps.push(compiler.stub_jump(Op::JumpIfFalse));
compiler.stack_depth -= 1; compiler.stack_depth -= 1;
} }
compiler.tail_pos = true;
compiler.visit(clause_body); compiler.visit(clause_body);
compiler.emit_op(Op::Store); compiler.emit_op(Op::Store);
compiler.scope_depth -= 1; compiler.scope_depth -= 1;
@ -1354,6 +1372,8 @@ impl<'a> Compiler<'a> {
self.emit_constant(Value::Nil); self.emit_constant(Value::Nil);
} }
Loop(value, clauses) => { Loop(value, clauses) => {
let tail_pos = self.tail_pos;
self.tail_pos = false;
//algo: //algo:
//first, put the values on the stack //first, put the values on the stack
let (Ast::Tuple(members), _) = value.as_ref() else { let (Ast::Tuple(members), _) = value.as_ref() else {
@ -1374,6 +1394,7 @@ impl<'a> Compiler<'a> {
let mut clauses = clauses.iter(); let mut clauses = clauses.iter();
let mut jump_idxes = vec![]; let mut jump_idxes = vec![];
while let Some((Ast::MatchClause(pattern, guard, body), _)) = clauses.next() { while let Some((Ast::MatchClause(pattern, guard, body), _)) = clauses.next() {
self.tail_pos = false;
self.emit_op(Op::ResetMatch); self.emit_op(Op::ResetMatch);
self.scope_depth += 1; self.scope_depth += 1;
let (Ast::TuplePattern(members), _) = pattern.as_ref() else { let (Ast::TuplePattern(members), _) = pattern.as_ref() else {
@ -1395,6 +1416,7 @@ impl<'a> Compiler<'a> {
jnm_idxes.push(self.stub_jump(Op::JumpIfFalse)); jnm_idxes.push(self.stub_jump(Op::JumpIfFalse));
self.stack_depth -= 1; self.stack_depth -= 1;
} }
self.tail_pos = tail_pos;
self.visit(body); self.visit(body);
self.emit_op(Op::Store); self.emit_op(Op::Store);
self.scope_depth -= 1; self.scope_depth -= 1;
@ -1462,7 +1484,12 @@ impl<'a> Compiler<'a> {
Do(terms) => { Do(terms) => {
let mut terms = terms.iter(); let mut terms = terms.iter();
let first = terms.next().unwrap(); let first = terms.next().unwrap();
let mut terms = terms.rev();
let last = terms.next().unwrap();
let terms = terms.rev();
// put the first value on the stack // put the first value on the stack
let tail_pos = self.tail_pos;
self.tail_pos = false;
self.visit(first); self.visit(first);
for term in terms { for term in terms {
self.visit(term); self.visit(term);
@ -1470,6 +1497,11 @@ impl<'a> Compiler<'a> {
self.emit_byte(1); self.emit_byte(1);
self.stack_depth -= 1; self.stack_depth -= 1;
} }
self.tail_pos = tail_pos;
self.visit(last);
self.emit_op(Op::Call);
self.emit_byte(1);
self.stack_depth -= 1;
} }
Placeholder => { Placeholder => {
self.emit_op(Op::Nothing); self.emit_op(Op::Nothing);

View File

@ -75,15 +75,14 @@ pub fn run(src: &'static str) {
pub fn main() { pub fn main() {
env::set_var("RUST_BACKTRACE", "1"); env::set_var("RUST_BACKTRACE", "1");
let src = " let src = "
fn arity { fn nope () -> nil
() -> 0 fn foo () -> {
(_) -> 1 nope ()
(_, _) -> 2 :foo
(_, _, _) -> 3
(_, _, _, ...args) -> (:a_lot, args)
} }
fn bar () -> foo ()
arity (1, 2, 3, 4, 5, 6) bar ()
"; ";
run(src); run(src);
} }

View File

@ -26,6 +26,7 @@ impl LFn {
match self { match self {
LFn::Declared { .. } => unreachable!(), LFn::Declared { .. } => unreachable!(),
LFn::Defined { closed, .. } => { LFn::Defined { closed, .. } => {
println!("closing over in {}: {value}", self.name());
closed.borrow_mut().push(value); closed.borrow_mut().push(value);
} }
} }
@ -87,7 +88,7 @@ impl LFn {
} => { } => {
let chunk_pos = arities.iter().position(|a| arity == *a); let chunk_pos = arities.iter().position(|a| arity == *a);
match chunk_pos { match chunk_pos {
Some(pos) => &chunks[pos as usize], Some(pos) => &chunks[pos],
None => &chunks[*splat as usize], None => &chunks[*splat as usize],
} }
} }

View File

@ -854,7 +854,7 @@ impl Vm {
}; };
self.push(Value::Partial(Rc::new(partial))); self.push(Value::Partial(Rc::new(partial)));
} }
Call => { Call | TailCall => {
let arity = self.chunk().bytecode[self.ip + 1]; let arity = self.chunk().bytecode[self.ip + 1];
self.ip += 2; self.ip += 2;
@ -961,6 +961,8 @@ impl Vm {
GetUpvalue => { GetUpvalue => {
let idx = self.chunk().bytecode[self.ip + 1]; let idx = self.chunk().bytecode[self.ip + 1];
self.ip += 2; self.ip += 2;
println!("getting upvalue {idx}");
dbg!(&self.frame.function);
if let Value::Fn(ref inner) = self.frame.function { if let Value::Fn(ref inner) = self.frame.function {
self.push(inner.as_ref().upvalue(idx)); self.push(inner.as_ref().upvalue(idx));
} else { } else {