Skip to content

Commit

Permalink
support inherited members (#106)
Browse files Browse the repository at this point in the history
fix #93
  • Loading branch information
mgravell authored Dec 18, 2023
1 parent f439eb0 commit 20659c1
Show file tree
Hide file tree
Showing 6 changed files with 502 additions and 65 deletions.
139 changes: 74 additions & 65 deletions src/Dapper.AOT.Analyzers/Internal/Inspection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -802,85 +802,94 @@ internal static ImmutableArray<ElementMember> GetMembers(bool forParameters, ITy
}
else
{
var elMembers = elementType.GetMembers();
var builder = ImmutableArray.CreateBuilder<ElementMember>(elMembers.Length);
var constructorParameters = (constructor is not null) ? ParseMethodParameters(constructor) : null;
var factoryMethodParameters = (factoryMethod is not null) ? ParseMethodParameters(factoryMethod) : null;
foreach (var member in elMembers)
var tier = elementType;

var builder = ImmutableArray.CreateBuilder<ElementMember>();
while (tier is not null or IErrorTypeSymbol)
{
// instance only, must be able to access by name
if (member.IsStatic || !member.CanBeReferencedByName) continue;
var elMembers = tier.GetMembers(); // walk hierarchy model

// public or annotated only; not explicitly ignored
var dbValue = GetDapperAttribute(member, Types.DbValueAttribute);
var kind = ElementMemberKind.None;
if (GetDapperAttribute(member, Types.RowCountAttribute) is not null)
{
kind |= ElementMemberKind.RowCount;
}
if (GetDapperAttribute(member, Types.RowCountHintAttribute) is not null)
var constructorParameters = (constructor is not null) ? ParseMethodParameters(constructor) : null;
var factoryMethodParameters = (factoryMethod is not null) ? ParseMethodParameters(factoryMethod) : null;
foreach (var member in elMembers)
{
kind |= ElementMemberKind.RowCountHint;
}
// instance only, must be able to access by name
if (member.IsStatic || !member.CanBeReferencedByName) continue;

if (dbValue is null && member.DeclaredAccessibility != Accessibility.Public && kind == ElementMemberKind.None) continue;
if (TryGetAttributeValue(dbValue, "Ignore", out bool ignore) && ignore)
{
continue;
}
// public or annotated only; not explicitly ignored
var dbValue = GetDapperAttribute(member, Types.DbValueAttribute);
var kind = ElementMemberKind.None;
if (GetDapperAttribute(member, Types.RowCountAttribute) is not null)
{
kind |= ElementMemberKind.RowCount;
}
if (GetDapperAttribute(member, Types.RowCountHintAttribute) is not null)
{
kind |= ElementMemberKind.RowCountHint;
}

// field or property (not indexer)
ITypeSymbol memberType;
switch (member)
{
case IPropertySymbol { IsIndexer: false } prop:
memberType = prop.Type;
break;
case IFieldSymbol field:
memberType = field.Type;
break;
default:
if (dbValue is null && member.DeclaredAccessibility != Accessibility.Public && kind == ElementMemberKind.None) continue;
if (TryGetAttributeValue(dbValue, "Ignore", out bool ignore) && ignore)
{
continue;
}
if (memberType is null) continue;
}

int? constructorParameterOrder = constructorParameters?.TryGetValue(member.Name, out var constructorParameter) == true
? constructorParameter.Order
: null;
// field or property (not indexer)
ITypeSymbol memberType;
switch (member)
{
case IPropertySymbol { IsIndexer: false } prop:
memberType = prop.Type;
break;
case IFieldSymbol field:
memberType = field.Type;
break;
default:
continue;
}
if (memberType is null) continue;

int? factoryMethodParamOrder = factoryMethodParameters?.TryGetValue(member.Name, out var factoryMethodParam) == true
? factoryMethodParam.Order
: null;
int? constructorParameterOrder = constructorParameters?.TryGetValue(member.Name, out var constructorParameter) == true
? constructorParameter.Order
: null;

ElementMember.ElementMemberFlags flags = ElementMember.ElementMemberFlags.None;
if (CodeWriter.IsGettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsGettable;
if (CodeWriter.IsSettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsSettable;
if (CodeWriter.IsInitOnlyInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsInitOnly;
if (CodeWriter.IsRequired(member)) flags |= ElementMember.ElementMemberFlags.IsRequired;
int? factoryMethodParamOrder = factoryMethodParameters?.TryGetValue(member.Name, out var factoryMethodParam) == true
? factoryMethodParam.Order
: null;

if (forParameters)
{
// needs to be readable
if ((flags & ElementMember.ElementMemberFlags.IsGettable) == 0) continue;
}
else
{
// needs to be writable
if (constructorParameterOrder is null && factoryMethodParamOrder is null &&
(flags & (ElementMember.ElementMemberFlags.IsSettable | ElementMember.ElementMemberFlags.IsInitOnly)) == 0) continue;
}
ElementMember.ElementMemberFlags flags = ElementMember.ElementMemberFlags.None;
if (CodeWriter.IsGettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsGettable;
if (CodeWriter.IsSettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsSettable;
if (CodeWriter.IsInitOnlyInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsInitOnly;
if (CodeWriter.IsRequired(member)) flags |= ElementMember.ElementMemberFlags.IsRequired;

// see Dapper's TryStringSplit logic
if (IsCollectionType(memberType, out var innerType) && innerType is not null)
{
flags |= ElementMember.ElementMemberFlags.IsExpandable;
}
if (forParameters)
{
// needs to be readable
if ((flags & ElementMember.ElementMemberFlags.IsGettable) == 0) continue;
}
else
{
// needs to be writable
if (constructorParameterOrder is null && factoryMethodParamOrder is null &&
(flags & (ElementMember.ElementMemberFlags.IsSettable | ElementMember.ElementMemberFlags.IsInitOnly)) == 0) continue;
}

// see Dapper's TryStringSplit logic
if (IsCollectionType(memberType, out var innerType) && innerType is not null)
{
flags |= ElementMember.ElementMemberFlags.IsExpandable;
}

var columnAttributeData = ParseColumnAttributeData(member);

var columnAttributeData = ParseColumnAttributeData(member);
// all good, then!
builder.Add(new(member, dbValue, columnAttributeData, kind, flags, constructorParameterOrder, factoryMethodParamOrder));
}

// all good, then!
builder.Add(new(member, dbValue, columnAttributeData, kind, flags, constructorParameterOrder, factoryMethodParamOrder));
tier = tier.BaseType;
}

return builder.ToImmutable();
}

Expand Down
24 changes: 24 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/InheritedMembers.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Dapper;
using System.Data.Common;

[DapperAot]
public static class Foo
{
static void SomeCode(DbConnection connection, string bar)
{
var args = new { Foo = 12, bar };
// these should support Id+Name
var obj = connection.QueryFirst<Entity1>("def", args);
connection.Execute("ghi @Id, @Name", obj);
}
}

public abstract class EntityBase
{
public long Id { get; set; }
}

public class Entity1 : EntityBase
{
public string Name { get; set; }
}
198 changes: 198 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/InheritedMembers.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\InheritedMembers.input.cs", 11, 30)]
internal static global::Entity1 QueryFirst0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Query, TypedResult, HasParameters, SingleRow, StoredProcedure, AtLeastOne, BindResultsByName, KnownParameters
// takes parameter: <anonymous type: int Foo, string bar>
// parameter map: bar Foo
// returns data: global::Entity1
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QueryFirst(param, RowFactory0.Instance);

}

[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\InheritedMembers.input.cs", 12, 20)]
internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::Entity1
// parameter map: Id Name
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((global::Entity1)param!);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class RowFactory0 : global::Dapper.RowFactory<global::Entity1>
{
internal static readonly RowFactory0 Instance = new();
private RowFactory0() {}
public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)
{
for (int i = 0; i < tokens.Length; i++)
{
int token = -1;
var name = reader.GetName(columnOffset);
var type = reader.GetFieldType(columnOffset);
switch (NormalizedHash(name))
{
case 2369371622U when NormalizedEquals(name, "name"):
token = type == typeof(string) ? 0 : 2; // two tokens for right-typed and type-flexible
break;
case 926444256U when NormalizedEquals(name, "id"):
token = type == typeof(long) ? 1 : 3;
break;

}
tokens[i] = token;
columnOffset++;

}
return null;
}
public override global::Entity1 Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan<int> tokens, int columnOffset, object? state)
{
global::Entity1 result = new();
foreach (var token in tokens)
{
switch (token)
{
case 0:
result.Name = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset);
break;
case 2:
result.Name = reader.IsDBNull(columnOffset) ? (string?)null : GetValue<string>(reader, columnOffset);
break;
case 1:
result.Id = reader.GetInt64(columnOffset);
break;
case 3:
result.Id = GetValue<long>(reader, columnOffset);
break;

}
columnOffset++;

}
return result;

}

}

private sealed class CommandFactory0 : CommonCommandFactory<object?> // <anonymous type: int Foo, string bar>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { Foo = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Foo";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.Foo);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "bar";
p.DbType = global::System.Data.DbType.String;
p.Direction = global::System.Data.ParameterDirection.Input;
SetValueWithDefaultSize(p, typed.bar);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { Foo = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.Foo);
ps[1].Value = AsValue(typed.bar);

}
public override bool CanPrepare => true;

}

private sealed class CommandFactory1 : CommonCommandFactory<global::Entity1>
{
internal static readonly CommandFactory1 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::Entity1 args)
{
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Name";
p.DbType = global::System.Data.DbType.String;
p.Direction = global::System.Data.ParameterDirection.Input;
SetValueWithDefaultSize(p, args.Name);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "Id";
p.DbType = global::System.Data.DbType.Int64;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.Id);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::Entity1 args)
{
var ps = cmd.Parameters;
ps[0].Value = AsValue(args.Name);
ps[1].Value = AsValue(args.Id);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading

0 comments on commit 20659c1

Please sign in to comment.