Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Commit

Permalink
use form data index as key instead of form data itself
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 10, 2020
1 parent 6a2c5b5 commit 3e15b16
Showing 1 changed file with 54 additions and 69 deletions.
123 changes: 54 additions & 69 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,16 @@ class TSFCFormData(object):
r"""Mimic `ufl.FormData`.
:arg form_data_tuple: A tuple of `ufl.FormData`s.
:arg extraarg_tuple: A tuple of extra `ufl.Argument`s
corresponding to form_data_tuple. These extra
arguments are eventually replaced by the user with
the associated functions in function_tuple after
compiling UFL but before compiling gem. These
arguments thus do not contribute to the rank of the form.
:arg function_tuple: A tuple of functions corresponding
to extraarg_tuple.
:arg original_form: The form from which forms for
`ufl.Formdata`s were extracted.
:arg form_data_extraarg_map: A map from `ufl.FormData`s to
extra `ufl.Argument`s: the user can apply arbitrary
linear transformations to these `ufl.Argument`s and
replace them with corresponding functions stored in
`form_data_function_map`. This must happen after
compiling UFL but before compiling gem. These
`ufl.Arguments` thus do not contribute to the rank
of the form.
:arg form_data_function_map: A map from `ufl.FormData`s to
functions corresponding to the face `ufl.Argument`s in
`form_data_extraarg_map`.
:diagonal: A flag for diagonal matrix assembly.
This class mimics `ufl.FormData`, but is to contain minimum
Expand All @@ -69,8 +66,8 @@ class TSFCFormData(object):
+--- form_N ---- gem_N' ---- gem_N ---+
After preprocessing `ufl.FormData`s here:
* Only essential information about the `ufl.FormData`s is retained.
* TSFC can forget `ufl.FormData.original_form`,
* TSFC can forget `ufl.IntegralData.enabled_coefficients`,
* `KernelBuilder`s only need to deal with raw `ufl.Coefficient`s.
Illustration of the structures.
Expand All @@ -91,108 +88,96 @@ class TSFCFormData(object):
|____0___||____1___|_ _|____M___| ||________||________| |________||
|_____________________________________|
"""
def __init__(self, form_data_tuple, original_form, form_data_extraarg_map, form_data_function_map, diagonal):
def __init__(self, form_data_tuple, extraarg_tuple, function_tuple, original_form, diagonal):
arguments = set()
for fd in form_data_tuple:
for fd, extraarg in zip(form_data_tuple, extraarg_tuple):
args = []
for arg in fd.preprocessed_form.arguments():
if arg not in form_data_extraarg_map[fd]:
if arg not in extraarg:
args.append(arg)
arguments.update((tuple(args), ))
if len(arguments) != 1:
raise ValueError("Found inconsistent sets of arguments in `FormData`s.")
self.arguments, = tuple(arguments)
# Gathere all coefficients.
# Gather all coefficients.
# If a form contains extra arguments, those will be replaced by corresponding functions
# after compiling UFL, so these functions must be included here, too.
reduced_coefficients_set = set(c for fd in form_data_tuple for c in fd.reduced_coefficients)
for _, val in form_data_function_map.items():
reduced_coefficients_set.update(val)
reduced_coefficients_set.update(chain(*function_tuple))
reduced_coefficients = sorted(reduced_coefficients_set, key=lambda c: c.count())
if len(form_data_tuple) == 1:
self.reduced_coefficients = form_data_tuple[0].reduced_coefficients
self.original_coefficient_positions = form_data_tuple[0].original_coefficient_positions
self.function_replace_map = form_data_tuple[0].function_replace_map
else:
# Reconstruct `ufl.Coefficinet`s with count starting at 0.
function_replace_map = {}
for i, func in enumerate(reduced_coefficients):
for fd in form_data_tuple:
if func in fd.function_replace_map:
coeff = fd.function_replace_map[func]
new_coeff = Coefficient(coeff.ufl_function_space(), count=i)
function_replace_map[func] = new_coeff
break
else:
ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element())
new_coeff = Coefficient(ufl_function_space, count=i)
# Reconstruct `ufl.Coefficinet`s with count starting at 0.
function_replace_map = {}
for i, func in enumerate(reduced_coefficients):
for fd in form_data_tuple:
if func in fd.function_replace_map:
coeff = fd.function_replace_map[func]
new_coeff = Coefficient(coeff.ufl_function_space(), count=i)
function_replace_map[func] = new_coeff
self.reduced_coefficients = reduced_coefficients
self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients())
if f in self.reduced_coefficients]
self.function_replace_map = function_replace_map
break
else:
ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element())
new_coeff = Coefficient(ufl_function_space, count=i)
function_replace_map[func] = new_coeff
self.reduced_coefficients = reduced_coefficients
self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients())
if f in self.reduced_coefficients]
self.function_replace_map = function_replace_map

