Skip to content

Commit

Permalink
8336768: Allow captureCallState and critical linker options to be com…
Browse files Browse the repository at this point in the history
…bined

Reviewed-by: mcimadamore
  • Loading branch information
JornVernee committed Dec 3, 2024
1 parent 63af2f4 commit 8cad043
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 88 deletions.
2 changes: 0 additions & 2 deletions src/java.base/share/classes/java/lang/foreign/Linker.java
Original file line number Diff line number Diff line change
Expand Up @@ -852,8 +852,6 @@ static Option firstVariadicArg(int index) {
* // use errno
* }
* }
* <p>
* This linker option can not be combined with {@link #critical}.
*
* @param capturedState the names of the values to save
* @throws IllegalArgumentException if at least one of the provided
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -195,6 +195,10 @@ public boolean needsTransition() {
return !linkerOptions.isCritical();
}

public boolean usingAddressPairs() {
return linkerOptions.allowsHeapAccess();
}

public int numLeadingParams() {
return 2 + (linkerOptions.hasCapturedCallState() ? 1 : 0); // 2 for addr, allocator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,18 @@ public CallingSequence build() {
MethodType calleeMethodType;
if (!forUpcall) {
if (linkerOptions.hasCapturedCallState()) {
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.unboxAddress(),
Binding.vmStore(abi.capturedStateStorage(), long.class)));
if (linkerOptions.allowsHeapAccess()) {
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.dup(),
Binding.segmentBase(),
Binding.vmStore(abi.capturedStateStorage(), Object.class),
Binding.segmentOffsetAllowHeap(),
Binding.vmStore(null, long.class)));
} else {
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.unboxAddress(),
Binding.vmStore(abi.capturedStateStorage(), long.class)));
}
}
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.unboxAddress(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public MethodHandle getBoundMethodHandle() {
leafType,
callingSequence.needsReturnBuffer(),
callingSequence.capturedStateMask(),
callingSequence.needsTransition()
callingSequence.needsTransition(),
callingSequence.usingAddressPairs()
);
MethodHandle handle = JLIA.nativeMethodHandle(nep);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ private static LinkerOptions forShared(BiConsumer<LinkerOptionImpl, FunctionDesc
optionMap.put(option.getClass(), opImpl);
}

LinkerOptions linkerOptions = new LinkerOptions(optionMap);
if (linkerOptions.hasCapturedCallState() && linkerOptions.isCritical()) {
throw new IllegalArgumentException("Incompatible linker options: captureCallState, critical");
}
return linkerOptions;
return new LinkerOptions(optionMap);
}

public static LinkerOptions empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ public static NativeEntryPoint make(ABIDescriptor abi,
MethodType methodType,
boolean needsReturnBuffer,
int capturedStateMask,
boolean needsTransition) {
boolean needsTransition,
boolean usingAddressPairs) {
if (returnMoves.length > 1 != needsReturnBuffer) {
throw new AssertionError("Multiple register return, but needsReturnBuffer was false");
}
checkType(methodType, needsReturnBuffer, capturedStateMask);
checkMethodType(methodType, needsReturnBuffer, capturedStateMask, usingAddressPairs);

CacheKey key = new CacheKey(methodType, abi, Arrays.asList(argMoves), Arrays.asList(returnMoves),
needsReturnBuffer, capturedStateMask, needsTransition);
Expand All @@ -80,14 +81,26 @@ public static NativeEntryPoint make(ABIDescriptor abi,
});
}

