Skip to content

Commit

Permalink
Fix detection of raw IEnumerableT etc (#72)
Browse files Browse the repository at this point in the history
* ImplementsInterface should detect raw IEnumerable<T> etc

* CI
  • Loading branch information
mgravell authored Nov 15, 2023
1 parent f344fd0 commit 82fce34
Show file tree
Hide file tree
Showing 7 changed files with 574 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/dotnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: dotnet build Build.csproj --no-restore -c Release

- name: Test
run: dotnet test Build.csproj --no-build --verbosity normal -c Release -f net7.0
run: dotnet test Build.csproj --no-build --verbosity normal -c Release -f net6.0

- name: Pack
if: ${{ success() && !github.base_ref }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ private static bool ImplementsInterface(
searchedInterface = null;
return false;
}
if (typeSymbol.SpecialType == interfaceType || typeSymbol.OriginalDefinition?.SpecialType == interfaceType)
{
searchedInterface = typeSymbol;
return true;
}

if (searchFromStart)
{
Expand Down
152 changes: 152 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/Techempower.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@

using Dapper;
using Microsoft.Data.SqlClient;
using System.Data.Common;
using System.Text;
using System.Collections.Generic;
using System;
using System.Threading.Tasks;
using System.Linq;

[module: DapperAot]

var connectionString = new SqlConnectionStringBuilder
{
DataSource = ".",
InitialCatalog = "master",
IntegratedSecurity = true,
TrustServerCertificate = true,
}.ConnectionString;

var obj = new BenchRunner(connectionString, SqlClientFactory.Instance);
obj.Create();
await obj.LoadMultipleUpdatesRows(50);

class BenchRunner
{
private readonly string _connectionString;
private readonly DbProviderFactory _dbProviderFactory;
public BenchRunner(string connectionString, DbProviderFactory dbProviderFactory)
{
_dbProviderFactory = dbProviderFactory;
_connectionString = connectionString;
}

[DapperAot]
public void Create()
{
using var conn = _dbProviderFactory.CreateConnection();
conn.ConnectionString = _connectionString;

try { conn.Execute("DROP TABLE world;"); } catch { }
conn.Execute("CREATE TABLE world (id int not null primary key, randomNumber int not null);");
conn.Execute("TRUNCATE TABLE world;");
conn.Execute("INSERT World (id, randomNumber) VALUES (@Id, @RandomNumber)", Invent(10000));

static IEnumerable<World> Invent(int count)
{
var rand = GetRandom();
return Enumerable.Range(1, 10000).Select(i => new World { Id = i, RandomNumber = rand.Next() });
}
}

[DapperAot(false)] // dictionary usage isn't going to work today
public async Task<World[]> LoadMultipleUpdatesRows(int count)
{
count = Clamp(count, 1, 500);

var parameters = new Dictionary<string, object>();

using var db = _dbProviderFactory.CreateConnection();

db!.ConnectionString = _connectionString;
await db.OpenAsync();

var results = new World[count];
for (var i = 0; i < count; i++)
{
results[i] = await ReadSingleRow(db);
}

var rand = GetRandom();
for (var i = 0; i < count; i++)
{
var randomNumber = rand.Next(1, 10001);
parameters[$"@Rn_{i}"] = randomNumber;
parameters[$"@Id_{i}"] = results[i].Id;

results[i].RandomNumber = randomNumber;
}

await db.ExecuteAsync(BatchUpdateString.Query(count), parameters);
return results;
}

// note that this QueryFirstOrDefaultAsync<struct> is unusual, hence DAP038
// see: https://aot.dapperlib.dev/rules/DAP038
//
// options:
// 0. leave it as-is
// 1. use QueryFirstAsync and eat the exception if no rows
// 2 use QueryFirstOrDefaultAsync<World?> which allows `null` to be expressed
[System.Diagnostics.CodeAnalysis.SuppressMessage("Library", "DAP038:Value-type single row 'OrDefault' usage", Justification = "Retain old behaviour for baseline")]
static Task<World> ReadSingleRow(DbConnection db)
{
return db.QueryFirstOrDefaultAsync<World>(
"SELECT id, randomnumber FROM world WHERE id = @Id",
new { Id = GetRandom().Next(1, 10001) });
}

static Random GetRandom()
{
#if NET6_OR_GREATER
return Random.Shared;
#else
return new Random();
#endif
}

static int Clamp(int value, int min, int max)
{
#if NET6_OR_GREATER
return Math.Clamp(value, min, max);
#else
if (value < min) value = min;
if (value > max) value = max;
return value;
#endif
}
}

public struct World
{
public int Id { get; set; }

public int RandomNumber { get; set; }
}


internal class BatchUpdateString
{
private const int MaxBatch = 500;

private static readonly string[] _queries = new string[MaxBatch + 1];

public static string Query(int batchSize)
{
if (_queries[batchSize] != null)
{
return _queries[batchSize];
}

var lastIndex = batchSize - 1;

var sb = new StringBuilder();

sb.Append("UPDATE world SET randomNumber = temp.randomNumber FROM (VALUES ");
Enumerable.Range(0, lastIndex).ToList().ForEach(i => sb.Append($"(@Id_{i}, @Rn_{i}), "));
sb.Append($"(@Id_{lastIndex}, @Rn_{lastIndex}) ORDER BY 1) AS temp(id, randomNumber) WHERE temp.id = world.id");

return _queries[batchSize] = sb.ToString();
}
}
204 changes: 204 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/Techempower.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 41, 20)]
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 42, 14)]
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 43, 14)]
internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, Text
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).Execute(param);

}

