Skip to content

Commit

Permalink
Fix simple factory for generic type markers
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayPianikov committed Aug 10, 2024
1 parent a7ec131 commit 12abcf2
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 196 deletions.
52 changes: 12 additions & 40 deletions src/Pure.DI.Core/Core/ApiInvocationProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -387,69 +387,41 @@ private void VisitSimpleFactory(
ParenthesizedLambdaExpressionSyntax lambdaExpression)
{
CheckNotAsync(lambdaExpression);
var identifiers = lambdaExpression.ParameterList.Parameters.Select(i => i.Identifier).ToList();
var paramAttributes = lambdaExpression.ParameterList.Parameters.Select(i => i.AttributeLists.SelectMany(j => j.Attributes).ToList()).ToList();
const string ctxName = "ctx_1182D127";
var contextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(ctxName));
var parameters = lambdaExpression.ParameterList.Parameters;
var paramAttributes = parameters.Select(i => i.AttributeLists.SelectMany(j => j.Attributes).ToList()).ToList();
var resolvers = new List<MdResolver>();
var block = new List<StatementSyntax>();
var namespaces = new HashSet<string>();
for (var i = 0; i < argsTypes.Count; i++)
{
var argTypeSyntax = argsTypes[i];
var argType = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, argTypeSyntax);
namespaces.Add(argType.ContainingNamespace.ToString());
var argNamespace = argType.ContainingNamespace;
if (argNamespace is not null)
{
namespaces.Add(argNamespace.ToString());
}

var attributes = paramAttributes[i];
resolvers.Add(new MdResolver
{
SemanticModel = semanticModel,
Source = argTypeSyntax,
ContractType = argType,
Tag = new MdTag(0, null),
ArgumentType = argTypeSyntax,
Parameter = parameters[i],
Position = i,
Attributes = attributes.ToImmutableArray()
});

var valueDeclaration = SyntaxFactory.DeclarationExpression(
argTypeSyntax,
SyntaxFactory.SingleVariableDesignation(identifiers[i]));

var valueArg =
SyntaxFactory.Argument(valueDeclaration)
.WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword));

var injection = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(ctxName),
SyntaxFactory.IdentifierName(nameof(IContext.Inject))))
.AddArgumentListArguments(valueArg);

block.Add(SyntaxFactory.ExpressionStatement(injection));
}

if (lambdaExpression.Block is {} lambdaBlock)
{
block.AddRange(lambdaBlock.Statements);
}
else
{
if (lambdaExpression.ExpressionBody is { } body)
{
block.Add(SyntaxFactory.ReturnStatement(body));
}
}

var newLambdaExpression = SyntaxFactory.SimpleLambdaExpression(contextParameter)
.WithBlock(SyntaxFactory.Block(block));

metadataVisitor.VisitFactory(
new MdFactory(
semanticModel,
source,
returnType,
newLambdaExpression,
contextParameter,
lambdaExpression,
SyntaxFactory.Parameter(SyntaxFactory.Identifier("ctx_1182D127")),
resolvers.ToImmutableArray(),
false));

Expand Down
128 changes: 127 additions & 1 deletion src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ internal class FactoryCodeBuilder(
ICompilations compilations)
: ICodeBuilder<DpFactory>
{
public static readonly ParenthesizedLambdaExpressionSyntax DefaultBindAttrParenthesizedLambda = SyntaxFactory.ParenthesizedLambdaExpression();
public static readonly ParameterSyntax DefaultCtxParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier("ctx_1182D127"));
public const string DefaultInstanceValueName = "instance_1182D127";
private static readonly string InjectionStatement = $"{Names.InjectionMarker};";

public void Build(BuildContext ctx, in DpFactory factory)
Expand All @@ -25,11 +28,134 @@ public void Build(BuildContext ctx, in DpFactory factory)
lockIsRequired = default;
}

var originalLambda = factory.Source.Factory;
// Simple factory
if (originalLambda is ParenthesizedLambdaExpressionSyntax parenthesizedLambda)
{
var block = new List<StatementSyntax>();
foreach (var resolver in factory.Source.Resolvers)
{
if (resolver.ArgumentType is not { } argumentType || resolver.Parameter is not {} parameter)
{
continue;
}

var valueDeclaration = SyntaxFactory.DeclarationExpression(
argumentType,
SyntaxFactory.SingleVariableDesignation(parameter.Identifier));

var valueArg =
SyntaxFactory.Argument(valueDeclaration)
.WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword));

var injection = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(DefaultCtxParameter.Identifier),
SyntaxFactory.IdentifierName(nameof(IContext.Inject))))
.AddArgumentListArguments(valueArg);

