Skip to content

Commit

Permalink
- implement GetRowParser (#69)
Browse files Browse the repository at this point in the history
- general cleanup
  • Loading branch information
mgravell authored Nov 13, 2023
1 parent bab9f85 commit 9bb613d
Show file tree
Hide file tree
Showing 26 changed files with 695 additions and 174 deletions.
6 changes: 6 additions & 0 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
argExpression = null;
sql = null;
bool? buffered = null;

// check the args
foreach (var arg in op.Arguments)
{
Expand Down Expand Up @@ -524,6 +525,11 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
case "cnn":
case "commandTimeout":
case "transaction":
case "reader":
case "startIndex":
case "length":
case "returnNullIfFirstMissing":
case "concreteType" when arg.Value is IDefaultValueOperation || (arg.ConstantValue.HasValue && arg.ConstantValue.Value is null):
// nothing to do
break;
case "commandType":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
}
sb.Append(");");
sb.NewLine().Outdent().NewLine().NewLine();
sb.Append(");").NewLine();
}

void WriteBatchCommandArguments(ITypeSymbol elementType)
Expand Down Expand Up @@ -91,7 +90,7 @@ void WriteBatchCommandArguments(ITypeSymbol elementType)
// commandFactory
if (flags.HasAny(OperationFlags.HasParameters))
{
var index = factories.GetIndex(elementType, map, cache, true, additionalCommandState, out var subIndex);
var index = factories.GetIndex(elementType, map, cache, additionalCommandState, out var subIndex);
sb.Append("CommandFactory").Append(index).Append(".Instance").Append(subIndex);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Dapper.CodeAnalysis;

public sealed partial class DapperInterceptorGenerator
{
void WriteSingleImplementation(
static void WriteSingleImplementation(
CodeWriter sb,
IMethodSymbol method,
ITypeSymbol? resultType,
Expand Down Expand Up @@ -54,7 +54,7 @@ void WriteSingleImplementation(
sb.Append(", ").Append(Forward(methodParameters, "commandTimeout")).Append(HasParam(methodParameters, "commandTimeout") ? ".GetValueOrDefault()" : "").Append(", ");
if (flags.HasAny(OperationFlags.HasParameters))
{
var index = factories.GetIndex(parameterType!, map, cache, false, additionalCommandState, out var subIndex);
var index = factories.GetIndex(parameterType!, map, cache, additionalCommandState, out var subIndex);
sb.Append("CommandFactory").Append(index).Append(".Instance").Append(subIndex);
}
else
Expand Down Expand Up @@ -92,83 +92,7 @@ void WriteSingleImplementation(
break;
}
}
if (IsInbuilt(resultType, out var helper))
{
sb.Append("global::Dapper.RowFactory.Inbuilt.").Append(helper);
}
else
{
sb.Append("RowFactory").Append(readers.GetIndex(resultType!)).Append(".Instance");
}

static bool IsInbuilt(ITypeSymbol? type, out string? helper)
{
if (type is null || type.TypeKind == TypeKind.Dynamic)
{
helper = "Dynamic";
return true;
}
if (type.SpecialType == SpecialType.System_Object)
{
helper = "Object";
return true;
}
if (Inspection.IdentifyDbType(type, out _) is not null)
{
bool nullable = type.IsValueType && type.NullableAnnotation == NullableAnnotation.Annotated;
helper = (nullable ? "NullableValue<" : "Value<") + CodeWriter.GetTypeName(
nullable ? Inspection.MakeNonNullable(type) : type) + ">()";
return true;
}
if (type is INamedTypeSymbol { Arity: 0 })
{
if (type is
{
TypeKind: TypeKind.Interface,
Name: "IDataRecord",
ContainingType: null,
ContainingNamespace:
{
Name: "Data",
ContainingNamespace:
{
Name: "System",
ContainingNamespace.IsGlobalNamespace: true
}
}
})
{
helper = "IDataRecord";
return true;
}
if (type is
{
TypeKind: TypeKind.Class,
Name: "DbDataRecord",
ContainingType: null,
ContainingNamespace:
{
Name: "Common",
ContainingNamespace:
{
Name: "Data",
ContainingNamespace:
{
Name: "System",
ContainingNamespace.IsGlobalNamespace: true
}
}
}
})
{
helper = "DbDataRecord";
return true;
}
}
helper = null;
return false;

}
sb.AppendReader(resultType, readers);
}
else if (flags.HasAny(OperationFlags.Execute))
{
Expand Down Expand Up @@ -227,7 +151,7 @@ static bool IsInbuilt(ITypeSymbol? type, out string? helper)
sb.Append("!");
}
}
sb.Append(";").NewLine().Outdent().NewLine().NewLine();
sb.Append(";").NewLine();

static CodeWriter WriteTypedArg(CodeWriter sb, ITypeSymbol? parameterType)
{
Expand Down
74 changes: 49 additions & 25 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex
}

// very fast and light-weight; we'll worry about the rest later from the semantic tree
internal static bool IsCandidate(string methodName) => methodName.StartsWith("Execute") || methodName.StartsWith("Query");
internal static bool IsCandidate(string methodName) =>
methodName.StartsWith("Execute")
|| methodName.StartsWith("Query")
|| methodName.StartsWith("GetRowParser");

internal bool PreFilter(SyntaxNode node, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -213,6 +216,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm
}

const string DapperBaseCommandFactory = "global::Dapper.CommandFactory";

[System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "Allow expectation of state")]
internal void Generate(in GenerateState ctx)
{
if (!CheckPrerequisites(ctx)) // also reports per-item diagnostics
Expand All @@ -221,7 +226,7 @@ internal void Generate(in GenerateState ctx)
return;
}

var dbCommandTypes = IdentifyDbCommandTypes(ctx.Compilation, out var needsCommandPrep, out bool frameworkHasBatchAPI);
var dbCommandTypes = IdentifyDbCommandTypes(ctx.Compilation, out var needsCommandPrep);

bool allowUnsafe = ctx.Compilation.Options is CSharpCompilationOptions cSharp && cSharp.AllowUnsafe;
var sb = new CodeWriter().Append("#nullable enable").NewLine()
Expand Down Expand Up @@ -293,18 +298,22 @@ internal void Generate(in GenerateState ctx)
var commandTypeMode = flags & (OperationFlags.Text | OperationFlags.StoredProcedure | OperationFlags.TableDirect);
var methodParameters = grp.Key.Method.Parameters;
string? fixedSql = null;
if (flags.HasAny(OperationFlags.IncludeLocation))
{
var origin = grp.Single();
fixedSql = origin.Sql; // expect exactly one SQL
sb.Append("global::System.Diagnostics.Debug.Assert(sql == ")
.AppendVerbatimLiteral(fixedSql).Append(");").NewLine();
var path = origin.Location.GetMappedLineSpan();
fixedSql = $"-- {path.Path}#{path.StartLinePosition.Line + 1}\r\n{fixedSql}";
}
else

if (HasParam(methodParameters, "sql"))
{
sb.Append("global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));").NewLine();
if (flags.HasAny(OperationFlags.IncludeLocation))
{
var origin = grp.Single();
fixedSql = origin.Sql; // expect exactly one SQL
sb.Append("global::System.Diagnostics.Debug.Assert(sql == ")
.AppendVerbatimLiteral(fixedSql).Append(");").NewLine();
var path = origin.Location.GetMappedLineSpan();
fixedSql = $"-- {path.Path}#{path.StartLinePosition.Line + 1}\r\n{fixedSql}";
}
else
{
sb.Append("global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));").NewLine();
}
}
if (HasParam(methodParameters, "commandType"))
{
Expand All @@ -320,12 +329,28 @@ internal void Generate(in GenerateState ctx)
sb.Append("global::System.Diagnostics.Debug.Assert(buffered is ").Append((flags & OperationFlags.Buffered) != 0).Append(");").NewLine();
}

sb.Append("global::System.Diagnostics.Debug.Assert(param is ").Append(flags.HasAny(OperationFlags.HasParameters) ? "not " : "").Append("null);").NewLine().NewLine();
if (HasParam(methodParameters, "param"))
{
sb.Append("global::System.Diagnostics.Debug.Assert(param is ").Append(flags.HasAny(OperationFlags.HasParameters) ? "not " : "").Append("null);").NewLine();
}

if (HasParam(methodParameters, "concreteType"))
{
sb.Append("global::System.Diagnostics.Debug.Assert(concreteType is null);").NewLine();
}

if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql, additionalCommandState))
sb.NewLine();

