diff --git a/sdk/mx.sdk/mx_sdk_vm_impl.py b/sdk/mx.sdk/mx_sdk_vm_impl.py index 34ae5bd1da80..347f34953a25 100644 --- a/sdk/mx.sdk/mx_sdk_vm_impl.py +++ b/sdk/mx.sdk/mx_sdk_vm_impl.py @@ -1210,7 +1210,10 @@ def is_ee_supported(self): def is_pgo_supported(self): return self.is_ee_supported() - def native_image(self, build_args, output_file, allow_server=False, nonZeroIsFatal=True, out=None, err=None): + search_tool = 'strings' + has_search_tool = shutil.which(search_tool) is not None + + def native_image(self, build_args, output_file, out=None, err=None, find_bad_strings=False): assert self._svm_supported stage1 = get_stage1_graalvm_distribution() native_image_project_name = GraalVmLauncher.launcher_project_name(mx_sdk.LauncherConfig(mx.exe_suffix('native-image'), [], "", []), stage1=True) @@ -1220,7 +1223,22 @@ def native_image(self, build_args, output_file, allow_server=False, nonZeroIsFat native_image_command += svm_experimental_options([ '-H:Path=' + output_directory or ".", ]) - return mx.run(native_image_command, nonZeroIsFatal=nonZeroIsFatal, out=out, err=err) + + mx.run(native_image_command, nonZeroIsFatal=True, out=out, err=err) + + if find_bad_strings and not mx.is_windows(): + if not self.__class__.has_search_tool: + mx.abort(f"Searching for strings requires '{self.__class__.search_tool}' executable.") + try: + strings_in_image = subprocess.check_output([self.__class__.search_tool, output_file], stderr=None).decode().strip().split('\n') + bad_strings = (output_directory, dirname(native_image_bin)) + for entry in strings_in_image: + for bad_string in bad_strings: + if bad_string in entry: + mx.abort(f"Found forbidden string '{bad_string}' in native image {output_file}.") + + except subprocess.CalledProcessError: + mx.abort(f"Using '{self.__class__.search_tool}' to search for strings in native image {output_file} failed.") def is_debug_supported(self): return self._debug_supported @@ -2367,7 +2385,7 @@ def build(self): mx.ensure_dir_exists(dirname(output_file)) # Disable build server (different Java properties on each build prevent server reuse) - self.svm_support.native_image(build_args, output_file) + self.svm_support.native_image(build_args, output_file, find_bad_strings=True) with open(self._get_command_file(), 'w') as f: f.writelines((l + os.linesep for l in build_args)) diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/IsolateArgumentParser.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/IsolateArgumentParser.java index 4ded2d09e487..0440b8703d2b 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/IsolateArgumentParser.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/IsolateArgumentParser.java @@ -183,10 +183,21 @@ public void persistOptions(CLongPointer parsedArgs) { public void verifyOptionValues() { for (int i = 0; i < OPTION_COUNT; i++) { - validate(OPTIONS[i], getOptionValue(i)); + RuntimeOptionKey option = OPTIONS[i]; + if (shouldValidate(option)) { + validate(option, getOptionValue(i)); + } } } + private static boolean shouldValidate(RuntimeOptionKey option) { + if (SubstrateOptions.UseSerialGC.getValue()) { + /* The serial GC supports changing the heap size at run-time to some degree. */ + return option != SubstrateGCOptions.MinHeapSize && option != SubstrateGCOptions.MaxHeapSize && option != SubstrateGCOptions.MaxNewSize; + } + return true; + } + @Uninterruptible(reason = "Called from uninterruptible code.", mayBeInlined = true) public static boolean getBooleanOptionValue(int index) { return PARSED_OPTION_VALUES[index] == 1; diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jdk/RuntimeSupport.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jdk/RuntimeSupport.java index d108e7115982..d6ea31d654f0 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jdk/RuntimeSupport.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jdk/RuntimeSupport.java @@ -35,6 +35,7 @@ import org.graalvm.nativeimage.VMRuntime; import org.graalvm.nativeimage.impl.VMRuntimeSupport; +import com.oracle.svm.core.IsolateArgumentParser; import com.oracle.svm.core.Isolates; import com.oracle.svm.core.feature.AutomaticallyRegisteredImageSingleton; import com.oracle.svm.core.heap.HeapSizeVerifier; @@ -92,7 +93,7 @@ public boolean isUninitialized() { public void initialize() { boolean shouldInitialize = initializationState.compareAndSet(InitializationState.Uninitialized, InitializationState.InProgress); if (shouldInitialize) { - // GR-35186: we should verify that none of the early parsed isolate arguments changed. + IsolateArgumentParser.singleton().verifyOptionValues(); HeapSizeVerifier.verifyHeapOptions(); executeHooks(startupHooks); diff --git a/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/image/CCLinkerInvocation.java b/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/image/CCLinkerInvocation.java index 91604578a5be..4b1990d800da 100644 --- a/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/image/CCLinkerInvocation.java +++ b/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/image/CCLinkerInvocation.java @@ -320,6 +320,11 @@ protected void setOutputKind(List cmd) { break; case SHARED_LIBRARY: cmd.add("-shared"); + /* + * Ensure shared library name in image does not use fully qualified build-path + * (GR-46837) + */ + cmd.add("-Wl,-soname=" + outputFile.getFileName()); break; default: VMError.shouldNotReachHereUnexpectedInput(imageKind); // ExcludeFromJacocoGeneratedReport diff --git a/truffle/CHANGELOG.md b/truffle/CHANGELOG.md index 721712772f8c..5eff94bf8533 100644 --- a/truffle/CHANGELOG.md +++ b/truffle/CHANGELOG.md @@ -50,6 +50,7 @@ This changelog summarizes major changes between Truffle versions relevant to lan * Bundle the necessary files into a jar distribution. * Implement the `InternalResource` interface for handling the resource file unpacking. * Call the `Env#getInternalResource` when the language or instrument needs the bundled resource files. This method ensures that the requested `InternalResource` is unpacked and provides a directory containing the unpacked files. Since unpacking internal resources can be an expensive operation, the implementation ensures that internal resources are cached. +* GR-44464 Added `TruffleString.ToValidStringNode` for encoding-level string sanitization. ## Version 23.0.0 diff --git a/truffle/docs/TruffleStrings.md b/truffle/docs/TruffleStrings.md index 8989ab47360d..a0bde6b7e478 100644 --- a/truffle/docs/TruffleStrings.md +++ b/truffle/docs/TruffleStrings.md @@ -108,6 +108,8 @@ Conversion: Convert a MutableTruffleString to an immutable TruffleString. * [AsManaged](https://www.graalvm.org/truffle/javadoc/com/oracle/truffle/api/strings/TruffleString.AsManagedNode.html): Convert a TruffleString backed by a native pointer to one backed by a java byte array. +* [ToValidString](https://www.graalvm.org/truffle/javadoc/com/oracle/truffle/api/strings/TruffleString.ToValidStringNode.html): + Convert a TruffleString to a version that is encoded correctly. * [CopyToByteArray](https://www.graalvm.org/truffle/javadoc/com/oracle/truffle/api/strings/TruffleString.CopyToByteArrayNode.html): Copy a string's content into a byte array. * [GetInternalByteArray](https://www.graalvm.org/truffle/javadoc/com/oracle/truffle/api/strings/TruffleString.GetInternalByteArrayNode.html): diff --git a/truffle/src/com.oracle.truffle.api.strings.test/src/com/oracle/truffle/api/strings/test/ops/TStringToValidStringTest.java b/truffle/src/com.oracle.truffle.api.strings.test/src/com/oracle/truffle/api/strings/test/ops/TStringToValidStringTest.java new file mode 100644 index 000000000000..2bbd8dae9ea2 --- /dev/null +++ b/truffle/src/com.oracle.truffle.api.strings.test/src/com/oracle/truffle/api/strings/test/ops/TStringToValidStringTest.java @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * The Universal Permissive License (UPL), Version 1.0 + * + * Subject to the condition set forth below, permission is hereby granted to any + * person obtaining a copy of this software, associated documentation and/or + * data (collectively the "Software"), free of charge and under any and all + * copyright rights in the Software, and any and all patent rights owned or + * freely licensable by each licensor hereunder covering either (i) the + * unmodified Software as contributed to or provided by such licensor, or (ii) + * the Larger Works (as defined below), to deal in both + * + * (a) the Software, and + * + * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if + * one is included with the Software each a "Larger Work" to which the Software + * is contributed by such licensors), + * + * without restriction, including without limitation the rights to copy, create + * derivative works of, display, perform, and distribute the Software and make, + * use, sell, offer for sale, import, export, have made, and have sold the + * Software and the Larger Work(s), and to sublicense the foregoing rights on + * either these or other terms. + * + * This license is subject to the following condition: + * + * The above copyright notice and either this complete permission notice or at a + * minimum a reference to the UPL must be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package com.oracle.truffle.api.strings.test.ops; + +import static com.oracle.truffle.api.strings.TruffleString.Encoding.BYTES; +import static com.oracle.truffle.api.strings.TruffleString.Encoding.ISO_8859_1; +import static com.oracle.truffle.api.strings.TruffleString.Encoding.US_ASCII; +import static com.oracle.truffle.api.strings.TruffleString.Encoding.UTF_16; +import static com.oracle.truffle.api.strings.TruffleString.Encoding.UTF_32; +import static com.oracle.truffle.api.strings.TruffleString.Encoding.UTF_8; +import static com.oracle.truffle.api.strings.test.TStringTestUtil.byteArray; +import static org.junit.runners.Parameterized.Parameter; + +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import com.oracle.truffle.api.strings.TruffleString; +import com.oracle.truffle.api.strings.test.TStringTestBase; + +@RunWith(Parameterized.class) +public class TStringToValidStringTest extends TStringTestBase { + + @Parameter public TruffleString.ToValidStringNode node; + + @Parameters(name = "{0}") + public static Iterable data() { + return Arrays.asList(TruffleString.ToValidStringNode.create(), TruffleString.ToValidStringNode.getUncached()); + } + + @Test + public void testAll() throws Exception { + forAllStrings(new TruffleString.Encoding[]{US_ASCII, ISO_8859_1, BYTES, UTF_8, UTF_16, UTF_32}, true, (a, array, codeRange, isValid, encoding, codepoints, byteIndices) -> { + TruffleString wellFormed = node.execute(a, encoding); + if (isValid && a instanceof TruffleString) { + Assert.assertSame(a, wellFormed); + } + Assert.assertTrue(wellFormed.isValidUncached(encoding)); + }); + } + + @Test + public void testAscii() { + testAscii(byteArray('a', '?'), byteArray('a', 0xff)); + testAscii(byteArray('a', '?'), byteArray('a', 0x80)); + testAscii(byteArray('a', '?', 'b'), byteArray('a', 0xff, 'b')); + testAscii(byteArray('a', '?', 'b'), byteArray('a', 0x80, 'b')); + testAscii(byteArray('a', 0x7f, 'b'), byteArray('a', 0x7f, 'b')); + } + + @Test + public void testUTF8() { + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD), byteArray('a', 0xff)); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD), byteArray('a', 0xf0, 0x90)); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD), byteArray('a', 0xf0, 0x90, 0x80)); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 0xf0, 0x90, 0x80, 0x80), byteArray('a', 0xf0, 0x90, 0x80, 0xf0, 0x90, 0x80, 0x80)); + testUTF8(byteArray('a', 0xf0, 0x90, 0x80, 0x80, 0xEF, 0xBF, 0xBD), byteArray('a', 0xf0, 0x90, 0x80, 0x80, 0xf0, 0x90, 0x80)); + testUTF8(byteArray('a', 0xf0, 0x90, 0x80, 0x80), byteArray('a', 0xf0, 0x90, 0x80, 0x80)); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 0xEF, 0xBF, 0xBD), byteArray('a', 0xf8, 0x90)); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 'b'), byteArray('a', 0xff, 'b')); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 'b'), byteArray('a', 0xf0, 0x90, 'b')); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 'b'), byteArray('a', 0xf0, 0x90, 0x80, 'b')); + testUTF8(byteArray('a', 0xf0, 0x90, 0x80, 0x80, 'b'), byteArray('a', 0xf0, 0x90, 0x80, 0x80, 'b')); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 0xf0, 0x90, 0x80, 0x80, 'b'), byteArray('a', 0xf0, 0x90, 0x80, 0xf0, 0x90, 0x80, 0x80, 'b')); + testUTF8(byteArray('a', 0xEF, 0xBF, 0xBD, 0xEF, 0xBF, 0xBD, 'b'), byteArray('a', 0xf8, 0x90, 'b')); + } + + private void testAscii(byte[] expected, byte[] input) { + testByteArray(expected, input, US_ASCII); + } + + private void testUTF8(byte[] expected, byte[] input) { + testByteArray(expected, input, UTF_8); + } + + private void testByteArray(byte[] expected, byte[] input, TruffleString.Encoding encoding) { + TruffleString wellFormed = node.execute(TruffleString.fromByteArrayUncached(input, encoding), encoding); + for (int i = 0; i < expected.length; i++) { + Assert.assertEquals(Byte.toUnsignedInt(expected[i]), wellFormed.readByteUncached(i, encoding)); + } + Assert.assertTrue(wellFormed.isValidUncached(encoding)); + } + + @Test + public void testUTF16() { + testUTF16("a\ufffd", "a\udfff"); + testUTF16("a\ufffd", "a\udbff"); + testUTF16("a\ufffd\ufffd", "a\udfff\udfff"); + testUTF16("a\ufffd\ufffd", "a\udbff\udbff"); + testUTF16("a\udbff\udfff\ufffd", "a\udbff\udfff\udbff"); + testUTF16("a\udbff\udfff\ufffdb", "a\udbff\udfff\udbffb"); + } + + private void testUTF16(String expected, String input) { + TruffleString wellFormed = node.execute(TruffleString.fromJavaStringUncached(input, UTF_16), UTF_16); + Assert.assertEquals(expected, wellFormed.toJavaStringUncached()); + Assert.assertTrue(wellFormed.isValidUncached(UTF_16)); + } + + @Test + public void testUTF32() { + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', Character.MIN_SURROGATE}); + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', Character.MAX_SURROGATE}); + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', Integer.MAX_VALUE}); + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', Integer.MIN_VALUE}); + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', 0x110000}); + testUTF32(new int[]{'a', 0xfffd}, new int[]{'a', 0xffff_ffff}); + testUTF32(new int[]{'a', Character.MAX_CODE_POINT}, new int[]{'a', Character.MAX_CODE_POINT}); + testUTF32(new int[]{'a', Character.MAX_CODE_POINT, 0xfffd}, new int[]{'a', Character.MAX_CODE_POINT, Character.MIN_SURROGATE}); + } + + private void testUTF32(int[] expected, int[] input) { + TruffleString wellFormed = node.execute(TruffleString.fromIntArrayUTF32Uncached(input), UTF_32); + for (int i = 0; i < expected.length; i++) { + Assert.assertEquals(expected[i], wellFormed.codePointAtIndexUncached(i, UTF_32)); + } + Assert.assertTrue(wellFormed.isValidUncached(UTF_32)); + } + + @Test + public void testNull() throws Exception { + expectNullPointerException(() -> node.execute(null, UTF_16)); + expectNullPointerException(() -> node.execute(S_UTF16, null)); + } +} diff --git a/truffle/src/com.oracle.truffle.api.strings/snapshot.sigtest b/truffle/src/com.oracle.truffle.api.strings/snapshot.sigtest index 575debdf03a5..13154837ae55 100644 --- a/truffle/src/com.oracle.truffle.api.strings/snapshot.sigtest +++ b/truffle/src/com.oracle.truffle.api.strings/snapshot.sigtest @@ -79,6 +79,7 @@ hfds GIL_LOCK,PARENT_LIMIT,SAME_LANGUAGE_CHECK_VISITOR,parent CLSS public abstract interface com.oracle.truffle.api.nodes.NodeInterface CLSS public abstract com.oracle.truffle.api.strings.AbstractTruffleString +meth public com.oracle.truffle.api.strings.TruffleString toValidStringUncached(com.oracle.truffle.api.strings.TruffleString$Encoding) meth public final boolean codeRangeEqualsUncached(com.oracle.truffle.api.strings.TruffleString$CodeRange) meth public final boolean equals(java.lang.Object) meth public final boolean equalsUncached(com.oracle.truffle.api.strings.AbstractTruffleString,com.oracle.truffle.api.strings.TruffleString$Encoding) @@ -353,6 +354,7 @@ innr public abstract static SubstringByteIndexNode innr public abstract static SubstringNode innr public abstract static SwitchEncodingNode innr public abstract static ToJavaStringNode +innr public abstract static ToValidStringNode innr public final static !enum CodeRange innr public final static !enum CompactionLevel innr public final static !enum Encoding @@ -985,6 +987,13 @@ meth public static com.oracle.truffle.api.strings.TruffleString$ToJavaStringNode meth public static com.oracle.truffle.api.strings.TruffleString$ToJavaStringNode getUncached() supr com.oracle.truffle.api.nodes.Node +CLSS public abstract static com.oracle.truffle.api.strings.TruffleString$ToValidStringNode + outer com.oracle.truffle.api.strings.TruffleString +meth public abstract com.oracle.truffle.api.strings.TruffleString execute(com.oracle.truffle.api.strings.AbstractTruffleString,com.oracle.truffle.api.strings.TruffleString$Encoding) +meth public static com.oracle.truffle.api.strings.TruffleString$ToValidStringNode create() +meth public static com.oracle.truffle.api.strings.TruffleString$ToValidStringNode getUncached() +supr com.oracle.truffle.api.nodes.Node + CLSS public final static com.oracle.truffle.api.strings.TruffleString$WithMask outer com.oracle.truffle.api.strings.TruffleString innr public abstract static CreateNode @@ -1139,7 +1148,7 @@ CLSS public final com.oracle.truffle.api.strings.TruffleStringFactory cons public init() innr public final static WithMaskFactory supr java.lang.Object -hcls AsManagedNodeGen,AsNativeNodeGen,AsTruffleStringNodeGen,ByteIndexOfAnyByteNodeGen,ByteIndexOfCodePointNodeGen,ByteIndexOfCodePointSetNodeGen,ByteIndexOfStringNodeGen,ByteIndexToCodePointIndexNodeGen,ByteLengthOfCodePointNodeGen,CharIndexOfAnyCharUTF16NodeGen,CodePointAtByteIndexNodeGen,CodePointAtIndexNodeGen,CodePointIndexToByteIndexNodeGen,CodePointLengthNodeGen,CodeRangeEqualsNodeGen,CompareBytesNodeGen,CompareCharsUTF16NodeGen,CompareIntsUTF32NodeGen,ConcatNodeGen,CopyToByteArrayNodeGen,CopyToNativeMemoryNodeGen,CreateBackwardCodePointIteratorNodeGen,CreateCodePointIteratorNodeGen,EqualNodeGen,ForceEncodingNodeGen,FromByteArrayNodeGen,FromCharArrayUTF16NodeGen,FromCodePointNodeGen,FromIntArrayUTF32NodeGen,FromJavaStringNodeGen,FromLongNodeGen,FromNativePointerNodeGen,GetByteCodeRangeNodeGen,GetCodeRangeImpreciseNodeGen,GetCodeRangeNodeGen,GetInternalByteArrayNodeGen,GetInternalNativePointerNodeGen,GetStringCompactionLevelNodeGen,HashCodeNodeGen,IndexOfCodePointNodeGen,IndexOfStringNodeGen,IntIndexOfAnyIntUTF32NodeGen,InternalAsTruffleStringNodeGen,InternalCopyToByteArrayNodeGen,InternalSwitchEncodingNodeGen,IsValidNodeGen,LastByteIndexOfCodePointNodeGen,LastByteIndexOfStringNodeGen,LastIndexOfCodePointNodeGen,LastIndexOfStringNodeGen,MaterializeNodeGen,ParseDoubleNodeGen,ParseIntNodeGen,ParseLongNodeGen,ReadByteNodeGen,ReadCharUTF16NodeGen,RegionEqualByteIndexNodeGen,RegionEqualNodeGen,RepeatNodeGen,SubstringByteIndexNodeGen,SubstringNodeGen,SwitchEncodingNodeGen,ToIndexableNodeGen,ToJavaStringNodeGen +hcls AsManagedNodeGen,AsNativeNodeGen,AsTruffleStringNodeGen,ByteIndexOfAnyByteNodeGen,ByteIndexOfCodePointNodeGen,ByteIndexOfCodePointSetNodeGen,ByteIndexOfStringNodeGen,ByteIndexToCodePointIndexNodeGen,ByteLengthOfCodePointNodeGen,CharIndexOfAnyCharUTF16NodeGen,CodePointAtByteIndexNodeGen,CodePointAtIndexNodeGen,CodePointIndexToByteIndexNodeGen,CodePointLengthNodeGen,CodeRangeEqualsNodeGen,CompareBytesNodeGen,CompareCharsUTF16NodeGen,CompareIntsUTF32NodeGen,ConcatNodeGen,CopyToByteArrayNodeGen,CopyToNativeMemoryNodeGen,CreateBackwardCodePointIteratorNodeGen,CreateCodePointIteratorNodeGen,EqualNodeGen,ForceEncodingNodeGen,FromByteArrayNodeGen,FromCharArrayUTF16NodeGen,FromCodePointNodeGen,FromIntArrayUTF32NodeGen,FromJavaStringNodeGen,FromLongNodeGen,FromNativePointerNodeGen,GetByteCodeRangeNodeGen,GetCodeRangeImpreciseNodeGen,GetCodeRangeNodeGen,GetInternalByteArrayNodeGen,GetInternalNativePointerNodeGen,GetStringCompactionLevelNodeGen,HashCodeNodeGen,IndexOfCodePointNodeGen,IndexOfStringNodeGen,IntIndexOfAnyIntUTF32NodeGen,InternalAsTruffleStringNodeGen,InternalCopyToByteArrayNodeGen,InternalSwitchEncodingNodeGen,IsValidNodeGen,LastByteIndexOfCodePointNodeGen,LastByteIndexOfStringNodeGen,LastIndexOfCodePointNodeGen,LastIndexOfStringNodeGen,MaterializeNodeGen,ParseDoubleNodeGen,ParseIntNodeGen,ParseLongNodeGen,ReadByteNodeGen,ReadCharUTF16NodeGen,RegionEqualByteIndexNodeGen,RegionEqualNodeGen,RepeatNodeGen,SubstringByteIndexNodeGen,SubstringNodeGen,SwitchEncodingNodeGen,ToIndexableNodeGen,ToJavaStringNodeGen,ToValidStringNodeGen CLSS public final static com.oracle.truffle.api.strings.TruffleStringFactory$WithMaskFactory outer com.oracle.truffle.api.strings.TruffleStringFactory diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/AbstractTruffleString.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/AbstractTruffleString.java index cbbb48e2a9e2..232daa373efd 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/AbstractTruffleString.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/AbstractTruffleString.java @@ -382,6 +382,10 @@ final void invalidateCodePointLength() { codePointLength = -1; } + final boolean isCodePointLengthKnown() { + return codePointLength >= 0; + } + final void invalidateHashCode() { hashCode = 0; } @@ -1207,6 +1211,16 @@ public final void copyToNativeMemoryUncached(int byteFromIndexA, Object pointerO TruffleString.CopyToNativeMemoryNode.getUncached().execute(this, byteFromIndexA, pointerObject, byteFromIndexDst, byteLength, expectedEncoding); } + /** + * Shorthand for calling the uncached version of {@link TruffleString.ToValidStringNode}. + * + * @since 23.1 + */ + @TruffleBoundary + public TruffleString toValidStringUncached(Encoding expectedEncoding) { + return TruffleString.ToValidStringNode.getUncached().execute(this, expectedEncoding); + } + /** * Shorthand for calling the uncached version of {@link TruffleString.ToJavaStringNode}. * diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/Encodings.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/Encodings.java index 6fa69b59bb32..d6f59b6645b7 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/Encodings.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/Encodings.java @@ -163,6 +163,10 @@ final class Encodings { static final byte UTF8_ACCEPT = 0; static final byte UTF8_REJECT = 12; static final byte UTF8_REVERSE_INCOMPLETE_SEQ = 24; + /** + * UTF-8 encoded 0xfffd. + */ + static final byte[] CONVERSION_REPLACEMENT_UTF_8 = {(byte) 0xEF, (byte) 0xBF, (byte) 0xBD}; static byte[] getUTF8DecodingStateMachine(DecodingErrorHandler errorHandler) { return errorHandler == DecodingErrorHandler.DEFAULT_KEEP_SURROGATES_IN_UTF8 ? Encodings.UTF_8_STATE_MACHINE_ALLOW_UTF16_SURROGATES : Encodings.UTF_8_STATE_MACHINE; diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/JCodingsImpl.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/JCodingsImpl.java index 79a9fff04f44..662aeac67f8c 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/JCodingsImpl.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/JCodingsImpl.java @@ -282,7 +282,6 @@ private static void econvInsertOutput(TruffleString.Encoding targetEncoding, Enc } private static final byte[] CONVERSION_REPLACEMENT = {'?'}; - private static final byte[] CONVERSION_REPLACEMENT_UTF_8 = {(byte) 0xEF, (byte) 0xBF, (byte) 0xBD}; private static final byte[] CONVERSION_REPLACEMENT_UTF_16 = TStringGuards.littleEndian() ? new byte[]{(byte) 0xFD, (byte) 0xFF} : new byte[]{(byte) 0xFF, (byte) 0xFD}; private static final byte[] CONVERSION_REPLACEMENT_UTF_32 = TStringGuards.littleEndian() ? new byte[]{(byte) 0xFD, (byte) 0xFF, 0, 0} : new byte[]{0, 0, (byte) 0xFF, (byte) 0xFD}; @@ -318,7 +317,7 @@ public TruffleString transcode(Node location, AbstractTruffleString a, Object ar } else { final byte[] replacement; if (isUTF8(targetEncoding)) { - replacement = CONVERSION_REPLACEMENT_UTF_8; + replacement = Encodings.CONVERSION_REPLACEMENT_UTF_8; } else if (isUTF16(targetEncoding)) { replacement = CONVERSION_REPLACEMENT_UTF_16; } else if (isUTF32(targetEncoding)) { diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringInternalNodes.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringInternalNodes.java index 5aec5d1bc171..d1562cbb12ae 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringInternalNodes.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringInternalNodes.java @@ -43,8 +43,11 @@ import static com.oracle.truffle.api.strings.AbstractTruffleString.checkArrayRange; import static com.oracle.truffle.api.strings.AbstractTruffleString.checkByteLengthUTF16; import static com.oracle.truffle.api.strings.AbstractTruffleString.checkByteLengthUTF32; +import static com.oracle.truffle.api.strings.Encodings.UTF8_ACCEPT; +import static com.oracle.truffle.api.strings.Encodings.UTF8_REJECT; import static com.oracle.truffle.api.strings.Encodings.isUTF16Surrogate; import static com.oracle.truffle.api.strings.TSCodeRange.isBroken; +import static com.oracle.truffle.api.strings.TSCodeRange.isPrecise; import static com.oracle.truffle.api.strings.TStringGuards.indexOfCannotMatch; import static com.oracle.truffle.api.strings.TStringGuards.is16Bit; import static com.oracle.truffle.api.strings.TStringGuards.is7Bit; @@ -54,9 +57,9 @@ import static com.oracle.truffle.api.strings.TStringGuards.isAsciiBytesOrLatin1; import static com.oracle.truffle.api.strings.TStringGuards.isBrokenFixedWidth; import static com.oracle.truffle.api.strings.TStringGuards.isBrokenMultiByte; +import static com.oracle.truffle.api.strings.TStringGuards.isBuiltin; import static com.oracle.truffle.api.strings.TStringGuards.isBytes; import static com.oracle.truffle.api.strings.TStringGuards.isFixedWidth; -import static com.oracle.truffle.api.strings.TStringGuards.isBuiltin; import static com.oracle.truffle.api.strings.TStringGuards.isStride0; import static com.oracle.truffle.api.strings.TStringGuards.isStride1; import static com.oracle.truffle.api.strings.TStringGuards.isStride2; @@ -95,6 +98,10 @@ final class TStringInternalNodes { + /** + * Gets a string's code range with enough precision to decide whether the code range makes the + * string fixed-width. + */ abstract static class GetCodeRangeForIndexCalculationNode extends AbstractInternalNode { abstract int execute(Node node, AbstractTruffleString a, Encoding encoding); @@ -112,6 +119,26 @@ static int get(Node node, AbstractTruffleString a, Encoding encoding, } } + /** + * Gets a string's code range with enough precision to decide whether the string is valid. + */ + abstract static class GetValidOrBrokenCodeRangeNode extends AbstractInternalNode { + + abstract int execute(Node node, AbstractTruffleString a, Encoding encoding); + + @Specialization + static int get(Node node, AbstractTruffleString a, Encoding encoding, + @Cached InlinedConditionProfile impreciseProfile, + @Cached TruffleString.ToIndexableNode toIndexableNode, + @Cached CalcStringAttributesNode calcStringAttributesNode) { + int codeRange = a.codeRange(); + if (impreciseProfile.profile(node, !TSCodeRange.isPrecise(codeRange) && TSCodeRange.isBroken(codeRange))) { + return StringAttributes.getCodeRange(updateAttributes(node, a, encoding, codeRange, toIndexableNode, calcStringAttributesNode)); + } + return codeRange; + } + } + abstract static class GetPreciseCodeRangeNode extends AbstractInternalNode { abstract int execute(Node node, AbstractTruffleString a, Encoding encoding); @@ -1578,7 +1605,173 @@ static String createJavaString(Node node, AbstractTruffleString a, Object arrayA } return TStringUnsafe.createJavaString(bytes, stride); } + } + + abstract static class ToValidStringNode extends AbstractInternalNode { + + private static final int[] UTF_32_ASTRAL_RANGE = {0x10000, 0x10ffff}; + private static final int[] UTF_32_INVALID_RANGES = {Character.MIN_SURROGATE, Character.MAX_SURROGATE, 0x11_0000, 0xffff_ffff}; + + abstract TruffleString execute(Node node, AbstractTruffleString a, Object arrayA, Encoding encoding); + + @Specialization(guards = "isAscii(encoding)") + static TruffleString ascii(Node node, AbstractTruffleString a, Object arrayA, @SuppressWarnings("unused") Encoding encoding) { + assert isStride0(a); + int length = a.length(); + byte[] array = TStringOps.arraycopyOfWithStride(node, arrayA, a.offset(), length, 0, length, 0); + int pos = 0; + int loopCount = 0; + while (pos < length) { + pos = TStringOps.indexOfCodePointWithMaskWithStrideIntl(node, array, 0, length, 0, pos, 0xff, 0x7f); + if (pos >= 0) { + TStringOps.writeToByteArray(array, 0, pos++, '?'); + } else { + break; + } + TStringConstants.truffleSafePointPoll(node, ++loopCount); + } + return TruffleString.createFromByteArray(array, length, 0, Encoding.US_ASCII, length, TSCodeRange.get7Bit()); + } + + @Specialization(guards = "isUTF8(encoding)") + static TruffleString utf8(Node node, AbstractTruffleString a, Object arrayA, @SuppressWarnings("unused") Encoding encoding, + @Cached InlinedBranchProfile outOfMemoryProfile) { + assert isStride0(a); + assert isPrecise(a.codeRange()); + assert a.isCodePointLengthKnown(); + + boolean isLarge = TransCodeIntlNode.isLarge(a.codePointLength()); + byte[] buffer = new byte[isLarge ? TStringConstants.MAX_ARRAY_SIZE : a.codePointLength() * 4]; + int length = 0; + int state = UTF8_ACCEPT; + int lastCodePointPos = 0; + int lastErrorPos = 0; + int codePointLength = a.codePointLength(); + byte[] stateMachine = Encodings.UTF_8_STATE_MACHINE; + int i = 0; + while (i < a.length()) { + int b = readS0(a, arrayA, i++); + int type = stateMachine[b]; + state = stateMachine[256 + state + type]; + if (state == UTF8_ACCEPT) { + lastCodePointPos = i; + } else if (state == UTF8_REJECT) { + int curCPLength = i - (lastCodePointPos + 1); + length = utf8CopyValidRegion(node, a, arrayA, outOfMemoryProfile, isLarge, buffer, length, lastCodePointPos, lastErrorPos); + System.arraycopy(Encodings.CONVERSION_REPLACEMENT_UTF_8, 0, buffer, length, Encodings.CONVERSION_REPLACEMENT_UTF_8.length); + length += Encodings.CONVERSION_REPLACEMENT_UTF_8.length; + state = UTF8_ACCEPT; + if (curCPLength > 1) { + codePointLength -= curCPLength - 1; + i--; + } + lastErrorPos = i; + lastCodePointPos = i; + } + TStringConstants.truffleSafePointPoll(node, i); + } + length = utf8CopyValidRegion(node, a, arrayA, outOfMemoryProfile, isLarge, buffer, length, lastCodePointPos, lastErrorPos); + if (lastCodePointPos != a.length() && lastErrorPos != lastCodePointPos) { + System.arraycopy(Encodings.CONVERSION_REPLACEMENT_UTF_8, 0, buffer, length, Encodings.CONVERSION_REPLACEMENT_UTF_8.length); + length += Encodings.CONVERSION_REPLACEMENT_UTF_8.length; + int curCPLength = a.length() - lastCodePointPos; + if (curCPLength > 1) { + codePointLength -= curCPLength - 1; + } + } + return TruffleString.createFromByteArray(Arrays.copyOf(buffer, length), length, 0, Encoding.UTF_8, codePointLength, TSCodeRange.getValidMultiByte()); + } + + private static int utf8CopyValidRegion(Node node, AbstractTruffleString a, Object arrayA, + InlinedBranchProfile outOfMemoryProfile, boolean isLarge, byte[] buffer, int length, int lastCodePointPos, int lastErrorPos) { + int lengthCPY = lastCodePointPos - lastErrorPos; + if (isLarge && Integer.compareUnsigned(length + lengthCPY + Encodings.CONVERSION_REPLACEMENT_UTF_8.length, buffer.length) > 0) { + outOfMemoryProfile.enter(node); + throw InternalErrors.outOfMemory(); + } + TStringOps.arraycopyWithStride(node, arrayA, a.offset(), 0, lastErrorPos, buffer, 0, 0, length, lengthCPY); + return length + lengthCPY; + } + + @Specialization(guards = "isUTF16(encoding)") + static TruffleString utf16(Node node, AbstractTruffleString a, Object arrayA, @SuppressWarnings("unused") Encoding encoding) { + assert isStride1(a); + int length = a.length(); + byte[] array = TStringOps.arraycopyOfWithStride(node, arrayA, a.offset(), length, 1, length, 1); + int pos = 0; + int codeRange = TSCodeRange.get16Bit(); + int loopCount = 0; + while (true) { + pos = TStringOps.indexOfCodePointWithMaskWithStrideIntl(node, array, 0, length, 1, pos, 0xdfff, 0x7ff); + if (pos >= 0) { + boolean invalid = true; + if (pos != length - 1) { + char c = (char) TStringOps.readFromByteArray(array, 1, pos); + assert Encodings.isUTF16Surrogate(c); + if (!Encodings.isUTF16LowSurrogate(c)) { + assert Encodings.isUTF16HighSurrogate(c); + if (Encodings.isUTF16LowSurrogate((char) TStringOps.readFromByteArray(array, 1, pos + 1))) { + invalid = false; + codeRange = TSCodeRange.getValidMultiByte(); + pos++; + } + } + } + if (invalid) { + TStringOps.writeToByteArray(array, 1, pos, 0xfffd); + } + if (++pos == length) { + break; + } + } else { + break; + } + TStringConstants.truffleSafePointPoll(node, ++loopCount); + } + return TruffleString.createFromByteArray(array, length, 1, Encoding.UTF_16, a.codePointLength(), codeRange); + } + + @Specialization(guards = "isUTF32(encoding)") + static TruffleString utf32(Node node, AbstractTruffleString a, Object arrayA, @SuppressWarnings("unused") Encoding encoding, + @Cached InlinedConditionProfile strideProfile) { + assert isStride2(a); + int length = a.length(); + final byte[] array; + final int stride; + final int codeRange; + if (strideProfile.profile(node, TStringOps.indexOfAnyIntRange(node, arrayA, 0, 2, 0, a.length(), UTF_32_ASTRAL_RANGE) < 0)) { + array = TStringOps.arraycopyOfWithStride(node, arrayA, a.offset(), length, 2, length, 1); + stride = 1; + codeRange = TSCodeRange.get16Bit(); + utf32ReplaceInvalid(node, arrayA, length, array, 1); + } else { + array = TStringOps.arraycopyOfWithStride(node, arrayA, a.offset(), length, 2, length, 2); + stride = 2; + codeRange = TSCodeRange.getValidFixedWidth(); + utf32ReplaceInvalid(node, arrayA, length, array, 2); + } + return TruffleString.createFromByteArray(array, length, stride, Encoding.UTF_32, a.codePointLength(), codeRange); + } + + private static void utf32ReplaceInvalid(Node node, Object arrayA, int length, byte[] array, int stride) { + int pos = 0; + int loopCount = 0; + while (pos < length) { + pos = TStringOps.indexOfAnyIntRange(node, arrayA, 0, 2, pos, length, UTF_32_INVALID_RANGES); + if (pos >= 0) { + TStringOps.writeToByteArray(array, stride, pos++, 0xfffd); + } else { + break; + } + TStringConstants.truffleSafePointPoll(node, ++loopCount); + } + } + @SuppressWarnings("unused") + @Specialization(guards = "isUnsupportedEncoding(encoding)") + static TruffleString unsupported(Node node, AbstractTruffleString a, Object arrayA, Encoding encoding) { + throw InternalErrors.unsupportedOperation(); + } } abstract static class TransCodeNode extends AbstractInternalNode { diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringOps.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringOps.java index 022b20d2fd54..66fe6ccfe47d 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringOps.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TStringOps.java @@ -398,7 +398,7 @@ static int indexOfCodePointWithOrMaskWithStride(Node location, AbstractTruffleSt return indexOfCodePointWithMaskWithStrideIntl(location, arrayA, a.offset(), toIndex, strideA, fromIndex, codepoint, maskA); } - private static int indexOfCodePointWithMaskWithStrideIntl(Node location, Object array, int offset, int length, int stride, int fromIndex, int v1, int mask1) { + static int indexOfCodePointWithMaskWithStrideIntl(Node location, Object array, int offset, int length, int stride, int fromIndex, int v1, int mask1) { final boolean isNative = isNativePointer(array); final byte[] stubArray = stubArray(array, isNative); validateRegionIndex(stubArray, offset, length, stride, fromIndex, isNative); diff --git a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TruffleString.java b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TruffleString.java index 64d5ac8c79ad..147893bbfec3 100644 --- a/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TruffleString.java +++ b/truffle/src/com.oracle.truffle.api.strings/src/com/oracle/truffle/api/strings/TruffleString.java @@ -2161,7 +2161,7 @@ final TruffleString doNonEmpty(int[] value, int intOffset, int length, if (length == 0) { return Encoding.UTF_32.getEmpty(); } - if (length == 1 && value[intOffset] <= 0xff) { + if (length == 1 && Integer.compareUnsigned(value[intOffset], 0xff) <= 0) { return TStringConstants.getSingleByte(Encoding.UTF_32, value[intOffset]); } int offsetV = intOffset << 2; @@ -2794,9 +2794,9 @@ public abstract static class IsValidNode extends AbstractPublicNode { @Specialization final boolean isValid(AbstractTruffleString a, Encoding expectedEncoding, - @Cached TStringInternalNodes.GetPreciseCodeRangeNode getPreciseCodeRangeNode) { + @Cached TStringInternalNodes.GetValidOrBrokenCodeRangeNode getCodeRangeNode) { a.checkEncoding(expectedEncoding); - return !isBroken(getPreciseCodeRangeNode.execute(this, a, expectedEncoding)); + return !isBroken(getCodeRangeNode.execute(this, a, expectedEncoding)); } /** @@ -6151,6 +6151,64 @@ public TruffleString asNativeUncached(NativeAllocator allocator, Encoding expect return AsNativeNode.getUncached().execute(this, allocator, expectedEncoding, useCompaction, cacheResult); } + /** + * Node to replace all invalid bytes in a given string, such that the resulting string is + * {@link IsValidNode valid}. See + * {@link #execute(AbstractTruffleString, TruffleString.Encoding)} for details. + * + * @since 23.1 + */ + public abstract static class ToValidStringNode extends AbstractPublicNode { + + ToValidStringNode() { + } + + /** + * Returns a version of string {@code a} that contains only valid codepoints, which may be + * the string itself or a converted version. Invalid byte sequences are replaced with + * {@code '\ufffd'} (for UTF-*) or {@code '?'}. This is useful for string sanitization in + * all uses cases where a string is required to actually be {@link IsValidNode valid}, such + * as libraries that actively reject broken input, network and file system I/O, etc. + * + * @since 23.1 + */ + public abstract TruffleString execute(AbstractTruffleString a, Encoding expectedEncoding); + + @Specialization + final TruffleString toValid(AbstractTruffleString a, Encoding encoding, + @Cached InlinedConditionProfile isValidProfile, + @Cached TStringInternalNodes.GetValidOrBrokenCodeRangeNode getCodeRangeNode, + @Cached InternalAsTruffleStringNode asTruffleStringNode, + @Cached TStringInternalNodes.ToValidStringNode internalNode, + @Cached ToIndexableNode toIndexableNode) { + a.checkEncoding(encoding); + int codeRangeA = getCodeRangeNode.execute(this, a, encoding); + if (isValidProfile.profile(this, !isBroken(codeRangeA))) { + return asTruffleStringNode.execute(this, a, encoding); + } + return internalNode.execute(this, a, toIndexableNode.execute(this, a, a.data()), encoding); + } + + /** + * Create a new {@link SwitchEncodingNode}. + * + * @since 23.1 + */ + @NeverDefault + public static ToValidStringNode create() { + return TruffleStringFactory.ToValidStringNodeGen.create(); + } + + /** + * Get the uncached version of {@link SwitchEncodingNode}. + * + * @since 23.1 + */ + public static ToValidStringNode getUncached() { + return TruffleStringFactory.ToValidStringNodeGen.getUncached(); + } + } + /** * Node to get a given string in a specific encoding. See * {@link #execute(AbstractTruffleString, TruffleString.Encoding)} for details. @@ -6217,7 +6275,7 @@ public static SwitchEncodingNode getUncached() { abstract static class InternalSwitchEncodingNode extends AbstractInternalNode { - public abstract TruffleString execute(Node node, AbstractTruffleString a, Encoding targetEncoding, TranscodingErrorHandler errorHandler); + abstract TruffleString execute(Node node, AbstractTruffleString a, Encoding targetEncoding, TranscodingErrorHandler errorHandler); @Specialization(guards = "a.isCompatibleToIntl(targetEncoding)") static TruffleString compatibleImmutable(TruffleString a, @SuppressWarnings("unused") Encoding targetEncoding, @SuppressWarnings("unused") TranscodingErrorHandler errorHandler) { @@ -6287,7 +6345,6 @@ static TruffleString transCodeMutable(Node node, MutableTruffleString a, Encodin return transCodeNode.execute(node, a, a.data(), codePointLengthA, codeRangeA, targetEncoding, errorHandler); } } - } /** diff --git a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotContextImpl.java b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotContextImpl.java index 2a42890ba4bc..2569e7e22982 100644 --- a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotContextImpl.java +++ b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotContextImpl.java @@ -906,7 +906,7 @@ Object[] enterThreadChanged(boolean enterReverted, boolean pollSafepoint, boolea initializeThreadLocals(threadInfo); } - prev = threadInfo.enterInternal(engine); + prev = threadInfo.enterInternal(); if (leaveAndEnter) { threadInfo.setLeaveAndEnterInterrupter(null); notifyAll(); @@ -918,7 +918,7 @@ Object[] enterThreadChanged(boolean enterReverted, boolean pollSafepoint, boolea try { threadInfo.notifyEnter(engine, this); } catch (Throwable t) { - threadInfo.leaveInternal(engine, prev); + threadInfo.leaveInternal(prev); throw t; } } @@ -1130,7 +1130,7 @@ void leaveThreadChanged(Object[] prev, boolean entered, boolean finalizeAndDispo * Thread finalization notification is invoked outside of the context lock so that the * guest languages can operate freely without the risk of a deadlock. */ - ex = notifyThreadFinalizing(threadInfo, null); + ex = notifyThreadFinalizing(threadInfo, null, false); } synchronized (this) { if (finalizeAndDispose) { @@ -1145,7 +1145,7 @@ void leaveThreadChanged(Object[] prev, boolean entered, boolean finalizeAndDispo threadInfo.notifyLeave(engine, this); } } finally { - threadInfo.leaveInternal(engine, prev); + threadInfo.leaveInternal(prev); } } if (threadInfo.getEnteredCount() == 0) { @@ -1205,7 +1205,7 @@ private void finishThreadDispose(Thread current, PolyglotThreadInfo info, Throwa } } - private Throwable notifyThreadFinalizing(PolyglotThreadInfo threadInfo, Throwable previousEx) { + private Throwable notifyThreadFinalizing(PolyglotThreadInfo threadInfo, Throwable previousEx, boolean mustSucceed) { Throwable ex = previousEx; Thread thread = threadInfo.getThread(); if (thread == null) { @@ -1258,7 +1258,7 @@ private Throwable notifyThreadFinalizing(PolyglotThreadInfo threadInfo, Throwabl } synchronized (this) { if (finalizedContexts.cardinality() == threadInfo.initializedLanguageContextsCount()) { - threadInfo.setFinalizationComplete(); + threadInfo.setFinalizationComplete(engine, mustSucceed); break; } } @@ -3243,7 +3243,7 @@ private void finalizeContext(boolean notifyInstruments, boolean mustSucceed) { embedderThreads = getSeenThreads().values().stream().filter(threadInfo -> !threadInfo.isPolyglotThread(this)).toList().toArray(new PolyglotThreadInfo[0]); } for (PolyglotThreadInfo threadInfo : embedderThreads) { - ex = notifyThreadFinalizing(threadInfo, ex); + ex = notifyThreadFinalizing(threadInfo, ex, mustSucceed); } if (ex != null) { if (!mustSucceed || isInternalError(ex)) { diff --git a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotEngineImpl.java b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotEngineImpl.java index c1ae5dfb6968..c3dd1ec34ac5 100644 --- a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotEngineImpl.java +++ b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotEngineImpl.java @@ -2060,7 +2060,7 @@ Object[] enterCached(PolyglotContextImpl context, boolean pollSafepoint) { boolean enterReverted = false; if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, info.getThread() == Thread.currentThread())) { // Volatile increment is safe if only one thread does it. - prev = info.enterInternal(this); + prev = info.enterInternal(); // Check again whether the cached thread info is still the same as expected if (CompilerDirectives.injectBranchProbability(CompilerDirectives.FASTPATH_PROBABILITY, info == context.getCachedThread())) { @@ -2073,7 +2073,7 @@ Object[] enterCached(PolyglotContextImpl context, boolean pollSafepoint) { try { info.notifyEnter(this, context); } catch (Throwable e) { - info.leaveInternal(this, prev); + info.leaveInternal(prev); throw e; } return prev; @@ -2082,7 +2082,7 @@ Object[] enterCached(PolyglotContextImpl context, boolean pollSafepoint) { * If we go this path and enteredCount drops to 0, the subsequent slowpath enter * must call deactivateThread. */ - info.leaveInternal(this, prev); + info.leaveInternal(prev); enterReverted = true; } } @@ -2133,7 +2133,7 @@ void leaveCached(Object[] prev, PolyglotContextImpl context) { try { info.notifyLeave(this, context); } finally { - info.leaveInternal(this, prev); + info.leaveInternal(prev); entered = false; } if (CompilerDirectives.injectBranchProbability(CompilerDirectives.FASTPATH_PROBABILITY, info == context.getCachedThread())) { diff --git a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotThreadInfo.java b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotThreadInfo.java index 564c212596af..1ab839de5bc4 100644 --- a/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotThreadInfo.java +++ b/truffle/src/com.oracle.truffle.polyglot/src/com/oracle/truffle/polyglot/PolyglotThreadInfo.java @@ -46,6 +46,7 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.Objects; import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives; @@ -171,9 +172,13 @@ boolean isFinalizationComplete() { return finalizationComplete; } - void setFinalizationComplete() { + void setFinalizationComplete(PolyglotEngineImpl engine, boolean mustSucceed) { assert Thread.holdsLock(context); this.finalizationComplete = true; + // Assert only when !mustSucceed, partity might not be met on cancellation. + if (ASSERT_ENTER_RETURN_PARITY && !mustSucceed && engine.probeAssertionsEnabled) { + assertProbeThreadFinalized(); + } } boolean isSafepointActive() { @@ -218,13 +223,10 @@ void setLeaveAndEnterInterrupter(TruffleSafepoint.Interrupter interrupter) { * {@link PolyglotEngineImpl#enter(PolyglotContextImpl, boolean, Node, boolean)} instead. */ @SuppressFBWarnings("VO_VOLATILE_INCREMENT") - Object[] enterInternal(PolyglotEngineImpl engine) { + Object[] enterInternal() { Object[] prev = PolyglotFastThreadLocals.enter(this); assert Thread.currentThread() == getThread() : "Volatile increment is safe on a single thread only."; enteredCount++; - if (ASSERT_ENTER_RETURN_PARITY && engine.probeAssertionsEnabled) { - assertProbeThreadEnter(); - } return prev; } @@ -238,12 +240,9 @@ int getEnteredCount() { * {@link PolyglotEngineImpl#leave(PolyglotContextImpl, PolyglotContextImpl)} instead. */ @SuppressFBWarnings("VO_VOLATILE_INCREMENT") - void leaveInternal(PolyglotEngineImpl engine, Object[] prev) { + void leaveInternal(Object[] prev) { assert Thread.currentThread() == getThread() : "Volatile decrement is safe on a single thread only."; enteredCount--; - if (ASSERT_ENTER_RETURN_PARITY && engine.probeAssertionsEnabled) { - assertProbeThreadLeave(); - } PolyglotFastThreadLocals.leave(prev); } @@ -282,33 +281,38 @@ void notifyLeave(PolyglotEngineImpl engine, PolyglotContextImpl profiledContext) } @TruffleBoundary - private void assertProbeThreadEnter() { + private void assertProbeThreadFinalized() { if (probesEnterList != null) { - probesEnterList.add(null); + assert probesEnterList.isEmpty() : getEnteredProbesMessage(probesEnterList); } } - @TruffleBoundary - private void assertProbeThreadLeave() { - if (probesEnterList != null) { - int size = probesEnterList.size(); - assert size > 0 : "Leave of polyglot thread does not have a preceding enter."; - ProbeNode probe = probesEnterList.remove(size - 1); - assert probe == null : "Found an entered probe without return: " + probe + " with parent node " + probe.getParent().getClass() + "\n" + - "Specifically, a call to ProbeNode.onEnter()/onResume() does not have a corresponding call to ProbeNode.onReturnValue()/onReturnExceptionalOrUnwind()/onYield()."; + private static String getEnteredProbesMessage(List probes) { + StringBuilder sb = new StringBuilder("Found entered probes without return: "); + sb.append(probes); + sb.append("\nSpecifically, a call to ProbeNode.onEnter()/onResume() does not have a corresponding call to ProbeNode.onReturnValue()/onReturnExceptionalOrUnwind()/onYield()."); + for (ProbeNode probe : probes) { + sb.append("\n probe "); + sb.append(probe); + sb.append(" with parent node "); + sb.append(probe.getParent().getClass()); } + sb.append('\n'); + return sb.toString(); } @TruffleBoundary void assertProbeEntered(ProbeNode probe) { + Objects.requireNonNull(probe); probesEnterList.add(probe); } @TruffleBoundary void assertProbeReturned(ProbeNode probe) { - assert !probesEnterList.isEmpty() : "ProbeNode exited without enter"; + assert !probesEnterList.isEmpty() : "ProbeNode " + probe + " with parent " + probe.getParent().getClass() + " exited without enter"; ProbeNode lastProbe = probesEnterList.remove(probesEnterList.size() - 1); - assert probe == lastProbe : "Entered probe " + lastProbe + " differs from the returned probe " + probe + " with parent " + probe.getParent().getClass() + "\n" + + assert probe == lastProbe : "Entered probe " + lastProbe + " with parent " + lastProbe.getParent().getClass() + " differs from the returned probe " + + probe + " with parent " + probe.getParent().getClass() + "\n" + "Specifically, a call to onEnter()/onResume() on " + lastProbe + " was not followed by a call to onReturnValue()/onReturnExceptionalOrUnwind()/onYield() on the same probe, " + "but on " + probe + " instead."; }