Skip to content

Commit

Permalink
Fix #1772: Support EnumeratorCancellationAttribute
Browse files Browse the repository at this point in the history
With this change, the async decompiler no longer gets confused by the logic disposing `this.<>x__combinedTokens`.
  • Loading branch information
dgrunwald committed Jun 28, 2020
1 parent c0b1119 commit d8e837e
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 58 deletions.
9 changes: 9 additions & 0 deletions ICSharpCode.Decompiler.Tests/TestCases/Pretty/AsyncStreams.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty
Expand Down Expand Up @@ -47,6 +49,13 @@ public async IAsyncEnumerable<int> AwaitInFinally()
Console.WriteLine("end finally");
}
}

public static async IAsyncEnumerable<int> SimpleCancellation([EnumeratorCancellation] CancellationToken cancellationToken)
{
yield return 1;
await Task.Delay(100, cancellationToken);
yield return 2;
}
}

public struct TestStruct
Expand Down
174 changes: 129 additions & 45 deletions ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ enum AsyncMethodType
ILFunction moveNextFunction;
ILVariable cachedStateVar; // variable in MoveNext that caches the stateField.
TryCatch mainTryCatch;
Block setResultAndExitBlock; // block that is jumped to for return statements
// Note: for async enumerators, a jump to setResultAndExitBlock is a 'yield break;'
Block setResultReturnBlock; // block that is jumped to for return statements
// Note: for async enumerators, a jump to setResultReturnBlock is a 'yield break;'
int finalState; // final state after the setResultAndExitBlock
bool finalStateKnown;
ILVariable resultVar; // the variable that gets returned by the setResultAndExitBlock
Expand Down Expand Up @@ -351,8 +351,7 @@ call Start(ldloca V_1, ldloca V_0)
static bool MatchCall(ILInstruction inst, string name, out InstructionCollection<ILInstruction> args)
{
if (inst is CallInstruction call && (call.OpCode == OpCode.Call || call.OpCode == OpCode.CallVirt)
&& call.Method.Name == name && !call.Method.IsStatic)
{
&& call.Method.Name == name && !call.Method.IsStatic) {
args = call.Arguments;
return args.Count > 0;
}
Expand Down Expand Up @@ -435,7 +434,7 @@ private bool MatchAsyncEnumeratorCreationPattern(ILFunction function)
return false;
if (!returnValue.MatchLdLoc(v))
return false;

// HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state
initialState = -1;
try {
Expand Down Expand Up @@ -548,6 +547,7 @@ void AnalyzeMoveNext()
throw new SymbolicAnalysisFailedException();
if (blockContainer.EntryPoint.IncomingEdgeCount != 1)
throw new SymbolicAnalysisFailedException();
bool[] blocksAnalyzed = new bool[blockContainer.Blocks.Count];
cachedStateVar = null;
int pos = 0;
while (blockContainer.EntryPoint.Instructions[pos] is StLoc stloc) {
Expand Down Expand Up @@ -581,18 +581,19 @@ void AnalyzeMoveNext()
}

Debug.Assert(blockContainer.Blocks[0] == blockContainer.EntryPoint); // already checked this block
blocksAnalyzed[0] = true;
pos = 1;
if (MatchYieldBlock(blockContainer, pos)) {
setResultYieldBlock = blockContainer.Blocks[pos];
blocksAnalyzed[pos] = true;
pos++;
} else {
setResultYieldBlock = null;
}

setResultAndExitBlock = blockContainer.Blocks.ElementAtOrDefault(pos);
CheckSetResultAndExitBlock(blockContainer);
setResultReturnBlock = CheckSetResultReturnBlock(blockContainer, pos, blocksAnalyzed);

if (pos + 1 < blockContainer.Blocks.Count)
if (!blocksAnalyzed.All(b => b))
throw new SymbolicAnalysisFailedException("too many blocks");
}

Expand All @@ -619,27 +620,113 @@ bool MatchYieldBlock(BlockContainer blockContainer, int pos)
return false;
return block.Instructions[1].MatchLeave(blockContainer);
}
void CheckSetResultAndExitBlock(BlockContainer blockContainer)