# Translate `ufl.IntegralData`s -> `TSFCIntegralData`.
intg_data_dict = {}
form_data_dict = {}
for form_data in form_data_tuple:
intg_data_info_dict = {}
for form_data_index, form_data in enumerate(form_data_tuple):
for intg_data in form_data.integral_data:
domain = intg_data.domain
integral_type = intg_data.integral_type
subdomain_id = intg_data.subdomain_id
key = (domain, integral_type, subdomain_id)
# Add intg_data.
intg_data_dict.setdefault(key, []).append(intg_data)
# Remember which form_data this intg_data came from.
form_data_dict.setdefault(key, []).append(form_data)
# Add (intg_data, form_data, form_data_index).
intg_data_info_dict.setdefault(key, []).append((intg_data, form_data, form_data_index))
integral_data_list = []
for key in intg_data_dict:
intg_data_list = intg_data_dict[key]
form_data_list = form_data_dict[key]
for key, intg_data_info in intg_data_info_dict.items():
domain, _, _ = key
domain_number = original_form.domain_numbering()[domain]
integral_data_list.append(TSFCIntegralData(key, intg_data_list, form_data_list,
self, domain_number, form_data_function_map))
integral_data_list.append(TSFCIntegralData(key, intg_data_info,
self, domain_number, function_tuple))
self.integral_data = tuple(integral_data_list)


class TSFCIntegralData(object):
r"""Mimics `ufl.IntegralData`.
:arg integral_data_key: (domain, integral_type, subdomain_id) tuple.
:arg integral_data_list: A list of `ufl.IntegralData`.
:arg form_data_list: A list of `ufl.FormData`.
:arg tsfc_form_data: The `TSFCFormData` that will contain this
:arg integral_data_info: A tuple of the lists of integral_data,
form_data, and form_data_index.
:arg tsfc_form_data: The `TSFCFormData` that is to contain this
`TSFCIntegralData` object.
:arg domain_number: The domain number associated with `domain`.
:arg form_data_function_map: A map from `ufl.FormData`s to functions.
:arg function_tuple: A tuple of functions.
This class mimics `ufl.FormData`, but:
* extracts information required by TSFC.
* preprocesses integrals so that `KernelBuilder`s only
need to deal with raw `ufl.Coefficient`s.
After preprocessing `ufl.IntegralData`s here:
* Only essential information about the `ufl.IntegralData`s is retained.
* TSFC can forget `ufl.IntegralData.enabled_coefficients`,
"""
def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_form_data, domain_number, form_data_function_map):
def __init__(self, integral_data_key, intg_data_info, tsfc_form_data, domain_number, function_tuple):
self.domain, self.integral_type, self.subdomain_id = integral_data_key
self.domain_number = domain_number
# Gather/preprocess integrals.
integrals = []
_integral_to_form_data_map = {}
_integral_index_to_form_data_index = []
functions = set()
for intg_data, form_data in zip(integral_data_list, form_data_list):
for intg_data, form_data, form_data_index in intg_data_info:
for integral in intg_data.integrals:
integrand = integral.integrand()
# Replace functions with Coefficients here.
integrand = ufl.replace(integrand, tsfc_form_data.function_replace_map)
new_integral = integral.reconstruct(integrand=integrand)
integrals.append(new_integral)
# Remember which form_data this integral is associated with.
_integral_to_form_data_map[new_integral] = form_data
_integral_index_to_form_data_index.append(form_data_index)
# Gather functions that are enabled in this `TSFCIntegralData`.
functions.update(f for f, enabled in zip(form_data.reduced_coefficients, intg_data.enabled_coefficients) if enabled)
functions.update(form_data_function_map[form_data])
functions.update(function_tuple[form_data_index])
self.integrals = tuple(integrals)
self._integral_to_form_data_map = _integral_to_form_data_map
self._integral_index_to_form_data_index = _integral_index_to_form_data_index
self.arguments = tsfc_form_data.arguments
# This is which coefficient in the original form the
# current coefficient is.
Expand All @@ -206,9 +191,9 @@ def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_f
self.coefficients = tuple(tsfc_form_data.function_replace_map[f] for f in functions)
self.coefficient_numbers = tuple(tsfc_form_data.original_coefficient_positions[tsfc_form_data.reduced_coefficients.index(f)] for f in functions)

def integral_to_form_data(self, integral):
r"""Return `ufl.FormData` which `integral` is associated with."""
return self._integral_to_form_data_map[integral]
def integral_index_to_form_data_index(self, integral_index):
r"""Return the form data index given an integral index."""
return self._integral_index_to_form_data_index[integral_index]


def compile_form(form, prefix="form", parameters=None, interface=None, coffee=True, diagonal=False):
Expand All @@ -233,7 +218,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr
form_data = ufl_utils.compute_form_data(form, complex_mode=complex_mode)
if interface:
interface = partial(interface, function_replace_map=form_data.function_replace_map)
tsfc_form_data = TSFCFormData((form_data, ), form_data.original_form, {form_data: ()}, {form_data: ()}, diagonal)
tsfc_form_data = TSFCFormData((form_data, ), ((), ), ((), ), form_data.original_form, diagonal)

logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)

Expand Down

0 comments on commit 3e15b16

Please sign in to comment.