private static void checkType(MethodType methodType, boolean needsReturnBuffer, int savedValueMask) {
if (methodType.parameterType(0) != long.class) {
throw new AssertionError("Address expected as first param: " + methodType);
private static void checkMethodType(MethodType methodType, boolean needsReturnBuffer, int savedValueMask,
boolean usingAddressPairs) {
int checkIdx = 0;
checkParamType(methodType, checkIdx++, long.class, "Function address");
if (needsReturnBuffer) {
checkParamType(methodType, checkIdx++, long.class, "Return buffer address");
}
int checkIdx = 1;
if ((needsReturnBuffer && methodType.parameterType(checkIdx++) != long.class)
|| (savedValueMask != 0 && methodType.parameterType(checkIdx) != long.class)) {
throw new AssertionError("return buffer and/or preserved value address expected: " + methodType);
if (savedValueMask != 0) { // capturing call state
if (usingAddressPairs) {
checkParamType(methodType, checkIdx++, Object.class, "Capture state heap base");
checkParamType(methodType, checkIdx, long.class, "Capture state offset");
} else {
checkParamType(methodType, checkIdx, long.class, "Capture state address");
}
}
}

private static void checkParamType(MethodType methodType, int checkIdx, Class<?> expectedType, String name) {
if (methodType.parameterType(checkIdx) != expectedType) {
throw new AssertionError(name + " expected at index " + checkIdx + ": " + methodType);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,14 @@ private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args
acquiredSessions.add(targetImpl);

MemorySegment capturedState = null;
Object captureStateHeapBase = null;
if (invData.capturedStateMask() != 0) {
capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]);
if (!invData.allowsHeapAccess) {
SharedUtils.checkNative(capturedState);
} else {
captureStateHeapBase = capturedState.heapBase().orElse(null);
}
MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl();
capturedStateImpl.acquire0();
acquiredSessions.add(capturedStateImpl);
Expand Down Expand Up @@ -199,7 +205,8 @@ private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args
retSeg = (invData.returnLayout() instanceof GroupLayout ? returnAllocator : arena).allocate(invData.returnLayout);
}

LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs, capturedState, invData.capturedStateMask(),
LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs,
captureStateHeapBase, capturedState, invData.capturedStateMask(),
heapBases, args.length);

Reference.reachabilityFence(invData.cif());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ private static boolean tryLoadLibrary() {
* @see jdk.internal.foreign.abi.CapturableState
*/
static void doDowncall(MemorySegment cif, MemorySegment target, MemorySegment retPtr, MemorySegment argPtrs,
MemorySegment capturedState, int capturedStateMask,
Object captureStateHeapBase, MemorySegment capturedState, int capturedStateMask,
Object[] heapBases, int numArgs) {
doDowncall(cif.address(), target.address(),
retPtr == null ? 0 : retPtr.address(), argPtrs.address(),
captureStateHeapBase,
capturedState == null ? 0 : capturedState.address(), capturedStateMask,
heapBases, numArgs);
}
Expand Down Expand Up @@ -212,7 +213,7 @@ private static void checkStatus(int code) {
private static native int createClosure(long cif, Object userData, long[] ptrs);
private static native void freeClosure(long closureAddress, long globalTarget);
private static native void doDowncall(long cif, long fn, long rvalue, long avalues,
long capturedState, int capturedStateMask,
Object captureStateHeapBase, long capturedState, int capturedStateMask,
Object[] heapBases, int numArgs);

private static native int ffi_prep_cif(long cif, int abi, int nargs, long rtype, long atypes);
Expand Down
26 changes: 19 additions & 7 deletions src/java.base/share/native/libfallbackLinker/fallbackLinker.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,16 @@ static void do_capture_state(int32_t* value_ptr, int captured_state_mask) {

JNIEXPORT void JNICALL
Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclass cls, jlong cif, jlong fn, jlong rvalue,
jlong avalues, jlong jcaptured_state, jint captured_state_mask,
jlong avalues,
jarray capture_state_heap_base, jlong captured_state_offset,
jint captured_state_mask,
jarray heapBases, jint numArgs) {
void** carrays;
int capture_state_hb_offset = numArgs;
int32_t* captured_state_addr = jlong_to_ptr(captured_state_offset);
if (heapBases != NULL) {
void** aptrs = jlong_to_ptr(avalues);
carrays = malloc(sizeof(void*) * numArgs);
carrays = malloc(sizeof(void*) * (numArgs + 1));
for (int i = 0; i < numArgs; i++) {
jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i);
if (hb != NULL) {
Expand All @@ -130,24 +134,32 @@ Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclas
*((void**)aptrs[i]) = arrayPtr + offset;
}
}
if (capture_state_heap_base != NULL) {
jboolean isCopy;
jbyte* arrayPtr = (*env)->GetPrimitiveArrayCritical(env, capture_state_heap_base, &isCopy);
carrays[capture_state_hb_offset] = arrayPtr;
captured_state_addr = (int32_t*) (arrayPtr + captured_state_offset);
}
}

ffi_call(jlong_to_ptr(cif), jlong_to_ptr(fn), jlong_to_ptr(rvalue), jlong_to_ptr(avalues));

if (captured_state_mask != 0) {
do_capture_state(captured_state_addr, captured_state_mask);
}

if (heapBases != NULL) {
for (int i = 0; i < numArgs; i++) {
jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i);
if (hb != NULL) {
(*env)->ReleasePrimitiveArrayCritical(env, hb, carrays[i], JNI_COMMIT);
}
}
if (capture_state_heap_base != NULL) {
(*env)->ReleasePrimitiveArrayCritical(env, capture_state_heap_base, carrays[capture_state_hb_offset], JNI_COMMIT);
}
free(carrays);
}

if (captured_state_mask != 0) {
int32_t* captured_state = jlong_to_ptr(jcaptured_state);
do_capture_state(captured_state, captured_state_mask);
}
}

static void do_upcall(ffi_cif* cif, void* ret, void** args, void* user_data) {
Expand Down
5 changes: 0 additions & 5 deletions test/jdk/java/foreign/TestIllegalLink.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,6 @@ public static Object[][] types() {
NO_OPTIONS,
"has unexpected size"
},
{
FunctionDescriptor.ofVoid(),
new Linker.Option[]{Linker.Option.critical(false), Linker.Option.captureCallState("errno")},
"Incompatible linker options: captureCallState, critical"
},
}));

for (ValueLayout illegalLayout : List.of(C_CHAR, ValueLayout.JAVA_CHAR, C_BOOL, C_SHORT, C_FLOAT)) {
Expand Down
Loading

0 comments on commit 8cad043

Please sign in to comment.