private Block CheckSetResultReturnBlock(BlockContainer blockContainer, int setResultReturnBlockIndex, bool[] blocksAnalyzed)
{
if (setResultAndExitBlock == null) {
if (setResultReturnBlockIndex >= blockContainer.Blocks.Count) {
// This block can be absent if the function never exits normally,
// but always throws an exception/loops infinitely.
resultVar = null;
finalStateKnown = false; // final state will be detected in ValidateCatchBlock() instead
return;
return null;
}

var block = blockContainer.Blocks[setResultReturnBlockIndex];
// stfld <>1__state(ldloc this, ldc.i4 -2)
// [optional] stfld <>u__N(ldloc this, ldnull)
// call SetResult(ldflda <>t__builder(ldloc this), ldloc result)
// [optional] call Complete(ldflda <>t__builder(ldloc this))
// leave IL_0000
int pos = 0;
if (!MatchStateAssignment(setResultAndExitBlock.Instructions[pos], out finalState))
if (!MatchStateAssignment(block.Instructions[pos], out finalState))
throw new SymbolicAnalysisFailedException();
finalStateKnown = true;
pos++;
MatchHoistedLocalCleanup(setResultAndExitBlock, ref pos);
if (!MatchCall(setResultAndExitBlock.Instructions[pos], "SetResult", out var args))
if (pos + 2 == block.Instructions.Count && block.MatchIfAtEndOfBlock(out var condition, out var trueInst, out var falseInst)) {
if (MatchDisposeCombinedTokens(blockContainer, condition, trueInst, falseInst, blocksAnalyzed, out var setResultAndExitBlock)) {
blocksAnalyzed[block.ChildIndex] = true;
block = setResultAndExitBlock;
pos = 0;
} else {
throw new SymbolicAnalysisFailedException();
}
}

// [optional] stfld <>u__N(ldloc this, ldnull)
MatchHoistedLocalCleanup(block, ref pos);
CheckSetResultAndExit(blockContainer, block, ref pos);
blocksAnalyzed[block.ChildIndex] = true;
return blockContainer.Blocks[setResultReturnBlockIndex];
}

private bool MatchDisposeCombinedTokens(BlockContainer blockContainer, ILInstruction condition, ILInstruction trueInst, ILInstruction falseInst, bool[] blocksAnalyzed, out Block setResultAndExitBlock)
{
setResultAndExitBlock = null;
// ...
// if (comp.o(ldfld <>x__combinedTokens(ldloc this) == ldnull)) br setResultAndExit
// br disposeCombinedTokens
// }
//
// Block disposeCombinedTokens (incoming: 1) {
// callvirt Dispose(ldfld <>x__combinedTokens(ldloc this))
// stfld <>x__combinedTokens(ldloc this, ldnull)
// br setResultAndExit
// }
if (!condition.MatchCompEqualsNull(out var testedInst))
return false;
if (!testedInst.MatchLdFld(out var target, out var ctsField))
return false;
if (!target.MatchLdThis())
return false;
if (!(ctsField.Type is ITypeDefinition { FullTypeName: { IsNested: false, TopLevelTypeName: { Name: "CancellationTokenSource", Namespace: "System.Threading" } } }))
return false;
if (!(trueInst.MatchBranch(out setResultAndExitBlock) && falseInst.MatchBranch(out var disposedCombinedTokensBlock)))
return false;
if (!(setResultAndExitBlock.Parent == blockContainer && disposedCombinedTokensBlock.Parent == blockContainer))
return false;

var block = disposedCombinedTokensBlock;
int pos = 0;
// callvirt Dispose(ldfld <>x__combinedTokens(ldloc this))
if (!(block.Instructions[pos] is CallVirt { Method: { Name: "Dispose" } } disposeCall))
return false;
if (disposeCall.Arguments.Count != 1)
return false;
if (!disposeCall.Arguments[0].MatchLdFld(out var target2, out var ctsField2))
return false;
if (!(target2.MatchLdThis() && ctsField2.Equals(ctsField)))
return false;
pos++;
// stfld <>x__combinedTokens(ldloc this, ldnull)
if (!block.Instructions[pos].MatchStFld(out var target3, out var ctsField3, out var storedValue))
return false;
if (!(target3.MatchLdThis() && ctsField3.Equals(ctsField)))
return false;
if (!storedValue.MatchLdNull())
return false;
pos++;
// br setResultAndExit
if (!block.Instructions[pos].MatchBranch(setResultAndExitBlock))
return false;
blocksAnalyzed[block.ChildIndex] = true;

return true;
}

private void MatchHoistedLocalCleanup(Block block, ref int pos)
{
while (block.Instructions[pos].MatchStFld(out var target, out _, out var value)) {
// https://github.com/dotnet/roslyn/pull/39735 hoisted local cleanup
if (!target.MatchLdThis())
throw new SymbolicAnalysisFailedException();
if (!(value.MatchLdNull() || value is DefaultValue))
throw new SymbolicAnalysisFailedException();
pos++;
}
}

void CheckSetResultAndExit(BlockContainer blockContainer, Block block, ref int pos)
{
// call SetResult(ldflda <>t__builder(ldloc this), ldloc result)
// [optional] call Complete(ldflda <>t__builder(ldloc this))
// leave IL_0000
if (!MatchCall(block.Instructions[pos], "SetResult", out var args))
throw new SymbolicAnalysisFailedException();
if (!IsBuilderOrPromiseFieldOnThis(args[0]))
throw new SymbolicAnalysisFailedException();
Expand All @@ -666,27 +753,15 @@ void CheckSetResultAndExitBlock(BlockContainer blockContainer)
break;
}
pos++;
if (MatchCall(setResultAndExitBlock.Instructions[pos], "Complete", out args)) {
if (MatchCall(block.Instructions[pos], "Complete", out args)) {
if (!(args.Count == 1 && IsBuilderFieldOnThis(args[0])))
throw new SymbolicAnalysisFailedException();
pos++;
}
if (!setResultAndExitBlock.Instructions[pos].MatchLeave(blockContainer))
if (!block.Instructions[pos].MatchLeave(blockContainer))
throw new SymbolicAnalysisFailedException();
}

private void MatchHoistedLocalCleanup(Block block, ref int pos)
{
while (block.Instructions[pos].MatchStFld(out var target, out _, out var value)) {
// https://github.com/dotnet/roslyn/pull/39735 hoisted local cleanup
if (!target.MatchLdThis())
throw new SymbolicAnalysisFailedException();
if (!(value.MatchLdNull() || value is DefaultValue))
throw new SymbolicAnalysisFailedException();
pos++;
}
}

void ValidateCatchBlock()
{
// catch E_143 : System.Exception if (ldc.i4 1) BlockContainer {
Expand All @@ -706,9 +781,10 @@ void ValidateCatchBlock()
throw new SymbolicAnalysisFailedException();
if (!handler.Filter.MatchLdcI4(1))
throw new SymbolicAnalysisFailedException();
var catchBlock = YieldReturnDecompiler.SingleBlock(handler.Body);
if (catchBlock == null)
if (!(handler.Body is BlockContainer handlerContainer))
throw new SymbolicAnalysisFailedException();
bool[] blocksAnalyzed = new bool[handlerContainer.Blocks.Count];
var catchBlock = handlerContainer.EntryPoint;
catchHandlerOffset = catchBlock.StartILOffset;
// stloc exception(ldloc E_143)
if (!(catchBlock.Instructions[0] is StLoc stloc))
Expand All @@ -726,6 +802,15 @@ void ValidateCatchBlock()
finalStateKnown = true;
}
int pos = 2;
if (pos + 2 == catchBlock.Instructions.Count && catchBlock.MatchIfAtEndOfBlock(out var condition, out var trueInst, out var falseInst)) {
if (MatchDisposeCombinedTokens(handlerContainer, condition, trueInst, falseInst, blocksAnalyzed, out var setResultAndExitBlock)) {
blocksAnalyzed[catchBlock.ChildIndex] = true;
catchBlock = setResultAndExitBlock;
pos = 0;
} else {
throw new SymbolicAnalysisFailedException();
}
}
MatchHoistedLocalCleanup(catchBlock, ref pos);
// call SetException(ldfld <>t__builder(ldloc this), ldloc exception)
if (!MatchCall(catchBlock.Instructions[pos], "SetException", out var args))
Expand All @@ -748,6 +833,9 @@ void ValidateCatchBlock()
// leave IL_0000
if (!catchBlock.Instructions[pos].MatchLeave((BlockContainer)moveNextFunction.Body))
throw new SymbolicAnalysisFailedException();
blocksAnalyzed[catchBlock.ChildIndex] = true;
if (!blocksAnalyzed.All(b => b))
throw new SymbolicAnalysisFailedException();
}

