diff --git a/src/Pure.DI.Core/Core/IVariator.cs b/src/Pure.DI.Core/Core/IVariator.cs index 6d414284..b9ef667b 100644 --- a/src/Pure.DI.Core/Core/IVariator.cs +++ b/src/Pure.DI.Core/Core/IVariator.cs @@ -2,9 +2,9 @@ namespace Pure.DI.Core; // ReSharper disable once IdentifierTypo internal interface IVariator + where T: class { bool TryGetNextVariants( IEnumerable> variations, - Predicate hasVariantsPredicate, [NotNullWhen(true)] out IReadOnlyCollection? variants); } \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/ImplementationVariantsBuilder.cs b/src/Pure.DI.Core/Core/ImplementationVariantsBuilder.cs index c1af216f..d98d57bc 100644 --- a/src/Pure.DI.Core/Core/ImplementationVariantsBuilder.cs +++ b/src/Pure.DI.Core/Core/ImplementationVariantsBuilder.cs @@ -14,12 +14,12 @@ public IEnumerable Build(DpImplementation implementation) var variants = implementation.Methods.Select(method => CreateVariants(method, ImplementationVariantKind.Method)) .Concat(Enumerable.Repeat(CreateVariants(implementation.Constructor, ImplementationVariantKind.Ctor), 1)) - .Select(i => i.GetEnumerator()) + .Select(i => new SafeEnumerator(i.GetEnumerator())) .ToArray(); try { - while (implementationVariator.TryGetNextVariants(variants, _ => true, out var curVariants)) + while (implementationVariator.TryGetNextVariants(variants, out var curVariants)) { cancellationToken.ThrowIfCancellationRequested(); yield return curVariants.Aggregate( diff --git a/src/Pure.DI.Core/Core/Models/ImplementationVariant.cs b/src/Pure.DI.Core/Core/Models/ImplementationVariant.cs index 07bee662..b56bfc29 100644 --- a/src/Pure.DI.Core/Core/Models/ImplementationVariant.cs +++ b/src/Pure.DI.Core/Core/Models/ImplementationVariant.cs @@ -1,3 +1,3 @@ namespace Pure.DI.Core.Models; -internal readonly record struct ImplementationVariant(ImplementationVariantKind Kind, DpMethod Method); \ No newline at end of file +internal record ImplementationVariant(ImplementationVariantKind Kind, DpMethod Method); \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/ProcessingNode.cs b/src/Pure.DI.Core/Core/ProcessingNode.cs index d5bea946..402747ef 100644 --- a/src/Pure.DI.Core/Core/ProcessingNode.cs +++ b/src/Pure.DI.Core/Core/ProcessingNode.cs @@ -2,9 +2,8 @@ namespace Pure.DI.Core; -internal readonly struct ProcessingNode : IEquatable +internal class ProcessingNode : IEquatable { - public readonly bool HasNode = false; public readonly DependencyNode Node; private readonly Lazy _isMarkerBased; private readonly Lazy> _injections; @@ -15,7 +14,6 @@ public ProcessingNode( ISet contracts, IMarker marker) { - HasNode = true; Node = node; Contracts = contracts; diff --git a/src/Pure.DI.Core/Core/SafeEnumerator.cs b/src/Pure.DI.Core/Core/SafeEnumerator.cs new file mode 100644 index 00000000..41a13faa --- /dev/null +++ b/src/Pure.DI.Core/Core/SafeEnumerator.cs @@ -0,0 +1,35 @@ +#pragma warning disable CS8766 // Nullability of reference types in return type doesn't match implicitly implemented member (possibly because of nullability attributes). +namespace Pure.DI.Core; + +internal class SafeEnumerator(IEnumerator source): IEnumerator + where T: class +{ + private T? _current; + private bool _result; + + public T? Current + { + get + { + if (!_result) + { + return _current; + } + + _current = source.Current; + return _current; + } + } + + object? IEnumerator.Current => Current; + + public bool MoveNext() + { + _result = source.MoveNext(); + return _result; + } + + public void Reset() => source.Reset(); + + public void Dispose() => source.Dispose(); +} \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/VariationalDependencyGraphBuilder.cs b/src/Pure.DI.Core/Core/VariationalDependencyGraphBuilder.cs index 42dca9d6..bc80eceb 100644 --- a/src/Pure.DI.Core/Core/VariationalDependencyGraphBuilder.cs +++ b/src/Pure.DI.Core/Core/VariationalDependencyGraphBuilder.cs @@ -73,7 +73,7 @@ internal sealed class VariationalDependencyGraphBuilder( var maxIterations = globalOptions.MaxIterations; DependencyGraph? first = default; var maxAttempts = 0x2000; - while (variator.TryGetNextVariants(variants, node => !node.HasNode, out var nodes)) + while (variator.TryGetNextVariants(variants, out var nodes)) { if (maxAttempts-- == 0) { @@ -131,7 +131,7 @@ internal sealed class VariationalDependencyGraphBuilder( [SuppressMessage("ReSharper", "NotDisposedResourceIsReturned")] private static IEnumerable CreateVariants(IEnumerable nodes) => nodes.GroupBy(i => i.Node.Binding) - .Select(i => i.GetEnumerator()); + .Select(i => new SafeEnumerator(i.GetEnumerator())); private static IEnumerable SortByPriority(IEnumerable nodes) => nodes.GroupBy(i => i.Binding) diff --git a/src/Pure.DI.Core/Core/Variator.cs b/src/Pure.DI.Core/Core/Variator.cs index 0add7a62..88cb57c9 100644 --- a/src/Pure.DI.Core/Core/Variator.cs +++ b/src/Pure.DI.Core/Core/Variator.cs @@ -4,26 +4,32 @@ namespace Pure.DI.Core; internal sealed class Variator : IVariator + where T: class { public bool TryGetNextVariants( IEnumerable> variations, - Predicate hasVariantsPredicate, [NotNullWhen(true)] out IReadOnlyCollection? variants) { var hasNext = false; var curVariants = new List(); foreach (var enumerator in variations) { - if (!hasNext && enumerator.MoveNext()) + if (enumerator.Current is null) { - hasNext = true; - curVariants.Add(enumerator.Current); + enumerator.MoveNext(); + var current = enumerator.Current; + if (current is not null) + { + curVariants.Add(current); + hasNext = true; + } + continue; } - if (hasVariantsPredicate(enumerator.Current)) + if (!hasNext) { - enumerator.MoveNext(); + hasNext = enumerator.MoveNext(); } curVariants.Add(enumerator.Current);