Skip to content

Commit

Permalink
Factory->Builder
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Aug 26, 2024
1 parent b0fad3d commit 5a4d542
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
36 changes: 18 additions & 18 deletions movement/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
)

__all__ = [
"ConstantMonitorFactory",
"BallMonitorFactory",
"GradientMonitorFactory",
"HessianMonitorFactory",
"GradientHessianMonitorFactory",
"ConstantMonitorBuilder",
"BallMonitorBuilder",
"GradientMonitorBuilder",
"HessianMonitorBuilder",
"GradientHessianMonitorBuilder",
]


class MonitorFactory(metaclass=abc.ABCMeta):
class MonitorBuilder(metaclass=abc.ABCMeta):
"""
Abstract base class for monitor function factories.
"""
Expand Down Expand Up @@ -70,9 +70,9 @@ def __call__(self):
return self.get_monitor()


class ConstantMonitorFactory(MonitorFactory):
class ConstantMonitorBuilder(MonitorBuilder):
"""
Factory class for constant monitor functions.
Builder class for constant monitor functions.
"""

def monitor(self, mesh):
Expand All @@ -87,9 +87,9 @@ def monitor(self, mesh):
return Constant(1.0)


class BallMonitorFactory(MonitorFactory):
class BallMonitorBuilder(MonitorBuilder):
r"""
Factory class for monitor functions focused around ball shapes:
Builder class for monitor functions focused around ball shapes:
.. math::
m(\mathbf{x}) = 1 + \frac{\alpha}
Expand Down Expand Up @@ -141,7 +141,7 @@ def monitor(self, mesh):
)


class SolutionBasedMonitorFactory(MonitorFactory, metaclass=abc.ABCMeta):
class SolutionBasedMonitorBuilder(MonitorBuilder, metaclass=abc.ABCMeta):
"""
Abstract base class for monitor factories based on solution data.
"""
Expand All @@ -160,9 +160,9 @@ def __init__(self, dim, solution):


# TODO: Support computing gradient with Clement interpolant
class GradientMonitorFactory(SolutionBasedMonitorFactory):
class GradientMonitorBuilder(SolutionBasedMonitorBuilder):
r"""
Factory class for monitor functions based on gradients of solutions:
Builder class for monitor functions based on gradients of solutions:
.. math::
m(\mathbf{x}) = 1 + \alpha\frac{\nabla u\cdot\nabla u}
Expand Down Expand Up @@ -213,9 +213,9 @@ def monitor(self, mesh):


# TODO: Support computing Hessian with double L2 projection
class HessianMonitorFactory(SolutionBasedMonitorFactory):
class HessianMonitorBuilder(SolutionBasedMonitorBuilder):
r"""
Factory class for monitor functions based on Hessians of solutions.
Builder class for monitor functions based on Hessians of solutions.
.. math::
m(\mathbf{x}) = 1 + \alpha\frac{\nabla u\cdot\nabla u}
Expand Down Expand Up @@ -266,9 +266,9 @@ def monitor(self, mesh):
)


class GradientHessianMonitorFactory(GradientMonitorFactory, HessianMonitorFactory):
class GradientHessianMonitorBuilder(GradientMonitorBuilder, HessianMonitorBuilder):
r"""
Factory class for monitor functions based on both gradients and Hessians of
Builder class for monitor functions based on both gradients and Hessians of
solutions.
.. math::
Expand All @@ -293,7 +293,7 @@ def __init__(self, dim, solution, gradient_scale_factor, hessian_scale_factor):
:arg solution: solution to recover the gradient and Hessian of
:type solution: :class:`firedrake.function.Function`
"""
SolutionBasedMonitorFactory.__init__(self, dim, solution)
SolutionBasedMonitorBuilder.__init__(self, dim, solution)
self.gradient_scale_factor = Constant(gradient_scale_factor)
self.hessian_scale_factor = Constant(hessian_scale_factor)

Expand Down
44 changes: 22 additions & 22 deletions test/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BaseClasses:
Base classes for monitor factories.
"""

class TestMonitorFactory(unittest.TestCase):
class TestMonitorBuilder(unittest.TestCase):
"""
Base class for monitor factory unit tests.
"""
Expand All @@ -29,75 +29,75 @@ def setUp(self):
self.solution = Function(self.P1)