if (flags.HasAny(OperationFlags.GetRowParser))
{
WriteGetRowParser(sb, resultType, readers);
}
else if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql, additionalCommandState))
{
WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql, additionalCommandState);
}

sb.Outdent().NewLine().NewLine();
}

var baseCommandFactory = GetCommandFactory(ctx.Compilation, out var canConstruct) ?? DapperBaseCommandFactory;
Expand Down Expand Up @@ -384,7 +409,7 @@ internal void Generate(in GenerateState ctx)

foreach (var tuple in factories)
{
WriteCommandFactory(ctx, baseCommandFactory, sb, tuple.Type, tuple.Index, tuple.Map, tuple.CacheCount, tuple.AdditionalCommandState, tuple.SupportBatch && frameworkHasBatchAPI);
WriteCommandFactory(ctx, baseCommandFactory, sb, tuple.Type, tuple.Index, tuple.Map, tuple.CacheCount, tuple.AdditionalCommandState);
}

sb.Outdent().Outdent(); // ends our generated file-scoped class and the namespace
Expand All @@ -396,7 +421,13 @@ internal void Generate(in GenerateState ctx)
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.InterceptorsGenerated, null, callSiteCount, ctx.Nodes.Length, methodIndex, factories.Count(), readers.Count()));
}

