rought draft of splatted fn args

This commit is contained in:
Scott Richmond 2025-06-19 18:26:44 -04:00
parent 442532ecd3
commit bf1e7e4072
4 changed files with 91 additions and 63 deletions

View File

@ -876,7 +876,7 @@ impl<'a> Compiler<'a> {
pairs_len -= 1;
}
let mut match_depth = self.match_depth;
let match_depth = self.match_depth;
self.match_depth = 0;
for pair in pairs.iter().take(pairs_len) {
let (PairPattern(key, pattern), _) = pair else {
@ -922,9 +922,7 @@ impl<'a> Compiler<'a> {
self.patch_jump(before_load_dict_idx, self.len() - before_load_dict_idx - 3);
self.patch_jump(jump_idx, self.len() - jump_idx - 3);
}
Splattern(patt) => {
self.visit(patt);
}
Splattern(patt) => self.visit(patt),
InterpolatedPattern(parts, _) => {
println!("An interpolated pattern of {} parts", parts.len());
let mut pattern = "".to_string();
@ -1181,12 +1179,8 @@ impl<'a> Compiler<'a> {
for idx in jump_idxes {
self.patch_jump(idx, self.len() - idx - 3);
// self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 1;
}
self.pop_n(self.stack_depth - stack_depth);
// while self.stack_depth > stack_depth {
// self.pop();
// }
self.emit_op(Op::Load);
self.stack_depth += 1;
}
@ -1198,10 +1192,12 @@ impl<'a> Compiler<'a> {
unreachable!()
};
let mut compilers: HashMap<usize, Compiler> = HashMap::new();
let mut compilers: HashMap<u8, Compiler> = HashMap::new();
let mut upvalues = vec![];
let mut has_splat = false;
for clause in fn_body {
let MatchClause(pattern, guard, clause_body) = &clause.0 else {
unreachable!()
@ -1210,7 +1206,11 @@ impl<'a> Compiler<'a> {
unreachable!()
};
let arity = pattern.len();
if matches!(pattern.last(), Some((Splattern(_), _))) {
has_splat = true;
};
let arity = pattern.len() as u8;
let compiler = match compilers.get_mut(&arity) {
Some(compiler) => compiler,
@ -1223,9 +1223,9 @@ impl<'a> Compiler<'a> {
}
};
compiler.stack_depth += arity;
compiler.stack_depth += arity as usize;
compiler.scope_depth += 1;
compiler.match_depth = arity;
compiler.match_depth = arity as usize;
std::mem::swap(&mut upvalues, &mut compiler.upvalues);
@ -1244,9 +1244,12 @@ impl<'a> Compiler<'a> {
for idx in tup_jump_idxes {
compiler.patch_jump(idx, compiler.len() - idx - 3);
}
for _ in 0..arity {
compiler.emit_op(Op::Pop);
}
// compiler.pop_n(arity as usize);
compiler.emit_op(Op::PopN);
compiler.emit_byte(arity as usize);
// for _ in 0..arity {
// compiler.emit_op(Op::Pop);
// }
compiler.patch_jump(jump_idx, compiler.len() - jump_idx - 3);
let mut no_match_jumps = vec![];
no_match_jumps.push(compiler.stub_jump(Op::JumpIfNoMatch));
@ -1278,22 +1281,36 @@ impl<'a> Compiler<'a> {
std::mem::swap(&mut compiler.upvalues, &mut upvalues);
}
let mut the_chunks = vec![];
let mut compilers = compilers.into_iter().collect::<Vec<_>>();
compilers.sort_by(|(a, _), (b, _)| a.cmp(b));
for (arity, mut compiler) in compilers.into_iter() {
let mut arities = vec![];
let mut chunks = vec![];
for (arity, mut compiler) in compilers {
compiler.emit_op(Op::PanicNoMatch);
let chunk = compiler.chunk;
if crate::DEBUG_COMPILE {
println!("=== function chuncktion: {name}/{arity} ===");
chunk.dissasemble();
}
the_chunks.push((arity as u8, chunk));
arities.push(arity);
chunks.push(chunk);
}
let splat = if has_splat {
arities.iter().fold(0, |max, curr| max.max(*curr))
} else {
0
};
let lfn = crate::value::LFn::Defined {
name,
doc: *doc,
chunks: the_chunks,
arities,
chunks,
splat,
closed: RefCell::new(vec![]),
};
@ -1370,30 +1387,7 @@ impl<'a> Compiler<'a> {
self.emit_byte(self.match_depth);
self.visit(member);
jnm_idxes.push(self.stub_jump(Op::JumpIfNoMatch));
// self.emit_op(Op::JumpIfNoMatch);
// tup_jump_idxes.push(self.len());
// self.emit_byte(0xff);
}
// let jump_idx = self.stub_jump(Op::Jump);
// self.emit_op(Op::Jump);
// let jump_idx = self.len();
// self.emit_byte(0xff);
// for idx in jnm_idxes {
// self.patch_jump(idx, self.len() - idx - 3);
// // self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 2;
// }
// self.emit_op(Op::PopN);
// self.emit_byte(arity);
// for _ in 0..arity {
// self.emit_op(Op::Pop);
// }
// self.patch_jump(jump_idx, self.len() - jump_idx - 3);
// self.chunk.bytecode[jump_idx] = (self.len() - jump_idx) as u8 - 1;
// let mut no_match_jumps = vec![];
// no_match_jumps.push(self.stub_jump(Op::JumpIfNoMatch));
// self.emit_op(Op::JumpIfNoMatch);
// let jnm_idx = self.len();
// self.emit_byte(0xff);
if guard.is_some() {
let guard_expr: &'static Spanned<Ast> =
Box::leak(Box::new(guard.clone().unwrap()));
@ -1414,26 +1408,17 @@ impl<'a> Compiler<'a> {
while self.stack_depth > stack_depth {
self.pop();
}
// self.stack_depth -= arity;
jump_idxes.push(self.stub_jump(Op::Jump));
// self.emit_op(Op::Jump);
// jump_idxes.push(self.len());
// self.emit_byte(0xff);
for idx in jnm_idxes {
self.patch_jump(idx, self.len() - idx - 3);
}
// self.chunk.bytecode[jnm_idx] = (self.len() - jnm_idx) as u8;
self.scope_depth -= 1;
}
self.emit_op(Op::PanicNoMatch);
for idx in jump_idxes {
self.patch_jump(idx, self.len() - idx - 3);
// self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 1;
}
// self.emit_op(Op::PopN);
// self.emit_byte(arity);
self.stack_depth -= arity;
// println!("Op::Load at end of loop at byte {}", self.len());
self.emit_op(Op::Load);
self.stack_depth += 1;
self.leave_loop();
@ -1448,8 +1433,6 @@ impl<'a> Compiler<'a> {
self.emit_byte(self.loop_root());
self.emit_op(Op::Load);
self.jump(Op::JumpBack, self.len() - self.loop_idx());
// self.emit_op(Op::JumpBack);
// self.emit_byte(self.len() - self.loop_idx());
}
Panic(msg) => {
self.visit(msg);

View File

@ -75,8 +75,15 @@ pub fn run(src: &'static str) {
pub fn main() {
env::set_var("RUST_BACKTRACE", "1");
let src = "
let (#{a, ...}, d, (e)) = (#{:a 1, :b 2}, :foo, (:bar))
(a, d, e)
fn arity {
() -> 0
(_) -> 1
(_, _) -> 2
(_, _, _) -> 3
(_, _, _, ...args) -> (:a_lot, args)
}
arity (1, 2, 3, 4, 5, 6)
";
run(src);
}

View File

@ -14,7 +14,9 @@ pub enum LFn {
Defined {
name: &'static str,
doc: Option<&'static str>,
chunks: Vec<(u8, Chunk)>,
arities: Vec<u8>,
chunks: Vec<Chunk>,
splat: u8,
closed: RefCell<Vec<Value>>,
},
}
@ -47,7 +49,23 @@ impl LFn {
pub fn accepts(&self, arity: u8) -> bool {
match self {
LFn::Defined { chunks, .. } => chunks.iter().any(|(a, _)| *a == arity),
LFn::Defined { arities, splat, .. } => {
if arities.contains(&arity) {
return true;
}
if *splat == 0 {
return false;
}
let max_arity = arities.iter().fold(0, |a, b| a.max(*b));
arity > max_arity
}
LFn::Declared { .. } => unreachable!(),
}
}
pub fn splat_arity(&self) -> u8 {
match self {
LFn::Defined { splat, .. } => *splat,
LFn::Declared { .. } => unreachable!(),
}
}
@ -61,7 +79,18 @@ impl LFn {
pub fn chunk(&self, arity: u8) -> &Chunk {
match self {
LFn::Declared { .. } => unreachable!(),
LFn::Defined { chunks, .. } => &chunks.iter().find(|(a, _)| *a == arity).unwrap().1,
LFn::Defined {
arities,
splat,
chunks,
..
} => {
let chunk_pos = arities.iter().position(|a| arity == *a);
match chunk_pos {
Some(pos) => &chunks[pos as usize],
None => &chunks[*splat as usize],
}
}
}
}

View File

@ -2,11 +2,9 @@ use crate::base::BaseFn;
use crate::compiler::{Chunk, Op};
use crate::parser::Ast;
use crate::spans::Spanned;
use crate::value::{LFn, Partial, Value};
use chumsky::prelude::SimpleSpan;
use crate::value::{LFn, Value};
use imbl::{HashMap, Vector};
use num_traits::FromPrimitive;
use std::cell::OnceCell;
use std::cell::RefCell;
use std::fmt;
use std::mem::swap;
@ -97,7 +95,9 @@ impl Vm {
let lfn = LFn::Defined {
name: "user script",
doc: None,
chunks: vec![(0, chunk)],
chunks: vec![chunk],
arities: vec![0],
splat: 0,
closed: RefCell::new(vec![]),
};
let base_fn = Value::Fn(Rc::new(lfn));
@ -261,7 +261,7 @@ impl Vm {
let cond = self.pop();
match cond {
Value::Number(x) if x <= 0.0 => {
self.ip += jump_len as usize + 3;
self.ip += jump_len + 3;
}
Value::Number(..) => {
self.ip += 3;
@ -868,6 +868,15 @@ impl Vm {
val.show()
));
}
let splat_arity = val.as_fn().splat_arity();
if splat_arity > 0 && arity >= splat_arity {
let splatted_args = self.stack.split_off(
self.stack.len() - (arity - splat_arity) as usize - 1,
);
let gathered_args = Vector::from(splatted_args);
self.push(Value::List(Box::new(gathered_args)));
}
let arity = splat_arity.min(arity);
let mut frame = CallFrame {
function: val,
arity,