Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix buggy code for ordering switch statement cases for messages #147

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Runtime.InteropServices.ComTypes;
using Mediator.SourceGenerator.Extensions;
using Microsoft.CodeAnalysis.CSharp;

Expand Down Expand Up @@ -264,22 +266,35 @@ public void Analyze()
}
}

private sealed class InheritanceComparer : IComparer<INamedTypeSymbol>
private static ImmutableEquatableArray<TModel> ToModelsSortedByInheritanceDepth<TSource, TModel>(
HashSet<TSource> source,
Func<TSource, TModel> selector
)
where TSource : SymbolMetadata<TSource>
where TModel : SymbolMetadataModel, IEquatable<TModel>
{
public int Compare(INamedTypeSymbol x, INamedTypeSymbol y)
var analysis = new (TSource Message, int Depth)[source.Count];
int i = 0;
foreach (var message in source)
{
while (x.BaseType is not null)
var baseType = message.Symbol.BaseType;
int depth = 0;
while (baseType is not null && baseType.SpecialType != SpecialType.System_Object)
{
if (x.BaseType.SpecialType == SpecialType.System_Object)
break;

if (SymbolEqualityComparer.Default.Equals(x.BaseType, y))
return -1;
x = x.BaseType;
depth++;
baseType = baseType.BaseType;
}

return x.GetTypeSymbolFullName().CompareTo(y.GetTypeSymbolFullName());
Debug.Assert(i < source.Count);
analysis[i++] = (message, depth);
}

Array.Sort(analysis, (x, y) => y.Depth.CompareTo(x.Depth));
var models = new TModel[source.Count];
for (i = 0; i < source.Count; i++)
models[i] = selector(analysis[i].Message);

return new ImmutableEquatableArray<TModel>(models);
}

public CompilationModel ToModel()
Expand All @@ -289,28 +304,27 @@ public CompilationModel ToModel()