bool IsBuilderFieldOnThis(ILInstruction inst)
Expand Down Expand Up @@ -781,8 +869,7 @@ bool MatchStateAssignment(ILInstruction inst, out int newState)
if (inst.MatchStFld(out var target, out var field, out var value)
&& StackSlotValue(target).MatchLdThis()
&& field.MemberDefinition == stateField
&& StackSlotValue(value).MatchLdcI4(out newState))
{
&& StackSlotValue(value).MatchLdcI4(out newState)) {
return true;
}
newState = 0;
Expand Down Expand Up @@ -830,7 +917,7 @@ void InlineBodyOfMoveNext(ILFunction function)
moveNextFunction.Variables.Clear();
moveNextFunction.ReleaseRef();
foreach (var branch in function.Descendants.OfType<Branch>()) {
if (branch.TargetBlock == setResultAndExitBlock) {
if (branch.TargetBlock == setResultReturnBlock) {
branch.ReplaceWith(new Leave((BlockContainer)function.Body, resultVar == null ? null : new LdLoc(resultVar)).WithILRange(branch));
}
}
Expand Down Expand Up @@ -864,8 +951,7 @@ void FinalizeInlineMoveNext(ILFunction function)
for (int i = block.Instructions.Count - 1; i >= 0; i--) {
if (block.Instructions[i].MatchStLoc(out var v, out var value)
&& v.IsSingleDefinition && v.LoadCount == 0
&& value.MatchLdLoc(cachedStateVar))
{
&& value.MatchLdLoc(cachedStateVar)) {
block.Instructions.RemoveAt(i);
}
}
Expand Down Expand Up @@ -1118,8 +1204,7 @@ bool AnalyzeAwaitBlock(Block block, out ILVariable awaiter, out IField awaiterFi
while (pos > 0 && block.Instructions[pos - 1] is StLoc stloc2
&& stloc2.Variable.IsSingleDefinition && stloc2.Variable.LoadCount == 0
&& stloc2.Variable.Kind == VariableKind.StackSlot
&& SemanticHelper.IsPure(stloc2.Value.Flags))
{
&& SemanticHelper.IsPure(stloc2.Value.Flags)) {
pos--;
}
block.Instructions.RemoveRange(pos, block.Instructions.Count - pos);
Expand Down Expand Up @@ -1332,8 +1417,7 @@ bool CheckResumeBlock(Block block, ILVariable awaiterVar, IField awaiterField, B
if (block.Instructions[pos].MatchStFld(out target, out field, out value)
&& target.MatchLdThis()
&& field.Equals(awaiterField)
&& (value.OpCode == OpCode.DefaultValue || value.OpCode == OpCode.LdNull))
{
&& (value.OpCode == OpCode.DefaultValue || value.OpCode == OpCode.LdNull)) {
pos++;
} else {
// {stloc V_6(default.value System.Runtime.CompilerServices.TaskAwaiter)}
Expand All @@ -1355,7 +1439,7 @@ bool CheckResumeBlock(Block block, ILVariable awaiterVar, IField awaiterField, B
pos++;
}
if (block.Instructions[pos] is StLoc stlocCachedState) {
if (stlocCachedState.Variable.Kind == VariableKind.Local && stlocCachedState.Variable.Index == cachedStateVar?.Index) {
if (stlocCachedState.Variable.Kind == VariableKind.Local && stlocCachedState.Variable.Index == cachedStateVar?.Index) {
if (stlocCachedState.Value.MatchLdLoc(m1Var) || stlocCachedState.Value.MatchLdcI4(initialState))
pos++;
}
Expand All @@ -1371,7 +1455,7 @@ bool CheckResumeBlock(Block block, ILVariable awaiterVar, IField awaiterField, B
} else {
return false;
}

return block.Instructions[pos].MatchBranch(completedBlock);
}

Expand Down
4 changes: 1 addition & 3 deletions ICSharpCode.Decompiler/IL/Instructions/Block.cs
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ public bool MatchIfAtEndOfBlock(out ILInstruction condition, out ILInstruction t
falseInst = Instructions.Last();
while (condition.MatchLogicNot(out var arg)) {
condition = arg;
ILInstruction tmp = trueInst;
trueInst = falseInst;
falseInst = tmp;
(trueInst, falseInst) = (falseInst, trueInst);
}
return true;
}
Expand Down
Loading

0 comments on commit d8e837e

Please sign in to comment.