Skip to content

Commit

Permalink
Match explicit entry points for DX12 shaders
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Sep 21, 2024
1 parent 57f0a9b commit 057ffa1
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@ namespace ComputeSharp.SourceGeneration.Extensions;
/// </summary>
internal static class ITypeSymbolExtensions
{
/// <summary>
/// Gets the method of this symbol that have a particular name.
/// </summary>
/// <param name="symbol">The input <see cref="ITypeSymbol"/> instance to check.</param>
/// <param name="name">The name of the method to find.</param>
/// <returns>The target method, if present.</returns>
public static IMethodSymbol? GetMethod(this ITypeSymbol symbol, string name)
{
foreach (ISymbol memberSymbol in symbol.GetMembers(name))
{
if (memberSymbol is IMethodSymbol methodSymbol &&
memberSymbol.Name == name)
{
return methodSymbol;
}
}

return null;
}

/// <summary>
/// Checks whether or not a given type symbol has a specified fully qualified metadata name.
/// </summary>
Expand All @@ -28,7 +48,7 @@ public static bool HasFullyQualifiedMetadataName(this ITypeSymbol symbol, string
/// Checks whether or not a given <see cref="ITypeSymbol"/> implements an interface of a specified type.
/// </summary>
/// <param name="typeSymbol">The target <see cref="ITypeSymbol"/> instance to check.</param>
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instane to check for inheritance from.</param>
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instance to check for inheritance from.</param>
/// <returns>Whether or not <paramref name="typeSymbol"/> has an interface of type <paramref name="interfaceSymbol"/>.</returns>
public static bool HasInterfaceWithType(this ITypeSymbol typeSymbol, ITypeSymbol interfaceSymbol)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;

namespace ComputeSharp.SourceGenerators;
Expand All @@ -10,9 +11,14 @@ partial class ComputeShaderDescriptorGenerator
/// </summary>
/// <param name="typeSymbol">The input <see cref="INamedTypeSymbol"/> instance to check.</param>
/// <param name="compilation">The <see cref="Compilation"/> instance currently in use.</param>
/// <param name="shaderInterfaceType">The (constructed) shader interface type implemented by the shader type.</param>
/// <param name="isPixelShaderLike">Whether <paramref name="typeSymbol"/> is a "pixel shader like" type.</param>
/// <returns>Whether <paramref name="typeSymbol"/> is a compute shader type at all.</returns>
private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compilation compilation, out bool isPixelShaderLike)
private static bool TryGetIsPixelShaderLike(
INamedTypeSymbol typeSymbol,
Compilation compilation,
[NotNullWhen(true)] out INamedTypeSymbol? shaderInterfaceType,
out bool isPixelShaderLike)
{
INamedTypeSymbol computeShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader")!;
INamedTypeSymbol pixelShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader`1")!;
Expand All @@ -21,18 +27,21 @@ private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compila
{
if (SymbolEqualityComparer.Default.Equals(interfaceSymbol, computeShaderSymbol))
{
shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = false;

return true;
}
else if (SymbolEqualityComparer.Default.Equals(interfaceSymbol.ConstructedFrom, pixelShaderSymbol))
{
shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = true;

return true;
}
}

shaderInterfaceType = null;
isPixelShaderLike = false;

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ internal static partial class HlslSource
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="isPixelShaderLike">Whether <paramref name="structDeclarationSymbol"/> is a "pixel shader like" type.</param>
/// <param name="threadsX">The thread ids value for the X axis.</param>
/// <param name="threadsY">The thread ids value for the Y axis.</param>
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
Expand All @@ -42,6 +44,8 @@ public static void GetInfo(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Compilation compilation,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
bool isPixelShaderLike,
int threadsX,
int threadsY,
int threadsZ,
Expand All @@ -53,6 +57,8 @@ public static void GetInfo(
// Detect any invalid properties
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);

token.ThrowIfCancellationRequested();

// We need to sets to track all discovered custom types and static methods
HashSet<INamedTypeSymbol> discoveredTypes = new(SymbolEqualityComparer.Default);
Dictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods = new(SymbolEqualityComparer.Default);
Expand All @@ -62,9 +68,8 @@ public static void GetInfo(
Dictionary<IFieldSymbol, HlslStaticField> staticFieldDefinitions = new(SymbolEqualityComparer.Default);

// Setup the semantic model and basic properties
INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" });
bool isComputeShader = pixelShaderSymbol is null;
string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!);
bool isComputeShader = !isPixelShaderLike;
string? implicitTextureType = HlslKnownTypes.GetMappedNameForPixelShaderType(shaderInterfaceType);

token.ThrowIfCancellationRequested();

Expand All @@ -90,6 +95,7 @@ public static void GetInfo(
(string entryPoint, ImmutableArray<HlslMethod> processedMethods, isSamplerUsed) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
shaderInterfaceType,
semanticModelProvider,
discoveredTypes,
staticMethods,
Expand Down Expand Up @@ -360,6 +366,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
/// <param name="staticMethods">The set of discovered and processed static methods.</param>
Expand All @@ -373,6 +380,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSamplerUser) GetProcessedMethods(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
SemanticModelProvider semanticModel,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
Expand All @@ -385,6 +393,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
{
using ImmutableArrayBuilder<HlslMethod> methods = new();

IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
string? entryPoint = null;
bool isSamplerUsed = false;

Expand All @@ -396,22 +405,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
continue;
}

// Ensure that we have accessible source information
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
{
continue;
}

bool isShaderEntryPoint =
(isComputeShader &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnsVoid &&
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0) ||
(!isComputeShader &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnType is not null && // TODO: match for pixel type
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0);
// Check whether the current method is the entry point (ie. it's implementing 'Execute'). We use
// 'FindImplementationForInterfaceMember' to handle explicit interface implementations as well.
bool isShaderEntryPoint = SymbolEqualityComparer.Default.Equals(
structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod),
methodSymbol);

// Except for the entry point, ignore explicit interface implementations
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
}
// Check whether type is a compute shader, and if so, if it's pixel shader like
if (!TryGetIsPixelShaderLike(typeSymbol, context.SemanticModel.Compilation, out bool isPixelShaderLike))
if (!TryGetIsPixelShaderLike(
typeSymbol,
context.SemanticModel.Compilation,
out INamedTypeSymbol? shaderInterfaceType,
out bool isPixelShaderLike))
{
return default;
}
Expand Down Expand Up @@ -91,6 +95,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
diagnostics,
context.SemanticModel.Compilation,
typeSymbol,
shaderInterfaceType,
isPixelShaderLike,
threadsX,
threadsY,
threadsZ,
Expand Down
13 changes: 10 additions & 3 deletions src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,18 @@ public static partial string GetMappedName(INamedTypeSymbol typeSymbol)
/// <summary>
/// Gets the mapped HLSL-compatible type name for the output texture of a pixel shader.
/// </summary>
/// <param name="typeSymbol">The pixel shader type to map.</param>
/// <param name="typeSymbol">The shader type to map.</param>
/// <returns>The HLSL-compatible type name that can be used in an HLSL shader.</returns>
public static string GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol)
public static string? GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol)
{
string genericArgumentName = ((INamedTypeSymbol)typeSymbol.TypeArguments.First()).GetFullyQualifiedMetadataName();
// If the shader type is not a pixel shader type (ie. it has a type argument), stop here.
// At this point the input is guaranteed to either be 'IComputeShader' or 'IComputeShader<TPixel>'.
if (typeSymbol.TypeArguments is not [INamedTypeSymbol pixelShaderType])
{
return null;
}

string genericArgumentName = pixelShaderType.GetFullyQualifiedMetadataName();

// If the current type is a custom type, format it as needed
if (!KnownHlslTypeMetadataNames.TryGetValue(genericArgumentName, out string? mappedElementType))
Expand Down

0 comments on commit 057ffa1

Please sign in to comment.