Skip to content

Commit

Permalink
Refactor expandarray to use chain guards
Browse files Browse the repository at this point in the history
  • Loading branch information
maximecb committed Aug 2, 2023
1 parent eb3f093 commit a00de31
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 26 deletions.
11 changes: 11 additions & 0 deletions bootstraptest/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,17 @@ def expandarray_rhs_too_small
expandarray_rhs_too_small
}

assert_equal '[nil, 2, nil]', %q{
def foo(arr)
a, b, c = arr
end
a, b, c1 = foo([0, 1])
a, b, c2 = foo([0, 1, 2])
a, b, c3 = foo([0, 1])
[c1, c2, c3]
}

assert_equal '[1, [2]]', %q{
def expandarray_splat
a, *b = [1, 2]
Expand Down
91 changes: 66 additions & 25 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ fn guard_object_is_not_ruby2_keyword_hash(
fn gen_expandarray(
jit: &mut JITState,
asm: &mut Assembler,
_ocb: &mut OutlinedCb,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
// Both arguments are rb_num_t which is unsigned
let num = jit.get_arg(0).as_usize();
Expand Down Expand Up @@ -1500,48 +1500,89 @@ fn gen_expandarray(
return Some(KeepCompiling);
}

// Defer compilation so we can specialize on a runtime `self`
if !jit.at_current_insn() {
defer_compilation(jit, asm, ocb);
return Some(EndBlock);
}

let comptime_recv = jit.peek_at_stack(&asm.ctx, 0);

// If the comptime receiver is not an array, bail
if comptime_recv.class_of() != unsafe { rb_cArray } {
return None;
}

// Get the compile-time array length
let comptime_len = unsafe { rb_yjit_array_len(comptime_recv) as u32 };
//println!("comptime_len={}", comptime_len);

// Move the array from the stack and check that it's an array.
guard_object_is_array(
asm,
array_opnd,
array_opnd.into(),
Counter::expandarray_not_array,
);
let array_opnd = asm.stack_pop(1); // pop after using the type info

// If we don't actually want any values, then just return.
if num == 0 {
asm.stack_pop(1); // pop the array
return Some(KeepCompiling);
}

let array_opnd = asm.stack_opnd(0);
let array_reg = asm.load(array_opnd);
let array_len_opnd = get_array_len(asm, array_reg);

// Load the address of the embedded array into REG1.
// (struct RArray *)(obj)->as.ary
let array_reg = asm.load(array_opnd);
let ary_opnd = asm.lea(Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RARRAY_AS_ARY));

// Conditionally load the address of the heap array into REG1.
// (struct RArray *)(obj)->as.heap.ptr
let flags_opnd = Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RBASIC_FLAGS);
asm.test(flags_opnd, Opnd::UImm(RARRAY_EMBED_FLAG as u64));
let heap_ptr_opnd = Opnd::mem(
usize::BITS as u8,
asm.load(array_opnd),
RUBY_OFFSET_RARRAY_AS_HEAP_PTR,
// Guard on the comptime/expected array length
asm.comment(&format!("guard array length == {}", comptime_len));
asm.cmp(array_len_opnd, comptime_len.into());
jit_chain_guard(
JCC_JNE,
jit,
asm,
ocb,
OPT_AREF_MAX_CHAIN_DEPTH,
Counter::expandarray_chain_max_depth,
);
let ary_opnd = asm.csel_nz(ary_opnd, heap_ptr_opnd);

// Loop backward through the array and push each element onto the stack.
for i in (0..num).rev() {
let top = asm.stack_push(Type::Unknown);
let offset = i32::try_from(i * SIZEOF_VALUE).unwrap();
let array_opnd = asm.stack_pop(1); // pop after using the type info

//println!("generating loads");

// If the array has length 0, then we don't even need the array pointer
if comptime_len == 0 {
// Loop backward through the array and push Qnils onto the stack.
for _ in 0..num {
let top = asm.stack_push(Type::Nil);
asm.mov(top, Qnil.into());
}
} else {
let array_reg = asm.load(array_opnd);
let ary_opnd = asm.lea(Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RARRAY_AS_ARY));

// If the element index is less than the length of the array, load it
asm.cmp(array_len_opnd, i.into());
let elem_opnd = asm.csel_g(Opnd::mem(64, ary_opnd, offset), Qnil.into());
asm.mov(top, elem_opnd);
// Conditionally load the address of the heap array
// (struct RArray *)(obj)->as.heap.ptr
let flags_opnd = Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RBASIC_FLAGS);
asm.test(flags_opnd, Opnd::UImm(RARRAY_EMBED_FLAG as u64));
let heap_ptr_opnd = Opnd::mem(
usize::BITS as u8,
asm.load(array_opnd),
RUBY_OFFSET_RARRAY_AS_HEAP_PTR,
);
let ary_opnd = asm.csel_nz(ary_opnd, heap_ptr_opnd);

// Loop backward through the array and push each element onto the stack.
for i in (0..num).rev() {
let top = asm.stack_push(Type::Unknown);
let offset = i32::try_from(i * SIZEOF_VALUE).unwrap();

// Missing elements are Qnil
asm.comment(&format!("load array[{}]", i));
let elem_opnd = if (i as u32) < comptime_len { Opnd::mem(64, ary_opnd, offset) } else { Qnil.into() };
asm.mov(top, elem_opnd);
}
}

Some(KeepCompiling)
Expand Down Expand Up @@ -7393,7 +7434,7 @@ fn gen_leave(
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
// Only the return value should be on the stack
assert_eq!(1, asm.ctx.get_stack_size());
assert_eq!(1, asm.ctx.get_stack_size(), "leave instruction expects stack size 1");

let ocb_asm = Assembler::new();

Expand Down
2 changes: 1 addition & 1 deletion yjit/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ make_counters! {
expandarray_splat,
expandarray_postarg,
expandarray_not_array,
expandarray_rhs_too_small,
expandarray_chain_max_depth,

// getblockparam
gbp_wb_required,
Expand Down

0 comments on commit a00de31

Please sign in to comment.