try
{
var comparer = new InheritanceComparer();
if (_notificationPublisherImplementationSymbol is null)
throw new Exception("Unexpected state: NotificationPublisherImplementationSymbol is null");

var model = new CompilationModel(
_requestMessages
.OrderBy(m => m.Symbol, comparer)
.Select(x => new RequestMessageModel(
x.Symbol,
x.ResponseSymbol,
x.MessageType,
x.Handler?.ToModel(),
x.WrapperType
))
.ToImmutableEquatableArray(),
_notificationMessages
.OrderBy(m => m.Symbol, comparer)
.Select(x => x.ToModel())
.ToImmutableEquatableArray(),
ToModelsSortedByInheritanceDepth(
_requestMessages,
m => new RequestMessageModel(
m.Symbol,
m.ResponseSymbol,
m.MessageType,
m.Handler?.ToModel(),
m.WrapperType
)
),
ToModelsSortedByInheritanceDepth(_notificationMessages, m => m.ToModel()),
_requestMessageHandlers.Select(x => x.ToModel()).ToImmutableEquatableArray(),
_notificationMessageHandlers.Select(x => x.ToModel()).ToImmutableEquatableArray(),
RequestMessageHandlerWrappers.ToImmutableEquatableArray(),
new NotificationPublisherTypeModel(
_notificationPublisherImplementationSymbol!.GetTypeSymbolFullName(),
_notificationPublisherImplementationSymbol!.Name
_notificationPublisherImplementationSymbol.GetTypeSymbolFullName(),
_notificationPublisherImplementationSymbol.Name
),
HasErrors,
MediatorNamespace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public sealed class ImmutableEquatableArray<T> : IEquatable<ImmutableEquatableAr
public T this[int index] => _values[index];
public int Count => _values.Length;

public ImmutableEquatableArray(T[] values) => _values = values;

public ImmutableEquatableArray(IEnumerable<T> values) => _values = values.ToArray();

public bool Equals(ImmutableEquatableArray<T>? other) =>
Expand Down
64 changes: 63 additions & 1 deletion test/Mediator.SourceGenerator.Tests/MessageOrderingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ await inputCompilation.AssertAndVerify(
Assert.Equal(5, notifications.Count);

var last = notifications[^1];
last.Name.Equals("DomainEvent");
last.Name.Should().Be("DomainEvent");

var index0 = notifications.FindIndex(n => n.Name == "RoundSucceededActually");
var index1 = notifications.FindIndex(n => n.Name == "RoundSucceeded");
Expand All @@ -55,4 +55,66 @@ await inputCompilation.AssertAndVerify(
}
);
}

[Fact]
public async Task Test_Notifications_Ordering_Bigger()
{
var inputCompilation = Fixture.CreateLibrary(
"""
using Mediator;
using System.Threading.Tasks;
using System;

namespace TestCode;

public class Program
{
public static void Main()
{
}
}

public record DomainEvent(DateTimeOffset Timestamp) : INotification;
public record RoundCreated(long Id, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundResulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundSucceeded(long Id, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundSucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);

public record DomainEvent2(DateTimeOffset Timestamp) : INotification;
public record Round2Created(long Id, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);

public record DomainEvent10(DateTimeOffset Timestamp) : INotification;
public record Sound2Created(long Id, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);

public record DomainEvent11(DateTimeOffset Timestamp) : INotification;
public record Sound20Created(long Id, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : Sound20Succeeded(Id, Timestamp);
"""
);

await inputCompilation.AssertAndVerify(
Assertions.CompilesWithoutDiagnostics,
result =>
{
var model = result.Generator.CompilationModel;
Assert.NotNull(model);
var notifications = model.NotificationMessages.ToList();
Assert.Equal(5 * 4, notifications.Count);

Assert.All(notifications.AsEnumerable().Take(4), n => n.Name.Should().EndWith("Actually"));
Assert.All(
notifications.AsEnumerable().Reverse().Take(4),
n => n.Name.Should().StartWith("DomainEvent")
);
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ private readonly struct DICache
private readonly global::System.IServiceProvider _sp;

public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceededActually>[] Handlers_For_TestCode_RoundSucceededActually;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[] Handlers_For_TestCode_RoundSucceeded;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[] Handlers_For_TestCode_RoundResulted;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[] Handlers_For_TestCode_RoundCreated;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[] Handlers_For_TestCode_RoundResulted;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[] Handlers_For_TestCode_RoundSucceeded;
public readonly global::Mediator.INotificationHandler<global::TestCode.DomainEvent>[] Handlers_For_TestCode_DomainEvent;

public readonly global::Mediator.ForeachAwaitPublisher InternalNotificationPublisherImpl;
Expand All @@ -562,18 +562,18 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceededActually is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceededActually>[]);
Handlers_For_TestCode_RoundSucceededActually = handlers_For_TestCode_RoundSucceededActually.ToArray();
}
var handlers_For_TestCode_RoundSucceeded = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>>();
var handlers_For_TestCode_RoundCreated = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]>(
handlers_For_TestCode_RoundSucceeded
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]>(
handlers_For_TestCode_RoundCreated
);
}
else
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = handlers_For_TestCode_RoundSucceeded.ToArray();
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is not global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = handlers_For_TestCode_RoundCreated.ToArray();
}
var handlers_For_TestCode_RoundResulted = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundResulted>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
Expand All @@ -588,18 +588,18 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundResulted is not global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[]);
Handlers_For_TestCode_RoundResulted = handlers_For_TestCode_RoundResulted.ToArray();
}
var handlers_For_TestCode_RoundCreated = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>>();
var handlers_For_TestCode_RoundSucceeded = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]>(
handlers_For_TestCode_RoundCreated
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]>(
handlers_For_TestCode_RoundSucceeded
);
}
else
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is not global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = handlers_For_TestCode_RoundCreated.ToArray();
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = handlers_For_TestCode_RoundSucceeded.ToArray();
}
var handlers_For_TestCode_DomainEvent = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.DomainEvent>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
Expand Down Expand Up @@ -834,9 +834,9 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
switch (notification)
{
case global::TestCode.RoundSucceededActually n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundCreated n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.DomainEvent n: return Publish(n, cancellationToken);
default:
{
Expand Down Expand Up @@ -876,30 +876,30 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
);
}
/// <summary>
/// Send a notification of type global::TestCode.RoundSucceeded.
/// Send a notification of type global::TestCode.RoundCreated.
/// Throws <see cref="global::System.ArgumentNullException"/> if message is null.
/// Throws <see cref="global::System.AggregateException"/> if handlers throw exception(s).
/// </summary>
/// <param name="notification">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Awaitable task</returns>
public global::System.Threading.Tasks.ValueTask Publish(
global::TestCode.RoundSucceeded notification,
global::TestCode.RoundCreated notification,
global::System.Threading.CancellationToken cancellationToken = default
)
{
ThrowIfNull(notification, nameof(notification));


var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundSucceeded;
var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundCreated;

if (handlers.Length == 0)
{
return default;
}
var publisher = _diCacheLazy.Value.InternalNotificationPublisherImpl;
return publisher.Publish(
new global::Mediator.NotificationHandlers<global::TestCode.RoundSucceeded>(handlers, isArray: true),
new global::Mediator.NotificationHandlers<global::TestCode.RoundCreated>(handlers, isArray: true),
notification,
cancellationToken
);
Expand Down Expand Up @@ -934,30 +934,30 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
);
}
/// <summary>
/// Send a notification of type global::TestCode.RoundCreated.
/// Send a notification of type global::TestCode.RoundSucceeded.
/// Throws <see cref="global::System.ArgumentNullException"/> if message is null.
/// Throws <see cref="global::System.AggregateException"/> if handlers throw exception(s).
/// </summary>
/// <param name="notification">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Awaitable task</returns>
public global::System.Threading.Tasks.ValueTask Publish(
global::TestCode.RoundCreated notification,
global::TestCode.RoundSucceeded notification,
global::System.Threading.CancellationToken cancellationToken = default
)
{
ThrowIfNull(notification, nameof(notification));


var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundCreated;
var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundSucceeded;

if (handlers.Length == 0)
{
return default;
}
var publisher = _diCacheLazy.Value.InternalNotificationPublisherImpl;
return publisher.Publish(
new global::Mediator.NotificationHandlers<global::TestCode.RoundCreated>(handlers, isArray: true),
new global::Mediator.NotificationHandlers<global::TestCode.RoundSucceeded>(handlers, isArray: true),
notification,
cancellationToken
);
Expand Down Expand Up @@ -1010,9 +1010,9 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
switch (notification)
{
case global::TestCode.RoundSucceededActually n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundCreated n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.DomainEvent n: return Publish(n, cancellationToken);
default:
{
Expand Down
Loading
Loading