Skip to content

Commit

Permalink
JIT: Support subtraction in scalar evolution (#99154)
Browse files Browse the repository at this point in the history
Represent `x - y` as `x + y * -1` (this works even for MinValue).

For example:
```csharp
public static int Foo(int[] arr, int k)
{
    int sum = 0;
    for (int i = arr.Length - 1; i >= 0; i -= k)
    {
        sum += arr[i];
    }

    return sum;
}
```
analyzes to
```
STMT00007 ( ??? ... ??? )
N004 (  0,  0) [000044] DA---------                         ▌  STORE_LCL_VAR int    V03 loc1         d:3 $VN.Void
N003 (  0,  0) [000043] -----------                         └──▌  PHI       int    $241
N001 (  0,  0) [000053] ----------- pred BB03                  ├──▌  PHI_ARG   int    V03 loc1         u:4
N002 (  0,  0) [000051] ----------- pred BB02                  └──▌  PHI_ARG   int    V03 loc1         u:2 $201
  => <L00, V03.2, (V01.1 * -1)>
```

```csharp
public static int Bar(int n)
{
    int sum = n * n;
    for (int i = 0; i < n; i++)
    {
        sum -= i;
    }

    return sum;
}
```

analyzes to
```
N004 (  0,  0) [000029] DA---------                         ▌  STORE_LCL_VAR int    V01 loc0         d:3 $VN.Void
N003 (  0,  0) [000028] -----------                         └──▌  PHI       int    $140
N001 (  0,  0) [000033] ----------- pred BB03                  ├──▌  PHI_ARG   int    V01 loc0         u:4
N002 (  0,  0) [000031] ----------- pred BB02                  └──▌  PHI_ARG   int    V01 loc0         u:2 $100
  => <L00, V01.2, <L00, (V02.2 (0) * -1), -1>>

```
for `sum`. It would be `<L00, V01.2, <L00, 0, -1>>` if we resolved SSA
defs during simplification -- I'll make that change in a future PR.

Also add some more documentation around the symbolic way we resolve
addrecs.
  • Loading branch information
jakobbotsch authored Mar 4, 2024
1 parent 4229d61 commit 77c5dba
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion src/coreclr/jit/scev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,53 @@ Scev* ScalarEvolutionContext::AnalyzeNew(BasicBlock* block, GenTree* tree, int d

assert(ssaDsc->GetBlock() != nullptr);

// Try simple but most common case first, where we have a direct
// add recurrence like i = i + 1.
Scev* simpleAddRec = CreateSimpleAddRec(store, enterScev, ssaDsc->GetBlock(), ssaDsc->GetDefNode()->Data());
if (simpleAddRec != nullptr)
{
return simpleAddRec;
}

// Otherwise try a more powerful approach; we create a symbolic
// node representing the recurrence and then invoke the analysis
// recursively. This handles for example cases like
//
// int i = start;
// while (i < n)
// {
// int j = i + 1;
// ...
// i = j;
// }
// => <L, start, 1>
//
// where we need to follow SSA defs. In this case the analysis will result in
// <symbolic node> + 1. The symbolic node represents a recurrence,
// so this corresponds to the infinite sequence [start, start + 1,
// start + 1 + 1, ...] which can be represented by <L, start, 1>.
//
// This approach also generalizes to handle chains of recurrences.
// For example:
//
// int i = 0;
// int j = 0;
// while (i < n)
// {
// j++;
// i += j;
// }
// => <L, 0, <L, 1, 1>>
//
// Here `i` will analyze to <symbolic node> + <L, [initial value of j], 1>.
// Like before this corresponds to an infinite sequence
// [start, start + <L, [initial value of j], 1>, start + 2 * <L, [initial value of j], 1>, ...]
// which again can be represented as <L, start, <L, [initial value of j], 1>>.
//
// More generally, as long as we have only additions and only a
// single operand is the recurrence, we can represent it as an add
// recurrence. See MakeAddRecFromRecursiveScev for the details.
//
ScevConstant* symbolicAddRec = NewConstant(data->TypeGet(), 0xdeadbeef);
m_ephemeralCache.Emplace(store, symbolicAddRec);

Expand Down Expand Up @@ -515,6 +556,7 @@ Scev* ScalarEvolutionContext::AnalyzeNew(BasicBlock* block, GenTree* tree, int d
return NewExtension(cast->IsUnsigned() ? ScevOper::ZeroExtend : ScevOper::SignExtend, TYP_LONG, op);
}
case GT_ADD:
case GT_SUB:
case GT_MUL:
case GT_LSH:
{
Expand All @@ -532,6 +574,10 @@ Scev* ScalarEvolutionContext::AnalyzeNew(BasicBlock* block, GenTree* tree, int d
case GT_ADD:
oper = ScevOper::Add;
break;
case GT_SUB:
oper = ScevOper::Add;
op2 = NewBinop(ScevOper::Mul, op2, NewConstant(op2->Type, -1));
break;
case GT_MUL:
oper = ScevOper::Mul;
break;
Expand Down Expand Up @@ -651,7 +697,8 @@ void ScalarEvolutionContext::ExtractAddOperands(ScevBinop* binop, ArrayStack<Sce
// recursiveScev - A symbolic node whose appearance represents the value of "scev"
//
// Returns:
// A non-recursive addrec
// A non-recursive addrec, or nullptr if the recursive SCEV is not
// representable as an add recurrence.
//
Scev* ScalarEvolutionContext::MakeAddRecFromRecursiveScev(Scev* startScev, Scev* scev, Scev* recursiveScev)
{
Expand Down

0 comments on commit 77c5dba

Please sign in to comment.