From 32ac46b376648e1daf4b15dfadc6c3c4e9f77354 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 09:00:21 -0600 Subject: [PATCH] Implement empty() for BaseForm subclasses --- ufl/algorithms/formsplitter.py | 10 ++++------ ufl/form.py | 10 +++++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ufl/algorithms/formsplitter.py b/ufl/algorithms/formsplitter.py index 7ab0f6a1e..a81176b5c 100644 --- a/ufl/algorithms/formsplitter.py +++ b/ufl/algorithms/formsplitter.py @@ -10,6 +10,8 @@ from typing import Optional +import numpy as np + from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags from ufl.argument import Argument from ufl.classes import FixedIndex, ListTensor @@ -53,14 +55,10 @@ def argument(self, obj): Q_i = FunctionSpace(dom, sub_elem) a = Argument(Q_i, obj.number(), part=obj.part()) - indices = [()] - for m in a.ufl_shape: - indices = [(*k, j) for k in indices for j in range(m)] - if i == self.idx[obj.number()]: - args.extend(a[j] for j in indices) + args.extend(a[j] for j in np.ndindex(a.ufl_shape)) else: - args.extend(Zero() for j in indices) + args.extend(Zero() for j in np.ndindex(a.ufl_shape)) return as_vector(args) diff --git a/ufl/form.py b/ufl/form.py index c4b672330..0560dd0f7 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -307,7 +307,7 @@ def integrals_by_domain(self, domain): def empty(self): """Returns whether the form has no integrals.""" - return self.integrals() == () + return len(self.integrals()) == 0 def ufl_domains(self): """Return the geometric integration domains occuring in the form. @@ -812,6 +812,10 @@ def equals(self, other): a == b for a, b in zip(self.components(), other.components()) ) + def empty(self): + """Returns whether the FormSum has no components.""" + return len(self.components()) == 0 + def __str__(self): """Compute shorter string representation of form. This can be huge for complicated forms.""" # Warning used for making sure we don't use this in the general pipeline: @@ -878,6 +882,10 @@ def ufl_domains(self): self._analyze_domains() return self._domains + def empty(self): + """Returns whether the form has no integrals.""" + return True + def __ne__(self, other): """Overwrite BaseForm.__neq__ which relies on `equals`.""" return not self == other