class TestConstant(BaseClasses.TestMonitorFactory):
class TestConstant(BaseClasses.TestMonitorBuilder):
"""
Unit tests for :class:`~.ConstantMonitorFactory`.
Unit tests for :class:`~.ConstantMonitorBuilder`.
"""

def test_value(self):
mf = ConstantMonitorFactory(self.mesh.topological_dimension())
mf = ConstantMonitorBuilder(self.mesh.topological_dimension())
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 1))


class TestBall(BaseClasses.TestMonitorFactory):
class TestBall(BaseClasses.TestMonitorBuilder):
"""
Unit tests for :class:`~.BallMonitorFactory`.
Unit tests for :class:`~.BallMonitorBuilder`.
"""

def test_tiny_amplitude(self):
mf = BallMonitorFactory(
mf = BallMonitorBuilder(
dim=2, centre=(0, 0), radius=0.1, amplitude=1e-8, width=0.1
)
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 1))


class TestGradient(BaseClasses.TestMonitorFactory):
class TestGradient(BaseClasses.TestMonitorBuilder):
"""
Unit tests for :class:`~.GradientMonitorFactory`.
Unit tests for :class:`~.GradientMonitorBuilder`.
"""

def test_tiny_scale_factor(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x**2)
mf = GradientMonitorFactory(dim=2, solution=self.solution, scale_factor=1e-8)
mf = GradientMonitorBuilder(dim=2, solution=self.solution, scale_factor=1e-8)
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 1))

def test_linear(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x)
mf = GradientMonitorFactory(dim=2, solution=self.solution, scale_factor=1)
mf = GradientMonitorBuilder(dim=2, solution=self.solution, scale_factor=1)
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 2))


class TestHessian(BaseClasses.TestMonitorFactory):
class TestHessian(BaseClasses.TestMonitorBuilder):
"""
Unit tests for :class:`~.HessianMonitorFactory`.
Unit tests for :class:`~.HessianMonitorBuilder`.
"""

def test_tiny_scale_factor(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x**3)
mf = HessianMonitorFactory(dim=2, solution=self.solution, scale_factor=1e-8)
mf = HessianMonitorBuilder(dim=2, solution=self.solution, scale_factor=1e-8)
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 1))

def test_quadratic(self):
x, y = SpatialCoordinate(self.mesh)
P2 = FunctionSpace(self.mesh, "CG", 2)
solution = Function(P2)
solution.interpolate(0.5 * x**2)
mf = HessianMonitorFactory(dim=2, solution=solution, scale_factor=1)
mf = HessianMonitorBuilder(dim=2, solution=solution, scale_factor=1)
self.assertTrue(np.allclose(mf.get_monitor()(self.mesh).dat.data, 2))


class TestGradientHessian(BaseClasses.TestMonitorFactory):
class TestGradientHessian(BaseClasses.TestMonitorBuilder):
"""
Unit tests for :class:`~.GradientHessianMonitorFactory`.
Unit tests for :class:`~.GradientHessianMonitorBuilder`.
"""

def test_tiny_scale_factors(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x**3)
mf = GradientHessianMonitorFactory(
mf = GradientHessianMonitorBuilder(
dim=2,
solution=self.solution,
gradient_scale_factor=1e-8,
Expand All @@ -108,13 +108,13 @@ def test_tiny_scale_factors(self):
def test_tiny_hessian_scale_factor(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x**3)
mf1 = GradientHessianMonitorFactory(
mf1 = GradientHessianMonitorBuilder(
dim=2,
solution=self.solution,
gradient_scale_factor=1,
hessian_scale_factor=1e-8,
)
mf2 = GradientMonitorFactory(
mf2 = GradientMonitorBuilder(
dim=2,
solution=self.solution,
scale_factor=1,
Expand All @@ -126,13 +126,13 @@ def test_tiny_hessian_scale_factor(self):
def test_tiny_gradient_scale_factor(self):
x, y = SpatialCoordinate(self.mesh)
self.solution.interpolate(x**3)
mf1 = GradientHessianMonitorFactory(
mf1 = GradientHessianMonitorBuilder(
dim=2,
solution=self.solution,
gradient_scale_factor=1e-8,
hessian_scale_factor=1,
)
mf2 = HessianMonitorFactory(
mf2 = HessianMonitorBuilder(
dim=2,
solution=self.solution,
scale_factor=1,
Expand Down

0 comments on commit 5a4d542

Please sign in to comment.