Skip to content

Commit

Permalink
feature: support DbString in DapperAOT (#119)
Browse files Browse the repository at this point in the history
* start

* add dbstring example to query

* create test

* start processing

* support DbString!

* last checks

* regenerate for netfx

* fix DAP44 error appeared in refactor

* pre-review fixes

* change DbString configuration method declaration

* address initial comments

* DbStringHelpers to file class

* include `file` via preprocessor directive

* only in dapper AOT

* move to Dapper.AOT.Test.Integration.csproj

* setup the infra for test

* test is executed targeting docker db

* remove debug

* some other tries

* Revert "some other tries"

This reverts commit a95b6ff.

* Revert "remove debug"

This reverts commit 44e37d0.

* Revert "test is executed targeting docker db"

This reverts commit c5a4fee.

* Revert "setup the infra for test"

This reverts commit 63fd1a9.

* Revert "move to Dapper.AOT.Test.Integration.csproj"

This reverts commit e804648.

* address PR comments

* adjust a test

* use global:: for DapperSpecialType
  • Loading branch information
DeagleGross authored Jul 26, 2024
1 parent 42778fa commit 21b383a
Show file tree
Hide file tree
Showing 22 changed files with 927 additions and 57 deletions.
44 changes: 44 additions & 0 deletions docs/rules/DAP048.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# DAP048

[DbString](https://github.com/DapperLib/Dapper/blob/main/Dapper/DbString.cs) causes heap allocations, but achieves the same as
[DbValueAttribute](https://github.com/DapperLib/DapperAOT/blob/main/src/Dapper.AOT/DbValueAttribute.cs).

Bad:

``` c#
public void DapperCode(DbConnection conn)
{
var sql = "SELECT COUNT(*) FROM Foo WHERE Name = @Name;";
var cars = conn.Query<int>(sql,
new
{
Name = new DbString
{
Value = "MyFoo",
IsFixedLength = false,
Length = 5,
IsAnsi = true
}
});
}
```

Good:

``` c#
public void DapperCode(DbConnection conn)
{
var sql = "SELECT COUNT(*) FROM Foo WHERE Name = @Name;";
var cars = conn.Query<int>(sql,
new MyPoco
{
Name = "MyFoo"
});
}

class MyPoco
{
[DbValue(Length = 5, DbType = DbType.AnsiStringFixedLength)] // specify properties here
public string Name { get; set; }
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public static readonly DiagnosticDescriptor
CancellationDuplicated = LibraryWarning("DAP045", "Duplicate cancellation", "Multiple parameter values cannot define cancellation"),
AmbiguousProperties = LibraryWarning("DAP046", "Ambiguous properties", "Properties have same name '{0}' after normalization and can be conflated"),
AmbiguousFields = LibraryWarning("DAP047", "Ambiguous fields", "Fields have same name '{0}' after normalization and can be conflated"),
MoveFromDbString = LibraryWarning("DAP048", "Move from DbString to DbValue", "DbString achieves the same as [DbValue] does. Use it instead."),

// SQL parse specific
GeneralSqlError = SqlWarning("DAP200", "SQL error", "SQL error: {0}"),
Expand Down
77 changes: 48 additions & 29 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private void ValidateDapperMethod(in OperationAnalysisContext ctx, IOperation sq
}

// check the types
var resultType = invoke.GetResultType(flags);
var resultType = invoke.GetResultType(flags);
if (resultType is not null && IdentifyDbType(resultType, out _) is null) // don't warn if handled as an inbuilt
{
var resultMap = MemberMap.CreateForResults(resultType, location);
Expand Down Expand Up @@ -250,34 +250,9 @@ private void ValidateDapperMethod(in OperationAnalysisContext ctx, IOperation sq
{
_ = AdditionalCommandState.Parse(GetSymbol(parseState, invoke), parameters, onDiagnostic);
}
if (parameters is not null)
{
if (flags.HasAny(OperationFlags.DoNotGenerate)) // using vanilla Dapper mode
{
if (parameters.Members.Any(s => s.IsCancellation) || IsCancellationToken(parameters.ElementType))
{
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.CancellationNotSupported, parameters.Location));
}
}
else
{
bool first = true;
foreach(var member in parameters.Members)
{
if (member.IsCancellation)
{
if (first)
{
first = false;
}
else
{
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.CancellationDuplicated, member.GetLocation()));
}
}
}
}
}

ValidateParameters(parameters, flags, onDiagnostic);

var args = SharedGetParametersToInclude(parameters, ref flags, sql, onDiagnostic, out var parseFlags);

ValidateSql(ctx, sqlSource, GetModeFlags(flags), SqlParameters.From(args), location);
Expand Down Expand Up @@ -850,6 +825,50 @@ enum ParameterMode
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps);
}

static void ValidateParameters(MemberMap? parameters, OperationFlags flags, Action<Diagnostic> onDiagnostic)
{
if (parameters is null) return;

var usingVanillaDapperMode = flags.HasAny(OperationFlags.DoNotGenerate); // using vanilla Dapper mode
if (usingVanillaDapperMode)
{
if (parameters.Members.Any(s => s.IsCancellation) || IsCancellationToken(parameters.ElementType))
{
onDiagnostic(Diagnostic.Create(Diagnostics.CancellationNotSupported, parameters.Location));
}
}

var isFirstCancellation = true;
foreach (var member in parameters.Members)
{
ValidateCancellationTokenParameter(member);
ValidateDbStringParameter(member);
}

void ValidateDbStringParameter(ElementMember member)
{
if (usingVanillaDapperMode)
{
// reporting ONLY in Dapper AOT
return;
}

if (member.DapperSpecialType == DapperSpecialType.DbString)
{
onDiagnostic(Diagnostic.Create(Diagnostics.MoveFromDbString, member.GetLocation()));
}
}

void ValidateCancellationTokenParameter(ElementMember member)
{
if (!usingVanillaDapperMode && member.IsCancellation)
{
if (isFirstCancellation) isFirstCancellation = false;
else onDiagnostic(Diagnostic.Create(Diagnostics.CancellationDuplicated, member.GetLocation()));
}
}
}

static void ValidateMembers(MemberMap memberMap, Action<Diagnostic> onDiagnostic)
{
if (memberMap.Members.Length == 0)
Expand Down
48 changes: 36 additions & 12 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Runtime.InteropServices.ComTypes;
using System.Text;
using System.Threading;
using static Dapper.Internal.Inspection;

namespace Dapper.CodeAnalysis;

Expand Down Expand Up @@ -423,16 +425,15 @@ internal void Generate(in GenerateState ctx)
WriteRowFactory(ctx, sb, pair.Type, pair.Index);
}


foreach (var tuple in factories)
{
WriteCommandFactory(ctx, baseCommandFactory, sb, tuple.Type, tuple.Index, tuple.Map, tuple.CacheCount, tuple.AdditionalCommandState);
}

sb.Outdent().Outdent(); // ends our generated file-scoped class and the namespace

var interceptsLocationWriter = new InterceptorsLocationAttributeWriter(sb);
interceptsLocationWriter.Write(ctx.Compilation);
var preGeneratedCodeWriter = new PreGeneratedCodeWriter(sb, ctx.Compilation);
preGeneratedCodeWriter.Write(ctx.GeneratorContext.IncludedGenerationTypes);

ctx.AddSource((ctx.Compilation.AssemblyName ?? "package") + ".generated.cs", sb.ToString());
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.InterceptorsGenerated, null, callSiteCount, ctx.Nodes.Length, methodIndex, factories.Count(), readers.Count()));
Expand Down Expand Up @@ -490,11 +491,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
else
{
sb.Append("public override void AddParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.Add, map, ref flags);
WriteArgs(in ctx, type, sb, WriteArgsMode.Add, map, ref flags);
sb.Outdent().NewLine();

sb.Append("public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.Update, map, ref flags);
WriteArgs(in ctx, type, sb, WriteArgsMode.Update, map, ref flags);
sb.Outdent().NewLine();

if ((flags & (WriteArgsFlags.NeedsRowCount | WriteArgsFlags.NeedsPostProcess)) != 0)
Expand All @@ -507,11 +508,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
sb.Append("public override void PostProcess(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args, int rowCount)").Indent().NewLine();
if ((flags & WriteArgsFlags.NeedsPostProcess) != 0)
{
WriteArgs(type, sb, WriteArgsMode.PostProcess, map, ref flags);
WriteArgs(in ctx, type, sb, WriteArgsMode.PostProcess, map, ref flags);
}
if ((flags & WriteArgsFlags.NeedsRowCount) != 0)
{
WriteArgs(type, sb, WriteArgsMode.SetRowCount, map, ref flags);
WriteArgs(in ctx, type, sb, WriteArgsMode.SetRowCount, map, ref flags);
}
if (baseFactory != DapperBaseCommandFactory)
{
Expand All @@ -524,7 +525,7 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
{
sb.Append("public override global::System.Threading.CancellationToken GetCancellationToken(").Append(declaredType).Append(" args)")
.Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.GetCancellationToken, map, ref flags);
WriteArgs(in ctx, type, sb, WriteArgsMode.GetCancellationToken, map, ref flags);
sb.Outdent().NewLine();
}
}
Expand Down Expand Up @@ -966,7 +967,7 @@ enum WriteArgsMode
GetCancellationToken
}

private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags)
private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags)
{
if (parameterType is null)
{
Expand Down Expand Up @@ -1073,8 +1074,19 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr
switch (mode)
{
case WriteArgsMode.Add:
sb.Append("p = cmd.CreateParameter();").NewLine()
.Append("p.ParameterName = ").AppendVerbatimLiteral(member.DbName).Append(";").NewLine();
sb.Append("p = cmd.CreateParameter();").NewLine();
sb.Append("p.ParameterName = ").AppendVerbatimLiteral(member.DbName).Append(";").NewLine();

if (member.DapperSpecialType is DapperSpecialType.DbString)
{
ctx.GeneratorContext.IncludeGenerationType(IncludedGeneration.DbStringHelpers);

sb.Append("global::Dapper.Aot.Generated.DbStringHelpers.ConfigureDbStringDbParameter(p, ")
.Append(source).Append(".").Append(member.DbName).Append(");").NewLine();

sb.Append("ps.Add(p);").NewLine(); // dont forget to add parameter to command parameters collection
break;
}

var dbType = member.GetDbType(out _);
var size = member.TryGetValue<int>("Size");
Expand Down Expand Up @@ -1149,6 +1161,18 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr
}
break;
case WriteArgsMode.Update:
if (member.DapperSpecialType is DapperSpecialType.DbString)
{
ctx.GeneratorContext.IncludeGenerationType(IncludedGeneration.DbStringHelpers);

sb.Append("global::Dapper.Aot.Generated.DbStringHelpers.ConfigureDbStringDbParameter")
.Append("(ps[").Append(parameterIndex).Append("], ")
.Append(source).Append(".").Append(member.CodeName)
.Append(");").NewLine();

break;
}

sb.Append("ps[");
if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName);
else sb.Append(parameterIndex);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Dapper.CodeAnalysis.Extensions
{
internal static class IncludedGenerationExtensions
{
public static bool HasAny(this IncludedGeneration value, IncludedGeneration flag) => (value & flag) != 0;
}
}
28 changes: 28 additions & 0 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/GeneratorContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
namespace Dapper.CodeAnalysis
{
/// <summary>
/// Contains data about current generation run.
/// </summary>
internal class GeneratorContext
{
/// <summary>
/// Specifies which generation types should be included in the output.
/// </summary>
public IncludedGeneration IncludedGenerationTypes { get; private set; }

public GeneratorContext()
{
// set default included generation types here
IncludedGenerationTypes = IncludedGeneration.InterceptsLocationAttribute;
}

/// <summary>
/// Adds another generation type to the list of already included types.
/// </summary>
/// <param name="anotherType">another generation type to include in the output</param>
public void IncludeGenerationType(IncludedGeneration anotherType)
{
IncludedGenerationTypes |= anotherType;
}
}
}
1 change: 1 addition & 0 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public GenerateState(SourceProductionContext ctx, in (Compilation Compilation, I
private readonly GenerateContextProxy? proxy;
public readonly ImmutableArray<SourceState> Nodes;
public readonly Compilation Compilation;
public readonly GeneratorContext GeneratorContext = new();

internal void ReportDiagnostic(Diagnostic diagnostic)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ void ReportDiagnosticInUsages(DiagnosticDescriptor diagnosticDescriptor)
}
});

var interceptsLocWriter = new InterceptorsLocationAttributeWriter(codeWriter);
interceptsLocWriter.Write(state.Compilation);
var preGenerator = new PreGeneratedCodeWriter(codeWriter, state.Compilation);
preGenerator.Write(IncludedGeneration.InterceptsLocationAttribute);

context.AddSource((state.Compilation.AssemblyName ?? "package") + ".generated.cs", sb.GetSourceText());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
using Dapper.Internal;
using Dapper.CodeAnalysis.Extensions;
using Dapper.Internal;
using Microsoft.CodeAnalysis;

namespace Dapper.CodeAnalysis.Writers
{
internal struct InterceptorsLocationAttributeWriter
internal struct PreGeneratedCodeWriter
{
readonly Compilation _compilation;
readonly CodeWriter _codeWriter;

public InterceptorsLocationAttributeWriter(CodeWriter codeWriter)
public PreGeneratedCodeWriter(
CodeWriter codeWriter,
Compilation compilation)
{
_codeWriter = codeWriter;
_compilation = compilation;
}

/// <summary>
/// Writes the "InterceptsLocationAttribute" to inner <see cref="CodeWriter"/>.
/// </summary>
/// <remarks>Does so only when "InterceptsLocationAttribute" is NOT visible by <see cref="Compilation"/>.</remarks>
public void Write(Compilation compilation)
public void Write(IncludedGeneration includedGenerations)
{
var attrib = compilation.GetTypeByMetadataName("System.Runtime.CompilerServices.InterceptsLocationAttribute");
if (!IsAvailable(attrib, compilation))
if (includedGenerations.HasAny(IncludedGeneration.InterceptsLocationAttribute))
{
_codeWriter.NewLine().Append(Resources.ReadString("Dapper.InterceptsLocationAttribute.cs"));
WriteInterceptsLocationAttribute();
}

if (includedGenerations.HasAny(IncludedGeneration.DbStringHelpers))
{
_codeWriter.NewLine().Append(Resources.ReadString("Dapper.InGeneration.DapperHelpers.cs"));
}
}

void WriteInterceptsLocationAttribute()
{
var attrib = _compilation.GetTypeByMetadataName("System.Runtime.CompilerServices.InterceptsLocationAttribute");
if (!IsAvailable(attrib, _compilation))
{
_codeWriter.NewLine().Append(Resources.ReadString("Dapper.InGeneration.InterceptsLocationAttribute.cs"));
}

static bool IsAvailable(INamedTypeSymbol? type, Compilation compilation)
Expand Down
5 changes: 3 additions & 2 deletions src/Dapper.AOT.Analyzers/Dapper.AOT.Analyzers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

<ItemGroup>
<Compile Remove="AotGridReader.cs" />
<Compile Remove="InterceptsLocationAttribute.cs" />
<EmbeddedResource Include="AotGridReader.cs" />
<EmbeddedResource Include="InterceptsLocationAttribute.cs" />

<Compile Remove="InGeneration\*.cs" />
<EmbeddedResource Include="InGeneration\*.cs" />

<Compile Update="CodeAnalysis/DapperInterceptorGenerator.*.cs">
<DependentUpon>DapperInterceptorGenerator.cs</DependentUpon>
Expand Down
Loading

0 comments on commit 21b383a

Please sign in to comment.