diff --git a/.github/workflows/dotnet.yml b/.github/workflows/dotnet.yml index 6c6c4ebb..0e15dd97 100644 --- a/.github/workflows/dotnet.yml +++ b/.github/workflows/dotnet.yml @@ -36,7 +36,7 @@ jobs: run: dotnet build Build.csproj --no-restore -c Release - name: Test - run: dotnet test Build.csproj --no-build --verbosity normal -c Release -f net7.0 + run: dotnet test Build.csproj --no-build --verbosity normal -c Release -f net6.0 - name: Pack if: ${{ success() && !github.base_ref }} diff --git a/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs b/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs index d4b68d18..1b1583ca 100644 --- a/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs +++ b/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs @@ -216,6 +216,11 @@ private static bool ImplementsInterface( searchedInterface = null; return false; } + if (typeSymbol.SpecialType == interfaceType || typeSymbol.OriginalDefinition?.SpecialType == interfaceType) + { + searchedInterface = typeSymbol; + return true; + } if (searchFromStart) { diff --git a/test/Dapper.AOT.Test/Interceptors/Techempower.input.cs b/test/Dapper.AOT.Test/Interceptors/Techempower.input.cs new file mode 100644 index 00000000..d2954839 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/Techempower.input.cs @@ -0,0 +1,152 @@ + +using Dapper; +using Microsoft.Data.SqlClient; +using System.Data.Common; +using System.Text; +using System.Collections.Generic; +using System; +using System.Threading.Tasks; +using System.Linq; + +[module: DapperAot] + +var connectionString = new SqlConnectionStringBuilder +{ + DataSource = ".", + InitialCatalog = "master", + IntegratedSecurity = true, + TrustServerCertificate = true, +}.ConnectionString; + +var obj = new BenchRunner(connectionString, SqlClientFactory.Instance); +obj.Create(); +await obj.LoadMultipleUpdatesRows(50); + +class BenchRunner +{ + private readonly string _connectionString; + private readonly DbProviderFactory _dbProviderFactory; + public BenchRunner(string connectionString, DbProviderFactory dbProviderFactory) + { + _dbProviderFactory = dbProviderFactory; + _connectionString = connectionString; + } + + [DapperAot] + public void Create() + { + using var conn = _dbProviderFactory.CreateConnection(); + conn.ConnectionString = _connectionString; + + try { conn.Execute("DROP TABLE world;"); } catch { } + conn.Execute("CREATE TABLE world (id int not null primary key, randomNumber int not null);"); + conn.Execute("TRUNCATE TABLE world;"); + conn.Execute("INSERT World (id, randomNumber) VALUES (@Id, @RandomNumber)", Invent(10000)); + + static IEnumerable Invent(int count) + { + var rand = GetRandom(); + return Enumerable.Range(1, 10000).Select(i => new World { Id = i, RandomNumber = rand.Next() }); + } + } + + [DapperAot(false)] // dictionary usage isn't going to work today + public async Task LoadMultipleUpdatesRows(int count) + { + count = Clamp(count, 1, 500); + + var parameters = new Dictionary(); + + using var db = _dbProviderFactory.CreateConnection(); + + db!.ConnectionString = _connectionString; + await db.OpenAsync(); + + var results = new World[count]; + for (var i = 0; i < count; i++) + { + results[i] = await ReadSingleRow(db); + } + + var rand = GetRandom(); + for (var i = 0; i < count; i++) + { + var randomNumber = rand.Next(1, 10001); + parameters[$"@Rn_{i}"] = randomNumber; + parameters[$"@Id_{i}"] = results[i].Id; + + results[i].RandomNumber = randomNumber; + } + + await db.ExecuteAsync(BatchUpdateString.Query(count), parameters); + return results; + } + + // note that this QueryFirstOrDefaultAsync is unusual, hence DAP038 + // see: https://aot.dapperlib.dev/rules/DAP038 + // + // options: + // 0. leave it as-is + // 1. use QueryFirstAsync and eat the exception if no rows + // 2 use QueryFirstOrDefaultAsync which allows `null` to be expressed + [System.Diagnostics.CodeAnalysis.SuppressMessage("Library", "DAP038:Value-type single row 'OrDefault' usage", Justification = "Retain old behaviour for baseline")] + static Task ReadSingleRow(DbConnection db) + { + return db.QueryFirstOrDefaultAsync( + "SELECT id, randomnumber FROM world WHERE id = @Id", + new { Id = GetRandom().Next(1, 10001) }); + } + + static Random GetRandom() + { +#if NET6_OR_GREATER + return Random.Shared; +#else + return new Random(); +#endif + } + + static int Clamp(int value, int min, int max) + { +#if NET6_OR_GREATER + return Math.Clamp(value, min, max); +#else + if (value < min) value = min; + if (value > max) value = max; + return value; +#endif + } +} + +public struct World +{ + public int Id { get; set; } + + public int RandomNumber { get; set; } +} + + +internal class BatchUpdateString +{ + private const int MaxBatch = 500; + + private static readonly string[] _queries = new string[MaxBatch + 1]; + + public static string Query(int batchSize) + { + if (_queries[batchSize] != null) + { + return _queries[batchSize]; + } + + var lastIndex = batchSize - 1; + + var sb = new StringBuilder(); + + sb.Append("UPDATE world SET randomNumber = temp.randomNumber FROM (VALUES "); + Enumerable.Range(0, lastIndex).ToList().ForEach(i => sb.Append($"(@Id_{i}, @Rn_{i}), ")); + sb.Append($"(@Id_{lastIndex}, @Rn_{lastIndex}) ORDER BY 1) AS temp(id, randomNumber) WHERE temp.id = world.id"); + + return _queries[batchSize] = sb.ToString(); + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/Techempower.output.cs b/test/Dapper.AOT.Test/Interceptors/Techempower.output.cs new file mode 100644 index 00000000..e2d2aa19 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/Techempower.output.cs @@ -0,0 +1,204 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 41, 20)] + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 42, 14)] + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 43, 14)] + internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, Text + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).Execute(param); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 44, 14)] + internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::System.Collections.Generic.IEnumerable + // parameter map: Id RandomNumber + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((global::System.Collections.Generic.IEnumerable)param!); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 95, 19)] + internal static global::System.Threading.Tasks.Task QueryFirstOrDefaultAsync2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, Async, TypedResult, HasParameters, SingleRow, Text, BindResultsByName, KnownParameters + // takes parameter: + // parameter map: Id + // returns data: global::World + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryFirstOrDefaultAsync(param, RowFactory0.Instance); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 926444256U when NormalizedEquals(name, "id"): + token = type == typeof(int) ? 0 : 2; // two tokens for right-typed and type-flexible + break; + case 843736943U when NormalizedEquals(name, "randomnumber"): + token = type == typeof(int) ? 1 : 3; + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::World Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::World result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.Id = reader.GetInt32(columnOffset); + break; + case 2: + result.Id = GetValue(reader, columnOffset); + break; + case 1: + result.RandomNumber = reader.GetInt32(columnOffset); + break; + case 3: + result.RandomNumber = GetValue(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::World args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.Id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "RandomNumber"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.RandomNumber); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::World args) + { + var ps = cmd.Parameters; + ps[0].Value = AsValue(args.Id); + ps[1].Value = AsValue(args.RandomNumber); + + } + public override bool CanPrepare => true; + + } + + private sealed class CommandFactory1 : CommonCommandFactory // + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Id = default(int) }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.Id); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Id = default(int) }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.Id); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.cs new file mode 100644 index 00000000..e2d2aa19 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.cs @@ -0,0 +1,204 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 41, 20)] + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 42, 14)] + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 43, 14)] + internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, Text + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).Execute(param); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 44, 14)] + internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::System.Collections.Generic.IEnumerable + // parameter map: Id RandomNumber + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((global::System.Collections.Generic.IEnumerable)param!); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 95, 19)] + internal static global::System.Threading.Tasks.Task QueryFirstOrDefaultAsync2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, Async, TypedResult, HasParameters, SingleRow, Text, BindResultsByName, KnownParameters + // takes parameter: + // parameter map: Id + // returns data: global::World + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryFirstOrDefaultAsync(param, RowFactory0.Instance); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 926444256U when NormalizedEquals(name, "id"): + token = type == typeof(int) ? 0 : 2; // two tokens for right-typed and type-flexible + break; + case 843736943U when NormalizedEquals(name, "randomnumber"): + token = type == typeof(int) ? 1 : 3; + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::World Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::World result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.Id = reader.GetInt32(columnOffset); + break; + case 2: + result.Id = GetValue(reader, columnOffset); + break; + case 1: + result.RandomNumber = reader.GetInt32(columnOffset); + break; + case 3: + result.RandomNumber = GetValue(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::World args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.Id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "RandomNumber"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.RandomNumber); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::World args) + { + var ps = cmd.Parameters; + ps[0].Value = AsValue(args.Id); + ps[1].Value = AsValue(args.RandomNumber); + + } + public override bool CanPrepare => true; + + } + + private sealed class CommandFactory1 : CommonCommandFactory // + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Id = default(int) }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.Id); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Id = default(int) }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.Id); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.txt b/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.txt new file mode 100644 index 00000000..4919c011 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/Techempower.output.netfx.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 5 of 5 possible call-sites using 3 interceptors, 2 commands and 1 readers diff --git a/test/Dapper.AOT.Test/Interceptors/Techempower.output.txt b/test/Dapper.AOT.Test/Interceptors/Techempower.output.txt new file mode 100644 index 00000000..4919c011 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/Techempower.output.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 5 of 5 possible call-sites using 3 interceptors, 2 commands and 1 readers