Skip to content

Commit

Permalink
Fixed multithread issue, minor performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Sep 15, 2019
1 parent b5c1345 commit 381f99d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 44 deletions.
52 changes: 28 additions & 24 deletions src/ComputeSharp.Shaders/ShaderRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static void Run(GraphicsDevice device, int x, int y, int z, Action<Thread
/// <summary>
/// The mapping used to cache and reuse compiled shaders
/// </summary>
private static readonly Dictionary<(int Id, int ThreadsX, int ThreadsY, int ThreadsZ), (ShaderLoader Loader, ShaderBytecode Bytecode)> ShadersCache = new Dictionary<(int, int, int, int), (ShaderLoader, ShaderBytecode)>();
private static readonly Dictionary<(int Id, int ThreadsX, int ThreadsY, int ThreadsZ), (ShaderLoader, ShaderBytecode)> ShadersCache = new Dictionary<(int, int, int, int), (ShaderLoader, ShaderBytecode)>();

/// <summary>
/// Compiles and runs the input shader on a target <see cref="GraphicsDevice"/> instance, with the specified parameters
Expand All @@ -60,33 +60,37 @@ public static void Run(
Action<ThreadIds> action)
{
// Try to get the cache shader
var key = (ShaderLoader.GetHashCode(action), threadsX, threadsY, threadsZ);
if (!ShadersCache.TryGetValue(key, out var shaderData))
(ShaderLoader Loader, ShaderBytecode Bytecode) shaderData;
lock (ShadersCache)
{
// Load the input shader
ShaderLoader shaderLoader = ShaderLoader.Load(action);

// Render the loaded shader
ShaderInfo shaderInfo = new ShaderInfo
var key = (ShaderLoader.GetHashCode(action), threadsX, threadsY, threadsZ);
if (!ShadersCache.TryGetValue(key, out shaderData))
{
BuffersList = shaderLoader.BuffersList,
FieldsList = shaderLoader.FieldsList,
NumThreadsX = threadsX,
NumThreadsY = threadsY,
NumThreadsZ = threadsZ,
ThreadsIdsVariableName = shaderLoader.ThreadsIdsVariableName,
ShaderBody = shaderLoader.MethodBody,
FunctionsList = shaderLoader.FunctionsList,
LocalFunctionsList = shaderLoader.LocalFunctionsList
};
string shaderSource = ShaderRenderer.Instance.Render(shaderInfo);
// Load the input shader
ShaderLoader shaderLoader = ShaderLoader.Load(action);

// Render the loaded shader
ShaderInfo shaderInfo = new ShaderInfo
{
BuffersList = shaderLoader.BuffersList,
FieldsList = shaderLoader.FieldsList,
NumThreadsX = threadsX,
NumThreadsY = threadsY,
NumThreadsZ = threadsZ,
ThreadsIdsVariableName = shaderLoader.ThreadsIdsVariableName,
ShaderBody = shaderLoader.MethodBody,
FunctionsList = shaderLoader.FunctionsList,
LocalFunctionsList = shaderLoader.LocalFunctionsList
};
string shaderSource = ShaderRenderer.Instance.Render(shaderInfo);

// Compile the loaded shader to HLSL bytecode
ShaderBytecode shaderBytecode = ShaderCompiler.CompileShader(shaderSource);
// Compile the loaded shader to HLSL bytecode
ShaderBytecode shaderBytecode = ShaderCompiler.CompileShader(shaderSource);

// Cache for later use
shaderData = (shaderLoader, shaderBytecode);
ShadersCache.Add(key, shaderData);
// Cache for later use
shaderData = (shaderLoader, shaderBytecode);
ShadersCache.Add(key, shaderData);
}
}

// Create the root signature for the pipeline and get the pipeline state
Expand Down
32 changes: 12 additions & 20 deletions src/ComputeSharp.Shaders/Translation/MethodDecompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ internal sealed class MethodDecompiler
/// </summary>
public static MethodDecompiler Instance { get; } = new MethodDecompiler();

/// <summary>
/// The dummy object used to handle concurrent decompilation requests
/// </summary>
private readonly object Lock = new object();

/// <summary>
/// The mapping of available <see cref="CSharpDecompiler"/> instances targeting different assemblies
/// </summary>
Expand Down Expand Up @@ -141,25 +136,22 @@ private string DecompileMethodOrDeclaringType(MethodInfo methodInfo, bool method
/// <param name="semanticModel">The semantic model for the input method</param>
public void GetSyntaxTree(MethodInfo methodInfo, MethodType methodType, out MethodDeclarationSyntax rootNode, out SemanticModel semanticModel)
{
lock (Lock)
string sourceCode = methodType switch
{
string sourceCode = methodType switch
{
MethodType.Closure => GetSyntaxTreeForClosureMethod(methodInfo),
MethodType.Static => GetSyntaxTreeForStaticMethod(methodInfo),
_ => throw new ArgumentOutOfRangeException(nameof(methodType), $"Invalid method type: {methodType}")
};
MethodType.Closure => GetSyntaxTreeForClosureMethod(methodInfo),
MethodType.Static => GetSyntaxTreeForStaticMethod(methodInfo),
_ => throw new ArgumentOutOfRangeException(nameof(methodType), $"Invalid method type: {methodType}")
};

// Load the type syntax tree
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(sourceCode);
// Load the type syntax tree
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(sourceCode);

// Get the root node to return
rootNode = syntaxTree.GetRoot().DescendantNodes().OfType<MethodDeclarationSyntax>().First(node => node.GetLeadingTrivia().ToFullString().Contains(methodInfo.Name));
// Get the root node to return
rootNode = syntaxTree.GetRoot().DescendantNodes().OfType<MethodDeclarationSyntax>().First(node => node.GetLeadingTrivia().ToFullString().Contains(methodInfo.Name));

// Update the incremental compilation and retrieve the syntax tree for the method
_Compilation = _Compilation.AddSyntaxTrees(syntaxTree);
semanticModel = _Compilation.GetSemanticModel(syntaxTree);
}
// Update the incremental compilation and retrieve the syntax tree for the method
_Compilation = _Compilation.AddSyntaxTrees(syntaxTree);
semanticModel = _Compilation.GetSemanticModel(syntaxTree);
}

/// <summary>
Expand Down

0 comments on commit 381f99d

Please sign in to comment.