diff --git a/tsfc/driver.py b/tsfc/driver.py index c55c8352..7ef5b8e1 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -34,19 +34,16 @@ class TSFCFormData(object): r"""Mimic `ufl.FormData`. :arg form_data_tuple: A tuple of `ufl.FormData`s. + :arg extraarg_tuple: A tuple of extra `ufl.Argument`s + corresponding to form_data_tuple. These extra + arguments are eventually replaced by the user with + the associated functions in function_tuple after + compiling UFL but before compiling gem. These + arguments thus do not contribute to the rank of the form. + :arg function_tuple: A tuple of functions corresponding + to extraarg_tuple. :arg original_form: The form from which forms for `ufl.Formdata`s were extracted. - :arg form_data_extraarg_map: A map from `ufl.FormData`s to - extra `ufl.Argument`s: the user can apply arbitrary - linear transformations to these `ufl.Argument`s and - replace them with corresponding functions stored in - `form_data_function_map`. This must happen after - compiling UFL but before compiling gem. These - `ufl.Arguments` thus do not contribute to the rank - of the form. - :arg form_data_function_map: A map from `ufl.FormData`s to - functions corresponding to the face `ufl.Argument`s in - `form_data_extraarg_map`. :diagonal: A flag for diagonal matrix assembly. This class mimics `ufl.FormData`, but is to contain minimum @@ -69,8 +66,8 @@ class TSFCFormData(object): +--- form_N ---- gem_N' ---- gem_N ---+ After preprocessing `ufl.FormData`s here: + * Only essential information about the `ufl.FormData`s is retained. * TSFC can forget `ufl.FormData.original_form`, - * TSFC can forget `ufl.IntegralData.enabled_coefficients`, * `KernelBuilder`s only need to deal with raw `ufl.Coefficient`s. Illustration of the structures. @@ -91,68 +88,57 @@ class TSFCFormData(object): |____0___||____1___|_ _|____M___| ||________||________| |________|| |_____________________________________| """ - def __init__(self, form_data_tuple, original_form, form_data_extraarg_map, form_data_function_map, diagonal): + def __init__(self, form_data_tuple, extraarg_tuple, function_tuple, original_form, diagonal): arguments = set() - for fd in form_data_tuple: + for fd, extraarg in zip(form_data_tuple, extraarg_tuple): args = [] for arg in fd.preprocessed_form.arguments(): - if arg not in form_data_extraarg_map[fd]: + if arg not in extraarg: args.append(arg) arguments.update((tuple(args), )) if len(arguments) != 1: raise ValueError("Found inconsistent sets of arguments in `FormData`s.") self.arguments, = tuple(arguments) - # Gathere all coefficients. + # Gather all coefficients. # If a form contains extra arguments, those will be replaced by corresponding functions # after compiling UFL, so these functions must be included here, too. reduced_coefficients_set = set(c for fd in form_data_tuple for c in fd.reduced_coefficients) - for _, val in form_data_function_map.items(): - reduced_coefficients_set.update(val) + reduced_coefficients_set.update(chain(*function_tuple)) reduced_coefficients = sorted(reduced_coefficients_set, key=lambda c: c.count()) - if len(form_data_tuple) == 1: - self.reduced_coefficients = form_data_tuple[0].reduced_coefficients - self.original_coefficient_positions = form_data_tuple[0].original_coefficient_positions - self.function_replace_map = form_data_tuple[0].function_replace_map - else: - # Reconstruct `ufl.Coefficinet`s with count starting at 0. - function_replace_map = {} - for i, func in enumerate(reduced_coefficients): - for fd in form_data_tuple: - if func in fd.function_replace_map: - coeff = fd.function_replace_map[func] - new_coeff = Coefficient(coeff.ufl_function_space(), count=i) - function_replace_map[func] = new_coeff - break - else: - ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element()) - new_coeff = Coefficient(ufl_function_space, count=i) + # Reconstruct `ufl.Coefficinet`s with count starting at 0. + function_replace_map = {} + for i, func in enumerate(reduced_coefficients): + for fd in form_data_tuple: + if func in fd.function_replace_map: + coeff = fd.function_replace_map[func] + new_coeff = Coefficient(coeff.ufl_function_space(), count=i) function_replace_map[func] = new_coeff - self.reduced_coefficients = reduced_coefficients - self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients()) - if f in self.reduced_coefficients] - self.function_replace_map = function_replace_map + break + else: + ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element()) + new_coeff = Coefficient(ufl_function_space, count=i) + function_replace_map[func] = new_coeff + self.reduced_coefficients = reduced_coefficients + self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients()) + if f in self.reduced_coefficients] + self.function_replace_map = function_replace_map # Translate `ufl.IntegralData`s -> `TSFCIntegralData`. - intg_data_dict = {} - form_data_dict = {} - for form_data in form_data_tuple: + intg_data_info_dict = {} + for form_data_index, form_data in enumerate(form_data_tuple): for intg_data in form_data.integral_data: domain = intg_data.domain integral_type = intg_data.integral_type subdomain_id = intg_data.subdomain_id key = (domain, integral_type, subdomain_id) - # Add intg_data. - intg_data_dict.setdefault(key, []).append(intg_data) - # Remember which form_data this intg_data came from. - form_data_dict.setdefault(key, []).append(form_data) + # Add (intg_data, form_data, form_data_index). + intg_data_info_dict.setdefault(key, []).append((intg_data, form_data, form_data_index)) integral_data_list = [] - for key in intg_data_dict: - intg_data_list = intg_data_dict[key] - form_data_list = form_data_dict[key] + for key, intg_data_info in intg_data_info_dict.items(): domain, _, _ = key domain_number = original_form.domain_numbering()[domain] - integral_data_list.append(TSFCIntegralData(key, intg_data_list, form_data_list, - self, domain_number, form_data_function_map)) + integral_data_list.append(TSFCIntegralData(key, intg_data_info, + self, domain_number, function_tuple)) self.integral_data = tuple(integral_data_list) @@ -160,26 +146,25 @@ class TSFCIntegralData(object): r"""Mimics `ufl.IntegralData`. :arg integral_data_key: (domain, integral_type, subdomain_id) tuple. - :arg integral_data_list: A list of `ufl.IntegralData`. - :arg form_data_list: A list of `ufl.FormData`. - :arg tsfc_form_data: The `TSFCFormData` that will contain this + :arg integral_data_info: A tuple of the lists of integral_data, + form_data, and form_data_index. + :arg tsfc_form_data: The `TSFCFormData` that is to contain this `TSFCIntegralData` object. :arg domain_number: The domain number associated with `domain`. - :arg form_data_function_map: A map from `ufl.FormData`s to functions. + :arg function_tuple: A tuple of functions. - This class mimics `ufl.FormData`, but: - * extracts information required by TSFC. - * preprocesses integrals so that `KernelBuilder`s only - need to deal with raw `ufl.Coefficient`s. + After preprocessing `ufl.IntegralData`s here: + * Only essential information about the `ufl.IntegralData`s is retained. + * TSFC can forget `ufl.IntegralData.enabled_coefficients`, """ - def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_form_data, domain_number, form_data_function_map): + def __init__(self, integral_data_key, intg_data_info, tsfc_form_data, domain_number, function_tuple): self.domain, self.integral_type, self.subdomain_id = integral_data_key self.domain_number = domain_number # Gather/preprocess integrals. integrals = [] - _integral_to_form_data_map = {} + _integral_index_to_form_data_index = [] functions = set() - for intg_data, form_data in zip(integral_data_list, form_data_list): + for intg_data, form_data, form_data_index in intg_data_info: for integral in intg_data.integrals: integrand = integral.integrand() # Replace functions with Coefficients here. @@ -187,12 +172,12 @@ def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_f new_integral = integral.reconstruct(integrand=integrand) integrals.append(new_integral) # Remember which form_data this integral is associated with. - _integral_to_form_data_map[new_integral] = form_data + _integral_index_to_form_data_index.append(form_data_index) # Gather functions that are enabled in this `TSFCIntegralData`. functions.update(f for f, enabled in zip(form_data.reduced_coefficients, intg_data.enabled_coefficients) if enabled) - functions.update(form_data_function_map[form_data]) + functions.update(function_tuple[form_data_index]) self.integrals = tuple(integrals) - self._integral_to_form_data_map = _integral_to_form_data_map + self._integral_index_to_form_data_index = _integral_index_to_form_data_index self.arguments = tsfc_form_data.arguments # This is which coefficient in the original form the # current coefficient is. @@ -206,9 +191,9 @@ def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_f self.coefficients = tuple(tsfc_form_data.function_replace_map[f] for f in functions) self.coefficient_numbers = tuple(tsfc_form_data.original_coefficient_positions[tsfc_form_data.reduced_coefficients.index(f)] for f in functions) - def integral_to_form_data(self, integral): - r"""Return `ufl.FormData` which `integral` is associated with.""" - return self._integral_to_form_data_map[integral] + def integral_index_to_form_data_index(self, integral_index): + r"""Return the form data index given an integral index.""" + return self._integral_index_to_form_data_index[integral_index] def compile_form(form, prefix="form", parameters=None, interface=None, coffee=True, diagonal=False): @@ -233,7 +218,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr form_data = ufl_utils.compute_form_data(form, complex_mode=complex_mode) if interface: interface = partial(interface, function_replace_map=form_data.function_replace_map) - tsfc_form_data = TSFCFormData((form_data, ), form_data.original_form, {form_data: ()}, {form_data: ()}, diagonal) + tsfc_form_data = TSFCFormData((form_data, ), ((), ), ((), ), form_data.original_form, diagonal) logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)