Skip to content

Commit

Permalink
Handle final elements in SpanHelpers.Contains for byte and char vecto…
Browse files Browse the repository at this point in the history
…rized (#67492)

* Handle final elements in SpanHelpers.Contains(ref byte, byte, int) vectorized

* Handle final elements in SpanHelpers.Contains(ref char, char, int) vectorized

* Use equality operator instead of Vector<T>.Zero.Equals due to codegen issue

Cf. #67492 (comment)
  • Loading branch information
gfoidl authored Apr 25, 2022
1 parent c92e8d5 commit 1958c7e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 61 deletions.
61 changes: 35 additions & 26 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,26 +341,26 @@ public static bool Contains(ref byte searchSpace, byte value, int length)

uint uValue = value; // Use uint for comparisons to avoid unnecessary 8->32 extensions
nuint offset = 0; // Use nuint for arithmetic to avoid unnecessary 64->32->64 truncations
nuint lengthToExamine = (nuint)(uint)length;
nuint lengthToExamine = (uint)length;

if (Vector.IsHardwareAccelerated && length >= Vector<byte>.Count * 2)
{
lengthToExamine = UnalignedCountVector(ref searchSpace);
}

SequentialScan:
while (lengthToExamine >= 8)
{
lengthToExamine -= 8;

if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 0) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 1) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 2) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 3) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 4) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 5) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 6) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 7))
ref byte start = ref Unsafe.AddByteOffset(ref searchSpace, offset);

if (uValue == Unsafe.AddByteOffset(ref start, 0) ||
uValue == Unsafe.AddByteOffset(ref start, 1) ||
uValue == Unsafe.AddByteOffset(ref start, 2) ||
uValue == Unsafe.AddByteOffset(ref start, 3) ||
uValue == Unsafe.AddByteOffset(ref start, 4) ||
uValue == Unsafe.AddByteOffset(ref start, 5) ||
uValue == Unsafe.AddByteOffset(ref start, 6) ||
uValue == Unsafe.AddByteOffset(ref start, 7))
{
goto Found;
}
Expand All @@ -371,11 +371,12 @@ public static bool Contains(ref byte searchSpace, byte value, int length)
if (lengthToExamine >= 4)
{
lengthToExamine -= 4;
ref byte start = ref Unsafe.AddByteOffset(ref searchSpace, offset);

if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 0) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 1) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 2) ||
uValue == Unsafe.AddByteOffset(ref searchSpace, offset + 3))
if (uValue == Unsafe.AddByteOffset(ref start, 0) ||
uValue == Unsafe.AddByteOffset(ref start, 1) ||
uValue == Unsafe.AddByteOffset(ref start, 2) ||
uValue == Unsafe.AddByteOffset(ref start, 3))
{
goto Found;
}
Expand All @@ -385,24 +386,25 @@ public static bool Contains(ref byte searchSpace, byte value, int length)

while (lengthToExamine > 0)
{
lengthToExamine -= 1;
lengthToExamine--;

if (uValue == Unsafe.AddByteOffset(ref searchSpace, offset))
goto Found;

offset += 1;
offset++;
}

