Skip to content

Commit

Permalink
JIT: Add try/catch
Browse files Browse the repository at this point in the history
Summary:
Set up the JIT try/catch handling in a similar way to the native
backend.

If there are any catch handlers in the function, call `setjmp` at the
start with a stack-allocated `SHJmpBuf`. The catch table is checked
separately later.

Reviewed By: neildhar

Differential Revision: D64199003

fbshipit-source-id: c37f7af67da5a04873730be4caa354530f1d7354
  • Loading branch information
avp authored and facebook-github-bot committed Nov 12, 2024
1 parent 54b2b10 commit 5b029ce
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 37 deletions.
3 changes: 3 additions & 0 deletions include/hermes/BCGen/HBC/BytecodeFileFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,11 @@ struct ExceptionHandlerTableHeader {
/// We need HBCExceptionHandlerInfo other than using ExceptionHandlerInfo
/// directly because we don't need depth in HBC.
struct HBCExceptionHandlerInfo {
/// Start offset of the try, inclusive.
uint32_t start;
/// End offset of the try, exclusive.
uint32_t end;
/// Handler offset.
uint32_t target;
};

Expand Down
10 changes: 10 additions & 0 deletions lib/VM/JIT/DiscoverBB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "hermes/Inst/InstDecode.h"
#include "hermes/VM/CodeBlock.h"
#include "hermes/VM/RuntimeModule.h"

#include "llvh/Support/Debug.h"

Expand Down Expand Up @@ -93,6 +94,15 @@ void discoverBasicBlocks(
// Add the end of the bytecode
addLabel(ip);

auto excTable =
codeBlock->getRuntimeModule()->getBytecode()->getExceptionTable(
codeBlock->getFunctionID());

// Add labels for the handler of each try block.
for (const auto &tryRegion : excTable) {
addLabel(begin + tryRegion.target);
}

// Sort all labels into a sequence of basic blocks.
basicBlocks.clear();
basicBlocks.reserve(labelSet.size());
Expand Down
1 change: 1 addition & 0 deletions lib/VM/JIT/RuntimeOffsets.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct RuntimeOffsets {
static constexpr uint32_t currentFrame = offsetof(Runtime, currentFrame_);
static constexpr uint32_t globalObject = offsetof(Runtime, global_);
static constexpr uint32_t thrownValue = offsetof(Runtime, thrownValue_);
static constexpr uint32_t shLocals = offsetof(Runtime, shLocals);
static constexpr uint32_t nativeStackHigh =
offsetof(Runtime, overflowGuard_) +
offsetof(StackOverflowGuard, nativeStackHigh);
Expand Down
16 changes: 14 additions & 2 deletions lib/VM/JIT/arm64/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,17 @@ JITCompiledFunctionPtr JITContext::Compiler::compileCodeBlockImpl() {
compileBB(bbIndex);
}

auto excTable =
codeBlock_->getRuntimeModule()->getBytecode()->getExceptionTable(
codeBlock_->getFunctionID());
llvh::SmallVector<const asmjit::Label *, 4> handlers{};
handlers.reserve(excTable.size());
for (const auto &entry : excTable) {
handlers.push_back(&bbLabels_.at(ofsToBBIndex_.at(entry.target)));
}

em_.leave();
codeBlock_->setJITCompiled(em_.addToRuntime(jc_.impl_->jr));
codeBlock_->setJITCompiled(em_.addToRuntime(jc_.impl_->jr, handlers));

LLVM_DEBUG(
llvh::outs() << "\n Bytecode:";
Expand Down Expand Up @@ -286,7 +295,6 @@ JITCompiledFunctionPtr JITContext::Compiler::compileCodeBlockImpl() {
}

EMIT_UNIMPLEMENTED(GetEnvironment)
EMIT_UNIMPLEMENTED(Catch)
EMIT_UNIMPLEMENTED(DirectEval)
EMIT_UNIMPLEMENTED(AsyncBreakCheck)

Expand Down Expand Up @@ -739,6 +747,10 @@ inline void JITContext::Compiler::emitRet(const inst::RetInst *inst) {
em_.ret(FR(inst->op1));
}

inline void JITContext::Compiler::emitCatch(const inst::CatchInst *inst) {
em_.catchInst(FR(inst->op1));
}

inline void JITContext::Compiler::emitGetGlobalObject(
const inst::GetGlobalObjectInst *inst) {
em_.getGlobalObject(FR(inst->op1));
Expand Down
174 changes: 140 additions & 34 deletions lib/VM/JIT/arm64/JitEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ void Emitter::enter(uint32_t numCount, uint32_t npCount) {
frameRegs_[frIndex].globalType = FRType::UnknownNonPtr;
}

if (!codeBlock_->getRuntimeModule()
->getBytecode()
->getExceptionTable(codeBlock_->getFunctionID())
.empty())
catchTableLabel_ = a.newNamedLabel("CATCH_TABLE");

frameSetup(
frameRegs_.size(), nextGp - kGPSaved.first, nextVec - kVecSaved.first);
}
Expand All @@ -408,7 +414,10 @@ void Emitter::comment(const char *fmt, ...) {
a.comment(buf);
}

JITCompiledFunctionPtr Emitter::addToRuntime(asmjit::JitRuntime &jr) {
JITCompiledFunctionPtr Emitter::addToRuntime(
asmjit::JitRuntime &jr,
llvh::ArrayRef<const asmjit::Label *> exceptionHandlers) {
emitCatchTable(exceptionHandlers);
emitSlowPaths();
emitThunks();
emitROData();
Expand Down Expand Up @@ -504,25 +513,27 @@ void Emitter::frameSetup(
gpSaveCount_ = gpSaveCount;
vecSaveCount_ = vecSaveCount;

// +-------+----- old sp
// | x30 |
// +-------+
// | x29 |
// +-------+----- new x29
// | ... |
// +-------+
// | x21 |
// +-------+
// | x20 |
// +-------+
// | x19 |
// +-------+---- new sp
a.sub(
a64::sp,
a64::sp,
(((gpSaveCount + 1) & ~1) + ((vecSaveCount + 1) & ~1) + 2) * 8);

unsigned stackOfs = 0;
// Higher addresses are at the top.
// +-----------------------------+<---- old sp
// | x30 |
// +-----------------------------+
// | x29 |
// +-----------------------------+<---- new x29
// | ... |
// +-----------------------------+
// | x21 |
// +-----------------------------+
// | x20 |
// +-----------------------------+
// | x19 |
// +-----------------------------+
// | Saved SHLocals* (optional) |
// +-----------------------------+
// | SHJmpBuf (optional) |
// +-----------------------------+<--- new sp
a.sub(a64::sp, a64::sp, getStackSize());

unsigned stackOfs = getSavedRegsOffset();
for (unsigned i = 0; i < gpSaveCount; i += 2, stackOfs += 16) {
if (i + 1 < gpSaveCount)
a.stp(a64::GpX(19 + i), a64::GpX(20 + i), a64::Mem(a64::sp, stackOfs));
Expand All @@ -544,6 +555,15 @@ void Emitter::frameSetup(
comment("// xRuntime");
a.mov(xRuntime, a64::x0);

// Save the SHLocals pointer because we don't allocate and push a new
// SHLocals in the JIT.
// Used in CatchInst to restore state.
if (catchTableLabel_.isValid()) {
comment("// saved SHLocals *");
a.ldr(a64::x0, a64::Mem(xRuntime, RuntimeOffsets::shLocals));
a.str(a64::x0, a64::Mem(a64::sp, getSavedSHLocalsOffset()));
}

#ifndef HERMES_CHECK_NATIVE_STACK
#error Only native stack checking is supported in the JIT
#endif
Expand Down Expand Up @@ -718,6 +738,27 @@ void Emitter::frameSetup(
em, void (*)(SHRuntime *), _sh_throw_register_stack_overflow);
}});

if (catchTableLabel_.isValid()) {
comment("// _sh_try");
uint32_t jmpBufOffset = getJmpBufOffset();
// buf->prev = shr->shCurJmpBuf;
a.ldr(a64::x0, a64::Mem(xRuntime, offsetof(SHRuntime, shCurJmpBuf)));
a.str(a64::x0, a64::Mem(a64::sp, jmpBufOffset + offsetof(SHJmpBuf, prev)));

// shr->shCurJmpBuf = buf;
a.add(a64::x0, a64::sp, jmpBufOffset);
a.str(a64::x0, a64::Mem(xRuntime, offsetof(SHRuntime, shCurJmpBuf)));

// _setjmp(buf->buf);
a.add(a64::x0, a64::sp, jmpBufOffset + offsetof(SHJmpBuf, buf));
// setjmp can't throw and it'll be called once, so don't use a thunk.
EMIT_RUNTIME_CALL_WITHOUT_THUNK_AND_SAVED_IP(
*this, int (*)(jmp_buf), _setjmp);
// If this a catch, go to the catch table to jump to either a handler BB or
// rethrow.
a.cbnz(a64::x0, catchTableLabel_);
}

if (dumpJitCode_ & DumpJitCode::EntryExit) {
comment("// print entry");
a.mov(a64::w0, 1);
Expand All @@ -739,6 +780,15 @@ void Emitter::leave() {
EMIT_RUNTIME_CALL_WITHOUT_SAVED_IP(
*this, void (*)(bool, const char *), _sh_print_function_entry_exit);
}

if (catchTableLabel_.isValid()) {
comment("// _sh_end_try");
// shr->shCurJmpBuf = buf->prev
uint32_t jmpBufOffset = getJmpBufOffset();
a.ldr(a64::x0, a64::Mem(a64::sp, jmpBufOffset + offsetof(SHJmpBuf, prev)));
a.str(a64::x0, a64::Mem(xRuntime, offsetof(SHRuntime, shCurJmpBuf)));
}

// _sh_leave(shr, &locals.head, frame);
// Restore the previous stack frame.
a.str(xFrame, a64::Mem(xRuntime, RuntimeOffsets::stackPointer));
Expand All @@ -753,7 +803,7 @@ void Emitter::leave() {
// register.
a.mov(a64::x0, a64::x21);

unsigned stackOfs = 0;
unsigned stackOfs = getSavedRegsOffset();
for (unsigned i = 0; i < gpSaveCount_; i += 2, stackOfs += 16) {
if (i + 1 < gpSaveCount_)
a.ldp(a64::GpX(19 + i), a64::GpX(20 + i), a64::Mem(a64::sp, stackOfs));
Expand All @@ -771,10 +821,7 @@ void Emitter::leave() {
}
a.ldp(a64::x29, a64::x30, a64::Mem(a64::sp, stackOfs));

a.add(
a64::sp,
a64::sp,
(((gpSaveCount_ + 1) & ~1) + ((vecSaveCount_ + 1) & ~1) + 2) * 8);
a.add(a64::sp, a64::sp, getStackSize());

a.ret(a64::x30);
}
Expand Down Expand Up @@ -1328,6 +1375,23 @@ void Emitter::profilePoint(uint16_t pointIndex) {
#endif
}

void Emitter::catchInst(FR frRes) {
comment("// Catch r%u", frRes.index());

HWReg hwTemp = allocTempGpX();
HWReg hwRes = getOrAllocFRInGpX(frRes, false);
frUpdatedWithHW(frRes, hwRes);
freeReg(hwTemp);

// Catch simply returns the thrown value and clears it.

// Read thrown value.
a.ldr(hwRes.a64GpX(), a64::Mem(xRuntime, RuntimeOffsets::thrownValue));
// Clear thrown value.
loadBits64InGp(hwTemp.a64GpX(), _sh_ljs_empty().raw, "empty");
a.str(hwTemp.a64GpX(), a64::Mem(xRuntime, RuntimeOffsets::thrownValue));
}

void Emitter::ret(FR frValue) {
movHWFromFR(HWReg::gpX(21), frValue);
a.b(returnLabel_);
Expand Down Expand Up @@ -1892,10 +1956,10 @@ void Emitter::fastArrayLoad(FR frRes, FR frArr, FR frIdx) {
a.ccmp(
hwTmpSize.a64GpX().w(), hwTmpIdxGpX.a64GpX().w(), 0, a64::CondCode::kEQ);
// If the index is out-of-bounds jump to the failure path.
// TODO: We currently disregard the state of the registers on an OOB access
// because we cannot JIT try-catch, so we are guaranteed to be leaving the
// current function. If we add support for JIT of try-catch, we will have to
// sync registers when the access is inside a try region.
// We will have to sync registers when the access is inside a try region
// because we could read from the FRs again in this function.
if (isInTry())
syncAllFRTempExcept(frRes != frArr && frRes != frIdx ? frRes : FR());
a.b_ls(slowPathLab);

// Add the offset of the actual data in the ArrayStorage.
Expand Down Expand Up @@ -2558,7 +2622,10 @@ void Emitter::iteratorClose(FR frIteratorOrIdx, bool ignoreExceptions) {
void Emitter::throwInst(FR frInput) {
comment("// Throw r%u", frInput.index());

syncAllFRTempExcept({});
// We have to sync registers when the throw is inside a try region
// because we could read from the FRs again in this function.
if (isInTry())
syncAllFRTempExcept({});
movHWFromFR(HWReg::gpX(1), frInput);
freeAllFRTempExcept({});

Expand All @@ -2571,12 +2638,14 @@ void Emitter::throwIfEmpty(FR frRes, FR frInput) {

asmjit::Label slowPathLab = newSlowPathLabel();

// TODO: Add back the sync/free calls inside try.
// Outside a try it's not observable behavior.
// syncAllFRTempExcept(frRes != frInput ? frRes : FR());
// We have to sync registers when the throw is inside a try region
// because we could read from the FRs again in this function.
if (isInTry())
syncAllFRTempExcept(frRes != frInput ? frRes : FR());
HWReg hwInput = getOrAllocFRInGpX(frInput, true);
HWReg hwTemp = allocTempGpX();
// freeAllFRTempExcept({});
if (isInTry())
freeAllFRTempExcept({});
freeReg(hwTemp);

emit_sh_ljs_is_empty(a, hwTemp.a64GpX(), hwInput.a64GpX());
Expand Down Expand Up @@ -3473,6 +3542,43 @@ asmjit::Label Emitter::registerThunk(void *fn, const char *name) {
return thunks_[it->second].first;
}

void Emitter::emitCatchTable(
llvh::ArrayRef<const asmjit::Label *> exceptionHandlers) {
// No trys in the function, nothing to do here.
if (!catchTableLabel_.isValid())
return;

a.bind(catchTableLabel_);

asmjit::Label addressTableLab = a.newLabel();

// Find the catch target for the exception.
a.mov(a64::x0, xRuntime);
loadBits64InGp(a64::x1, (uint64_t)codeBlock_, "CodeBlock");
a.mov(a64::x2, xFrame);
a.add(a64::x3, a64::sp, getJmpBufOffset());
a.ldr(a64::x4, a64::Mem(a64::sp, getSavedSHLocalsOffset()));
a.adr(a64::x5, addressTableLab);
EMIT_RUNTIME_CALL_WITHOUT_THUNK_AND_SAVED_IP(
*this,
void *(*)(SHRuntime *,
SHCodeBlock *,
SHLegacyValue *,
SHJmpBuf *,
SHLocals *,
int32_t *),
_jit_find_catch_target);

// The address to branch to was returned here.
a.br(a64::x0);

// Table of offsets from addressTableLab to jump to.
a.bind(addressTableLab);
for (const asmjit::Label *handler : exceptionHandlers) {
a.embedLabelDelta(*handler, addressTableLab, /* size */ 4);
}
}

void Emitter::emitSlowPaths() {
while (!slowPaths_.empty()) {
SlowPath &sp = slowPaths_.front();
Expand Down
Loading

0 comments on commit 5b029ce

Please sign in to comment.