block.Add(SyntaxFactory.ExpressionStatement(injection));
}

if (factory.Source.MemberResolver is {} memberResolver
&& memberResolver.Member is {} member
&& memberResolver.TypeConstructor is {} typeConstructor)
{
ExpressionSyntax? value = default;
var type = memberResolver.ContractType;
ExpressionSyntax instance = member.IsStatic
? SyntaxFactory.ParseTypeName(type.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat))
: SyntaxFactory.IdentifierName(DefaultInstanceValueName);

switch (member)
{
case IFieldSymbol fieldSymbol:
value = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
instance,
SyntaxFactory.IdentifierName(member.Name));
break;

case IPropertySymbol propertySymbol:
value = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
instance,
SyntaxFactory.IdentifierName(member.Name));
break;

case IMethodSymbol methodSymbol:
var args = methodSymbol.Parameters
.Select(i => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(i.Name)))
.ToArray();

if (methodSymbol.IsGenericMethod)
{
var setup = variable.Setup;
var binding = variable.Node.Binding;
var typeArgs = new List<TypeSyntax>();
// ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator
foreach (var typeArg in methodSymbol.TypeArguments)
{
var argType = typeConstructor.ConstructReversed(setup, binding.SemanticModel.Compilation, typeArg);
if (binding.TypeConstructor is { } bindingTypeConstructor)
{
argType = bindingTypeConstructor.Construct(setup, binding.SemanticModel.Compilation, argType);
}

var typeName = argType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat);
typeArgs.Add(SyntaxFactory.ParseTypeName(typeName));
}

value = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
instance,
SyntaxFactory.GenericName(member.Name).AddTypeArgumentListArguments(typeArgs.ToArray()));
}
else
{
value = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
instance,
SyntaxFactory.IdentifierName(member.Name));
}

value = SyntaxFactory
.InvocationExpression(value)
.AddArgumentListArguments(args);

break;
}

if (value is not null)
{
block.Add(SyntaxFactory.ReturnStatement(value));
}
}
else
{
if (parenthesizedLambda.Block is {} lambdaBlock)
{
block.AddRange(lambdaBlock.Statements);
}
else
{
if (parenthesizedLambda.ExpressionBody is { } body)
{
block.Add(SyntaxFactory.ReturnStatement(body));
}
}
}

originalLambda = SyntaxFactory.SimpleLambdaExpression(DefaultCtxParameter)
.WithBlock(SyntaxFactory.Block(block));
}

// Rewrites syntax tree
var finishLabel = $"{variable.VariableDeclarationName}Finish";
var injections = new List<FactoryRewriter.Injection>();
var localVariableRenamingRewriter = new LocalVariableRenamingRewriter(idGenerator, factory.Source.SemanticModel);
var factoryExpression = localVariableRenamingRewriter.Rewrite(factory.Source.Factory);
var factoryExpression = localVariableRenamingRewriter.Rewrite(originalLambda);
var factoryRewriter = new FactoryRewriter(arguments, compilations, factory, variable, finishLabel, injections);
var lambda = factoryRewriter.Rewrite(factoryExpression);
new FactoryValidator(factory).Validate(lambda);
Expand Down
6 changes: 4 additions & 2 deletions src/Pure.DI.Core/Core/DependencyGraphBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ private MdBinding CreateGenericBinding(
return sourceNode.Binding with
{
Id = newId,
TypeConstructor = typeConstructor,
Contracts = newContracts,
Implementation = sourceNode.Binding.Implementation.HasValue
? sourceNode.Binding.Implementation.Value with
Expand Down Expand Up @@ -518,9 +519,9 @@ private MdBinding CreateAutoBinding(
var semanticModel = targetNode.Binding.SemanticModel;
var compilation = semanticModel.Compilation;
var sourceType = injection.Type;
var typeConstructor = typeConstructorFactory();
if (marker.IsMarkerBased(setup, injection.Type))
{
var typeConstructor = typeConstructorFactory();
typeConstructor.TryBind(setup, injection.Type, injection.Type);
sourceType = typeConstructor.Construct(setup, compilation, injection.Type);
}
Expand All @@ -538,7 +539,8 @@ private MdBinding CreateAutoBinding(
newContracts,
newTags,
new MdLifetime(semanticModel, setup.Source, Lifetime.Transient),
new MdImplementation(semanticModel, setup.Source, sourceType));
new MdImplementation(semanticModel, setup.Source, sourceType),
TypeConstructor: typeConstructor);
return newBinding;
}

Expand Down
92 changes: 33 additions & 59 deletions src/Pure.DI.Core/Core/FactoryTypeRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public MdFactory Build(RewriterContext<MdFactory> context)
{
_context = context;
var factory = context.State;
var newFactory = (LambdaExpressionSyntax)Visit(factory.Factory);
var newFactory = (LambdaExpressionSyntax)Visit(factory.Factory)!;
return factory with
{
Type = context.TypeConstructor.Construct(context.Setup, factory.SemanticModel.Compilation, factory.Type),
Expand Down Expand Up @@ -42,77 +42,51 @@ public MdFactory Build(RewriterContext<MdFactory> context)
return default;
}

public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) =>
TryCreateTypeSyntax(node) ?? base.VisitIdentifierName(node);

public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node) =>
TryCreateTypeSyntax(node) ?? base.VisitQualifiedName(node);