private static void WriteCommandFactory(in GenerateState ctx, string baseFactory, CodeWriter sb, ITypeSymbol type, int index, string map, int cacheCount, AdditionalCommandState? additionalCommandState, bool supportBatch)
private static void WriteGetRowParser(CodeWriter sb, ITypeSymbol? resultType, in RowReaderState readers)
{
sb.Append("return ").AppendReader(resultType, readers)
.Append(".GetRowParser(reader, startIndex, length, returnNullIfFirstMissing);").NewLine();
}

private static void WriteCommandFactory(in GenerateState ctx, string baseFactory, CodeWriter sb, ITypeSymbol type, int index, string map, int cacheCount, AdditionalCommandState? additionalCommandState)
{
var declaredType = type.IsAnonymousType ? "object?" : CodeWriter.GetTypeName(type);
sb.Append("private ").Append(cacheCount <= 1 ? "sealed" : "abstract").Append(" class CommandFactory").Append(index).Append(" : ")
Expand Down Expand Up @@ -467,11 +498,6 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
}
sb.Outdent().NewLine();
}

if (supportBatch)
{
sb.Append("public override bool SupportBatch => true;").NewLine();
}
}

if ((flags & WriteArgsFlags.CanPrepare) != 0)
Expand Down Expand Up @@ -1174,17 +1200,15 @@ private enum SpecialCommandFlags
InitialLONGFetchSize = 1 << 1,
}

private static ImmutableArray<ITypeSymbol> IdentifyDbCommandTypes(Compilation compilation, out bool needsPrepare, out bool hasBatchAPI)
private static ImmutableArray<ITypeSymbol> IdentifyDbCommandTypes(Compilation compilation, out bool needsPrepare)
{
needsPrepare = false;
var dbCommand = compilation.GetTypeByMetadataName("System.Data.Common.DbCommand");
if (dbCommand is null)
{
// if we can't find DbCommand, we're out of luck
hasBatchAPI = false;
return ImmutableArray<ITypeSymbol>.Empty;
}
hasBatchAPI = compilation.GetTypeByMetadataName("System.Data.Common.DbBatch") is not null;

var pending = new Queue<INamespaceOrTypeSymbol>();
foreach (var assemblyName in compilation.References)
Expand Down
81 changes: 81 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,85 @@ public string ToStringRecycle()
Interlocked.Exchange(ref s_Spare, this);
return s;
}

internal CodeWriter AppendReader(ITypeSymbol? resultType, RowReaderState readers)
{
if (IsInbuilt(resultType, out var helper))
{
return Append("global::Dapper.RowFactory.Inbuilt.").Append(helper);
}
else
{
return Append("RowFactory").Append(readers.GetIndex(resultType!)).Append(".Instance");
}

static bool IsInbuilt(ITypeSymbol? type, out string? helper)
{
if (type is null || type.TypeKind == TypeKind.Dynamic)
{
helper = "Dynamic";
return true;
}
if (type.SpecialType == SpecialType.System_Object)
{
helper = "Object";
return true;
}
if (Inspection.IdentifyDbType(type, out _) is not null)
{
bool nullable = type.IsValueType && type.NullableAnnotation == NullableAnnotation.Annotated;
helper = (nullable ? "NullableValue<" : "Value<") + CodeWriter.GetTypeName(
nullable ? Inspection.MakeNonNullable(type) : type) + ">()";
return true;
}
if (type is INamedTypeSymbol { Arity: 0 })
{
if (type is
{
TypeKind: TypeKind.Interface,
Name: "IDataRecord",
ContainingType: null,
ContainingNamespace:
{
Name: "Data",
ContainingNamespace:
{
Name: "System",
ContainingNamespace.IsGlobalNamespace: true
}
}
})
{
helper = "IDataRecord";
return true;
}
if (type is
{
TypeKind: TypeKind.Class,
Name: "DbDataRecord",
ContainingType: null,
ContainingNamespace:
{
Name: "Common",
ContainingNamespace:
{
Name: "Data",
ContainingNamespace:
{
Name: "System",
ContainingNamespace.IsGlobalNamespace: true
}
}
}
})
{
helper = "DbDataRecord";
return true;
}
}
helper = null;
return false;

}
}
}
Loading

0 comments on commit 9bb613d

Please sign in to comment.