From 1958c7e3086ced3deb2092f2a147cc212aee7709 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnther=20Foidl?= Date: Mon, 25 Apr 2022 17:15:54 +0200 Subject: [PATCH] Handle final elements in SpanHelpers.Contains for byte and char vectorized (#67492) * Handle final elements in SpanHelpers.Contains(ref byte, byte, int) vectorized * Handle final elements in SpanHelpers.Contains(ref char, char, int) vectorized * Use equality operator instead of Vector.Zero.Equals due to codegen issue Cf. https://github.com/dotnet/runtime/pull/67492#discussion_r841219532 --- .../src/System/SpanHelpers.Byte.cs | 61 ++++++++------- .../src/System/SpanHelpers.Char.cs | 75 ++++++++++--------- 2 files changed, 75 insertions(+), 61 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 5bd50caa89f5e..e8fca14efa52a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -341,26 +341,26 @@ public static bool Contains(ref byte searchSpace, byte value, int length) uint uValue = value; // Use uint for comparisons to avoid unnecessary 8->32 extensions nuint offset = 0; // Use nuint for arithmetic to avoid unnecessary 64->32->64 truncations - nuint lengthToExamine = (nuint)(uint)length; + nuint lengthToExamine = (uint)length; if (Vector.IsHardwareAccelerated && length >= Vector.Count * 2) { lengthToExamine = UnalignedCountVector(ref searchSpace); } - SequentialScan: while (lengthToExamine >= 8) { lengthToExamine -= 8; - - if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 0) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 1) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 2) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 3) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 4) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 5) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 6) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 7)) + ref byte start = ref Unsafe.AddByteOffset(ref searchSpace, offset); + + if (uValue == Unsafe.AddByteOffset(ref start, 0) || + uValue == Unsafe.AddByteOffset(ref start, 1) || + uValue == Unsafe.AddByteOffset(ref start, 2) || + uValue == Unsafe.AddByteOffset(ref start, 3) || + uValue == Unsafe.AddByteOffset(ref start, 4) || + uValue == Unsafe.AddByteOffset(ref start, 5) || + uValue == Unsafe.AddByteOffset(ref start, 6) || + uValue == Unsafe.AddByteOffset(ref start, 7)) { goto Found; } @@ -371,11 +371,12 @@ public static bool Contains(ref byte searchSpace, byte value, int length) if (lengthToExamine >= 4) { lengthToExamine -= 4; + ref byte start = ref Unsafe.AddByteOffset(ref searchSpace, offset); - if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 0) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 1) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 2) || - uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 3)) + if (uValue == Unsafe.AddByteOffset(ref start, 0) || + uValue == Unsafe.AddByteOffset(ref start, 1) || + uValue == Unsafe.AddByteOffset(ref start, 2) || + uValue == Unsafe.AddByteOffset(ref start, 3)) { goto Found; } @@ -385,24 +386,25 @@ public static bool Contains(ref byte searchSpace, byte value, int length) while (lengthToExamine > 0) { - lengthToExamine -= 1; + lengthToExamine--; if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset)) goto Found; - offset += 1; + offset++; } - if (Vector.IsHardwareAccelerated && (offset < (nuint)(uint)length)) + if (Vector.IsHardwareAccelerated && (offset < (uint)length)) { - lengthToExamine = (((nuint)(uint)length - offset) & (nuint)~(Vector.Count - 1)); + lengthToExamine = ((uint)length - offset) & (nuint)~(Vector.Count - 1); - Vector values = new Vector(value); + Vector values = new(value); + Vector matches; - while (lengthToExamine > offset) + while (offset < lengthToExamine) { - var matches = Vector.Equals(values, LoadVector(ref searchSpace, offset)); - if (Vector.Zero.Equals(matches)) + matches = Vector.Equals(values, LoadVector(ref searchSpace, offset)); + if (matches == Vector.Zero) { offset += (nuint)Vector.Count; continue; @@ -411,10 +413,17 @@ public static bool Contains(ref byte searchSpace, byte value, int length) goto Found; } - if (offset < (nuint)(uint)length) + // The total length is at least Vector.Count, so instead of falling back to a + // sequential scan for the remainder, we check the vector read from the end -- note: unaligned read necessary. + // We do this only if at least one element is left. + if (offset < (uint)length) { - lengthToExamine = ((nuint)(uint)length - offset); - goto SequentialScan; + offset = (uint)(length - Vector.Count); + matches = Vector.Equals(values, LoadVector(ref searchSpace, offset)); + if (matches != Vector.Zero) + { + goto Found; + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs index 7ca768ea789e0..a662fe73e1ca4 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs @@ -388,7 +388,7 @@ public static unsafe int SequenceCompareTo(ref char first, int firstLength, ref while (minLength >= (i + (nuint)(sizeof(nuint) / sizeof(char)))) { - if (Unsafe.ReadUnaligned (ref Unsafe.As(ref Unsafe.Add(ref first, (nint)i))) != + if (Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref first, (nint)i))) != Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref second, (nint)i)))) { break; @@ -428,83 +428,88 @@ public static unsafe bool Contains(ref char searchSpace, char value, int length) fixed (char* pChars = &searchSpace) { - char* pCh = pChars; - char* pEndCh = pCh + length; + nuint offset = 0; // Use nuint for arithmetic to avoid unnecessary 64->32->64 truncations + nuint lengthToExamine = (uint)length; if (Vector.IsHardwareAccelerated && length >= Vector.Count * 2) { // Figure out how many characters to read sequentially until we are vector aligned // This is equivalent to: - // unaligned = ((int)pCh % Unsafe.SizeOf>()) / elementsPerByte + // unaligned = ((int)pCh % Unsafe.SizeOf>()) / ElementsPerByte // length = (Vector.Count - unaligned) % Vector.Count - const int elementsPerByte = sizeof(ushort) / sizeof(byte); - int unaligned = ((int)pCh & (Unsafe.SizeOf>() - 1)) / elementsPerByte; - length = (Vector.Count - unaligned) & (Vector.Count - 1); + const int ElementsPerByte = sizeof(ushort) / sizeof(byte); + int unaligned = (int)((uint)((int)pChars & (Unsafe.SizeOf>() - 1)) / ElementsPerByte); + lengthToExamine = (uint)((Vector.Count - unaligned) & (Vector.Count - 1)); } - SequentialScan: - while (length >= 4) + while (lengthToExamine >= 4) { - length -= 4; + lengthToExamine -= 4; + char* pStart = pChars + offset; - if (value == *pCh || - value == *(pCh + 1) || - value == *(pCh + 2) || - value == *(pCh + 3)) + if (value == pStart[0] || + value == pStart[1] || + value == pStart[2] || + value == pStart[3]) { goto Found; } - pCh += 4; + offset += 4; } - while (length > 0) + while (lengthToExamine > 0) { - length--; + lengthToExamine--; - if (value == *pCh) + if (value == pChars[offset]) goto Found; - pCh++; + offset++; } // We get past SequentialScan only if IsHardwareAccelerated is true. However, we still have the redundant check to allow - // the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware accelerated. - if (Vector.IsHardwareAccelerated && pCh < pEndCh) + // the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware acceleration. + if (Vector.IsHardwareAccelerated && (offset < (uint)length)) { // Get the highest multiple of Vector.Count that is within the search space. // That will be how many times we iterate in the loop below. - // This is equivalent to: length = Vector.Count * ((int)(pEndCh - pCh) / Vector.Count) - length = (int)((pEndCh - pCh) & ~(Vector.Count - 1)); + // This is equivalent to: lengthToExamine = Vector.Count + ((uint)length - offset) / Vector.Count) + lengthToExamine = ((uint)length - offset) & (nuint)~(Vector.Count - 1); - // Get comparison Vector - Vector vComparison = new Vector(value); + Vector values = new(value); + Vector matches; - while (length > 0) + while (offset < lengthToExamine) { // Using Unsafe.Read instead of ReadUnaligned since the search space is pinned and pCh is always vector aligned - Debug.Assert(((int)pCh & (Unsafe.SizeOf>() - 1)) == 0); - Vector vMatches = Vector.Equals(vComparison, Unsafe.Read>(pCh)); - if (Vector.Zero.Equals(vMatches)) + Debug.Assert(((int)(pChars + offset) % Unsafe.SizeOf>()) == 0); + matches = Vector.Equals(values, Unsafe.Read>(pChars + offset)); + if (matches == Vector.Zero) { - pCh += Vector.Count; - length -= Vector.Count; + offset += (nuint)Vector.Count; continue; } goto Found; } - if (pCh < pEndCh) + // The total length is at least Vector.Count, so instead of falling back to a + // sequential scan for the remainder, we check the vector read from the end -- note: unaligned read necessary. + // We do this only if at least one element is left. + if (offset < (uint)length) { - length = (int)(pEndCh - pCh); - goto SequentialScan; + matches = Vector.Equals(values, Unsafe.ReadUnaligned>(pChars + (uint)length - (uint)Vector.Count)); + if (matches != Vector.Zero) + { + goto Found; + } } } return false; - Found: + Found: return true; } }