From bf67f79bb8f37e0f29317880e42546743c5ffe4c Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Mon, 21 Oct 2024 16:40:11 +0100 Subject: [PATCH] Fix constant numbering in SLATE (#3808) --- firedrake/slate/slac/kernel_builder.py | 15 +++++++++++---- firedrake/slate/slac/tsfc_driver.py | 3 +-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/firedrake/slate/slac/kernel_builder.py b/firedrake/slate/slac/kernel_builder.py index d8b4b620d6..cbc9b6fed2 100644 --- a/firedrake/slate/slac/kernel_builder.py +++ b/firedrake/slate/slac/kernel_builder.py @@ -159,7 +159,11 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap # Pick the constants associated with a Tensor()/TSFC kernel tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers) - kernel_data.extend([(c, c.name) for c in wrapper_constants if c in tsfc_constants]) + kernel_data.extend([ + (constant, constant_name) + for constant, constant_name in wrapper_constants + if constant in tsfc_constants + ]) return kernel_data def loopify_tsfc_kernel_data(self, kernel_data): @@ -254,7 +258,10 @@ def collect_coefficients(self): def collect_constants(self): """ All constants of self.expression as a list """ - return self.expression.constants() + return tuple( + (constant, f"c_{i}") + for i, constant in enumerate(self.expression.constants()) + ) def initialise_terminals(self, var2terminal, coefficients): """ Initilisation of the variables in which coefficients @@ -361,9 +368,9 @@ def generate_wrapper_kernel_args(self, tensor2temp): dtype=self.tsfc_parameters["scalar_type"]) args.append(kernel_args.CoefficientKernelArg(coeff_loopy_arg)) - for constant in self.bag.constants: + for constant, constant_name in self.bag.constants: constant_loopy_arg = loopy.GlobalArg( - constant.name, + constant_name, shape=constant.dat.cdim, dtype=self.tsfc_parameters["scalar_type"] ) diff --git a/firedrake/slate/slac/tsfc_driver.py b/firedrake/slate/slac/tsfc_driver.py index 34baf10dbc..0f5fbf96d3 100644 --- a/firedrake/slate/slac/tsfc_driver.py +++ b/firedrake/slate/slac/tsfc_driver.py @@ -59,10 +59,9 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None): cxt_kernels = [] assert prefix is not None for orig_it_type, integrals in transformed_integrals.items(): - subkernel_prefix = prefix + "%s_to_" % orig_it_type form = Form(integrals) kernels = tsfc_compile(form, - subkernel_prefix, + f"{prefix}{orig_it_type}_to_", parameters=tsfc_parameters, split=False, diagonal=tensor.diagonal)