private SyntaxNode? TryCreateTypeSyntax(SyntaxNode node) =>
TryGetNewTypeName(node, out var newTypeName)
? SyntaxFactory.ParseTypeName(newTypeName)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia())
: default(SyntaxNode?);

private bool TryGetNewTypeName(SyntaxNode? node, [NotNullWhen(true)] out string? newTypeName)
{
var identifier = base.VisitIdentifierName(node) as IdentifierNameSyntax;
if (identifier is null)
newTypeName = default;
if (node is null)
{
return identifier;
return false;
}

var semanticModel = _context.State.SemanticModel;
if (identifier.SyntaxTree != semanticModel.SyntaxTree)
if (semanticModel.GetSymbolInfo(node).Symbol is ITypeSymbol type)
{
return identifier;
return TryGetNewTypeName(type, true, out newTypeName);
}

var symbol = semanticModel.GetSymbolInfo(identifier).Symbol;
if (symbol is not ITypeSymbol type)
{
return identifier;
}

return false;
}

private bool TryGetNewTypeName(ITypeSymbol type, bool inTree, [NotNullWhen(true)] out string? newTypeName)
{
newTypeName = default;
if (!marker.IsMarkerBased(_context.Setup, type))
{
return identifier;
return false;
}

var newType = _context.TypeConstructor.Construct(_context.Setup, semanticModel.Compilation, type);
var newTypeName = typeResolver.Resolve(_context.Setup, newType).Name;
return node.WithIdentifier(
SyntaxFactory.Identifier(newTypeName))
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
}

public override SyntaxNode? VisitTypeArgumentList(TypeArgumentListSyntax node)
{
var newArgs = new List<TypeSyntax>();
var hasMarkerBased = false;
var semanticModel = _context.Setup.SemanticModel;
foreach (var arg in node.Arguments)
var newType = _context.TypeConstructor.Construct(_context.Setup, _context.State.SemanticModel.Compilation, type);
if (!inTree && SymbolEqualityComparer.Default.Equals(newType, type))
{
var typeName = arg.ToString();
var isFound = false;
foreach (var type in semanticModel.Compilation.GetTypesByMetadataName(typeName))
{
if (!marker.IsMarkerBased(_context.Setup, type))
{
newArgs.Add(arg);
isFound = true;
break;
}

hasMarkerBased = true;
var constructedType = _context.TypeConstructor.Construct(_context.Setup, semanticModel.Compilation, type);
if (SymbolEqualityComparer.Default.Equals(type, constructedType))
{
continue;
}

newArgs.Add(SyntaxFactory.ParseTypeName(constructedType.ToString()));
isFound = true;
break;
}

if (!isFound)
{
return base.VisitTypeArgumentList(node);
}
return false;
}

return hasMarkerBased
? SyntaxFactory.TypeArgumentList().AddArguments(newArgs.ToArray())
: base.VisitTypeArgumentList(node);

newTypeName = typeResolver.Resolve(_context.Setup, newType).Name;
return true;
}
}
3 changes: 2 additions & 1 deletion src/Pure.DI.Core/Core/Models/MdBinding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ internal record MdBinding(
in MdImplementation? Implementation = default,
in MdFactory? Factory = default,
in MdArg? Arg = default,
in MdConstruct? Construct = default)
in MdConstruct? Construct = default,
ITypeConstructor? TypeConstructor = default)
{
public override string ToString()
{
Expand Down
3 changes: 2 additions & 1 deletion src/Pure.DI.Core/Core/Models/MdFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ internal readonly record struct MdFactory(
LambdaExpressionSyntax Factory,
ParameterSyntax Context,
in ImmutableArray<MdResolver> Resolvers,
bool HasContextTag)
bool HasContextTag,
in MdResolver? MemberResolver = default)
{
public override string ToString() => $"To<{Type}>({Factory})";
}
Loading

0 comments on commit 12abcf2

Please sign in to comment.