if (Vector.IsHardwareAccelerated && (offset < (nuint)(uint)length))
if (Vector.IsHardwareAccelerated && (offset < (uint)length))
{
lengthToExamine = (((nuint)(uint)length - offset) & (nuint)~(Vector<byte>.Count - 1));
lengthToExamine = ((uint)length - offset) & (nuint)~(Vector<byte>.Count - 1);

Vector<byte> values = new Vector<byte>(value);
Vector<byte> values = new(value);
Vector<byte> matches;

while (lengthToExamine > offset)
while (offset < lengthToExamine)
{
var matches = Vector.Equals(values, LoadVector(ref searchSpace, offset));
if (Vector<byte>.Zero.Equals(matches))
matches = Vector.Equals(values, LoadVector(ref searchSpace, offset));
if (matches == Vector<byte>.Zero)
{
offset += (nuint)Vector<byte>.Count;
continue;
Expand All @@ -411,10 +413,17 @@ public static bool Contains(ref byte searchSpace, byte value, int length)
goto Found;
}

if (offset < (nuint)(uint)length)
// The total length is at least Vector<byte>.Count, so instead of falling back to a
// sequential scan for the remainder, we check the vector read from the end -- note: unaligned read necessary.
// We do this only if at least one element is left.
if (offset < (uint)length)
{
lengthToExamine = ((nuint)(uint)length - offset);
goto SequentialScan;
offset = (uint)(length - Vector<byte>.Count);
matches = Vector.Equals(values, LoadVector(ref searchSpace, offset));
if (matches != Vector<byte>.Zero)
{
goto Found;
}
}
}

Expand Down
75 changes: 40 additions & 35 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ public static unsafe int SequenceCompareTo(ref char first, int firstLength, ref

while (minLength >= (i + (nuint)(sizeof(nuint) / sizeof(char))))
{
if (Unsafe.ReadUnaligned<nuint> (ref Unsafe.As<char, byte>(ref Unsafe.Add(ref first, (nint)i))) !=
if (Unsafe.ReadUnaligned<nuint>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref first, (nint)i))) !=
Unsafe.ReadUnaligned<nuint>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref second, (nint)i))))
{
break;
Expand Down Expand Up @@ -428,83 +428,88 @@ public static unsafe bool Contains(ref char searchSpace, char value, int length)

fixed (char* pChars = &searchSpace)
{
char* pCh = pChars;
char* pEndCh = pCh + length;
nuint offset = 0; // Use nuint for arithmetic to avoid unnecessary 64->32->64 truncations
nuint lengthToExamine = (uint)length;

if (Vector.IsHardwareAccelerated && length >= Vector<ushort>.Count * 2)
{
// Figure out how many characters to read sequentially until we are vector aligned
// This is equivalent to:
// unaligned = ((int)pCh % Unsafe.SizeOf<Vector<ushort>>()) / elementsPerByte
// unaligned = ((int)pCh % Unsafe.SizeOf<Vector<ushort>>()) / ElementsPerByte
// length = (Vector<ushort>.Count - unaligned) % Vector<ushort>.Count
const int elementsPerByte = sizeof(ushort) / sizeof(byte);
int unaligned = ((int)pCh & (Unsafe.SizeOf<Vector<ushort>>() - 1)) / elementsPerByte;
length = (Vector<ushort>.Count - unaligned) & (Vector<ushort>.Count - 1);
const int ElementsPerByte = sizeof(ushort) / sizeof(byte);
int unaligned = (int)((uint)((int)pChars & (Unsafe.SizeOf<Vector<ushort>>() - 1)) / ElementsPerByte);
lengthToExamine = (uint)((Vector<ushort>.Count - unaligned) & (Vector<ushort>.Count - 1));
}

SequentialScan:
while (length >= 4)
while (lengthToExamine >= 4)
{
length -= 4;
lengthToExamine -= 4;
char* pStart = pChars + offset;

if (value == *pCh ||
value == *(pCh + 1) ||
value == *(pCh + 2) ||
value == *(pCh + 3))
if (value == pStart[0] ||
value == pStart[1] ||
value == pStart[2] ||
value == pStart[3])
{
goto Found;
}

pCh += 4;
offset += 4;
}

while (length > 0)
while (lengthToExamine > 0)
{
length--;
lengthToExamine--;

if (value == *pCh)
if (value == pChars[offset])
goto Found;

pCh++;
offset++;
}

// We get past SequentialScan only if IsHardwareAccelerated is true. However, we still have the redundant check to allow
// the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware accelerated.
if (Vector.IsHardwareAccelerated && pCh < pEndCh)
// the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware acceleration.
if (Vector.IsHardwareAccelerated && (offset < (uint)length))
{
// Get the highest multiple of Vector<ushort>.Count that is within the search space.
// That will be how many times we iterate in the loop below.
// This is equivalent to: length = Vector<ushort>.Count * ((int)(pEndCh - pCh) / Vector<ushort>.Count)
length = (int)((pEndCh - pCh) & ~(Vector<ushort>.Count - 1));
// This is equivalent to: lengthToExamine = Vector<ushort>.Count + ((uint)length - offset) / Vector<ushort>.Count)
lengthToExamine = ((uint)length - offset) & (nuint)~(Vector<ushort>.Count - 1);

// Get comparison Vector
Vector<ushort> vComparison = new Vector<ushort>(value);
Vector<ushort> values = new(value);
Vector<ushort> matches;

while (length > 0)
while (offset < lengthToExamine)
{
// Using Unsafe.Read instead of ReadUnaligned since the search space is pinned and pCh is always vector aligned
Debug.Assert(((int)pCh & (Unsafe.SizeOf<Vector<ushort>>() - 1)) == 0);
Vector<ushort> vMatches = Vector.Equals(vComparison, Unsafe.Read<Vector<ushort>>(pCh));
if (Vector<ushort>.Zero.Equals(vMatches))
Debug.Assert(((int)(pChars + offset) % Unsafe.SizeOf<Vector<ushort>>()) == 0);
matches = Vector.Equals(values, Unsafe.Read<Vector<ushort>>(pChars + offset));
if (matches == Vector<ushort>.Zero)
{
pCh += Vector<ushort>.Count;
length -= Vector<ushort>.Count;
offset += (nuint)Vector<ushort>.Count;
continue;
}

goto Found;
}

if (pCh < pEndCh)
// The total length is at least Vector<ushort>.Count, so instead of falling back to a
// sequential scan for the remainder, we check the vector read from the end -- note: unaligned read necessary.
// We do this only if at least one element is left.
if (offset < (uint)length)
{
length = (int)(pEndCh - pCh);
goto SequentialScan;
matches = Vector.Equals(values, Unsafe.ReadUnaligned<Vector<ushort>>(pChars + (uint)length - (uint)Vector<ushort>.Count));
if (matches != Vector<ushort>.Zero)
{
goto Found;
}
}
}

return false;

Found:
Found:
return true;
}
}
Expand Down

0 comments on commit 1958c7e

Please sign in to comment.