[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 44, 14)]
internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::System.Collections.Generic.IEnumerable<global::World>
// parameter map: Id RandomNumber
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((global::System.Collections.Generic.IEnumerable<global::World>)param!);

}

[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\Techempower.input.cs", 95, 19)]
internal static global::System.Threading.Tasks.Task<global::World> QueryFirstOrDefaultAsync2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Query, Async, TypedResult, HasParameters, SingleRow, Text, BindResultsByName, KnownParameters
// takes parameter: <anonymous type: int Id>
// parameter map: Id
// returns data: global::World
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryFirstOrDefaultAsync(param, RowFactory0.Instance);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class RowFactory0 : global::Dapper.RowFactory<global::World>
{
internal static readonly RowFactory0 Instance = new();
private RowFactory0() {}
public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)
{
for (int i = 0; i < tokens.Length; i++)
{
int token = -1;
var name = reader.GetName(columnOffset);
var type = reader.GetFieldType(columnOffset);
switch (NormalizedHash(name))
{
case 926444256U when NormalizedEquals(name, "id"):
token = type == typeof(int) ? 0 : 2; // two tokens for right-typed and type-flexible
break;
case 843736943U when NormalizedEquals(name, "randomnumber"):
token = type == typeof(int) ? 1 : 3;
break;

}
tokens[i] = token;
columnOffset++;

}
return null;
}
public override global::World Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan<int> tokens, int columnOffset, object? state)
{
global::World result = new();
foreach (var token in tokens)
{
switch (token)
{
case 0:
result.Id = reader.GetInt32(columnOffset);
break;
case 2:
result.Id = GetValue<int>(reader, columnOffset);
break;
case 1:
result.RandomNumber = reader.GetInt32(columnOffset);
break;
case 3:
result.RandomNumber = GetValue<int>(reader, columnOffset);
break;

}
columnOffset++;

}
return result;

}

}

private sealed class CommandFactory0 : CommonCommandFactory<global::World>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::World args)
{
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.Id);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "RandomNumber";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.RandomNumber);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::World args)
{
var ps = cmd.Parameters;
ps[0].Value = AsValue(args.Id);
ps[1].Value = AsValue(args.RandomNumber);

}
public override bool CanPrepare => true;

}

private sealed class CommandFactory1 : CommonCommandFactory<object?> // <anonymous type: int Id>
{
internal static readonly CommandFactory1 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { Id = default(int) }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.Id);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { Id = default(int) }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.Id);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading

0 comments on commit 82fce34

Please sign in to comment.