Skip to content

Commit

Permalink
Merge pull request Pyomo#3338 from jsiirola/valsetindex
Browse files Browse the repository at this point in the history
Support validate / filter for IndexedSet components using the index
  • Loading branch information
blnicho authored Aug 20, 2024
2 parents 3ea08c9 + e64143d commit 084d057
Show file tree
Hide file tree
Showing 8 changed files with 676 additions and 253 deletions.
6 changes: 5 additions & 1 deletion examples/pyomo/tutorials/set.dat
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ set S[5] := 2 3;
set T[2] := 1 3;
set T[5] := 2 3;

set T_indexed_validate[2] := 1;
set T_indexed_validate[3] := 1 2;
set T_indexed_validate[4] := 1 2 3;

set X[2] := 1;
set X[5] := 2 3;
set X[5] := 2 3;
9 changes: 7 additions & 2 deletions examples/pyomo/tutorials/set.out
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
24 Set Declarations
25 Set Declarations
A : Size=1, Index=None, Ordered=Insertion
Key : Dimen : Domain : Size : Members
None : 1 : Any : 3 : {1, 2, 3}
Expand Down Expand Up @@ -80,6 +80,11 @@
Key : Dimen : Domain : Size : Members
2 : 1 : Any : 2 : {1, 3}
5 : 1 : Any : 2 : {2, 3}
T_indexed_validate : Size=3, Index=B, Ordered=Insertion
Key : Dimen : Domain : Size : Members
2 : 1 : Any : 1 : {1,}
3 : 1 : Any : 2 : {1, 2}
4 : 1 : Any : 3 : {1, 2, 3}
U : Size=1, Index=None, Ordered=Insertion
Key : Dimen : Domain : Size : Members
None : 1 : Any : 5 : {1, 2, 6, 24, 120}
Expand All @@ -94,4 +99,4 @@
2 : 1 : S[2] : 1 : {1,}
5 : 1 : S[5] : 2 : {2, 3}

24 Declarations: A B C D E F G H Hsub I J K K_2 L M N O P R S X T U V
25 Declarations: A B C D E F G H Hsub I J K K_2 L M N O P R S X T T_indexed_validate U V
12 changes: 11 additions & 1 deletion examples/pyomo/tutorials/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,17 @@ def T_validate(model, value):
return value in model.A


model.T = Set(model.B, validate=M_validate)
model.T = Set(model.B, validate=T_validate)


#
# Validation also provides the index within the IndexedSet being validated:
#
def T_indexed_validate(model, value, i):
return value in model.A and value < i


model.T_indexed_validate = Set(model.B, validate=T_indexed_validate)


##
Expand Down
112 changes: 97 additions & 15 deletions pyomo/core/base/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
from collections.abc import Mapping

from pyomo.common.autoslots import AutoSlots
from pyomo.common.dependencies import numpy, numpy_available, pandas, pandas_available
from pyomo.common.modeling import NOTSET
from pyomo.core.pyomoobject import PyomoObject
Expand All @@ -37,6 +38,7 @@ def Initializer(
allow_generators=False,
treat_sequences_as_mappings=True,
arg_not_specified=None,
additional_args=0,
):
"""Standardized processing of Component keyword arguments
Expand Down Expand Up @@ -69,9 +71,54 @@ def Initializer(
If ``arg`` is ``arg_not_specified``, then the function will
return None (and not an InitializerBase object).
additional_args: int
The number of additional arguments that will be passed to any
function calls (provided *before* the index value).
"""
if arg is arg_not_specified:
return None
if additional_args:
if arg.__class__ in function_types:
if allow_generators or inspect.isgeneratorfunction(arg):
raise ValueError(
"Generator functions are not allowed when passing additional args"
)
_args = inspect.getfullargspec(arg)
_nargs = len(_args.args)
if inspect.ismethod(arg) and arg.__self__ is not None:
# Ignore 'self' for bound instance methods and 'cls' for
# @classmethods
_nargs -= 1
if _nargs == 1 + additional_args and _args.varargs is None:
return ParameterizedScalarCallInitializer(arg, constant=True)
else:
return ParameterizedIndexedCallInitializer(arg)
else:
base_initializer = Initializer(
arg=arg,
allow_generators=allow_generators,
treat_sequences_as_mappings=treat_sequences_as_mappings,
arg_not_specified=arg_not_specified,
)
if type(base_initializer) in (
ScalarCallInitializer,
IndexedCallInitializer,
):
# This is an edge case: if we are providing additional
# args, but this is the first time we are seeing a
# callable type, we will (potentially) incorrectly
# categorize this as an IndexedCallInitializer. Re-try
# now that we know this is a function_type.
return Initializer(
arg=base_initializer._fcn,
allow_generators=allow_generators,
treat_sequences_as_mappings=treat_sequences_as_mappings,
arg_not_specified=arg_not_specified,
additional_args=additional_args,
)
return ParameterizedInitializer(base_initializer)
if arg.__class__ in initializer_map:
return initializer_map[arg.__class__](arg)
if arg.__class__ in sequence_types:
Expand Down Expand Up @@ -193,27 +240,13 @@ def Initializer(
return ConstantInitializer(arg)


class InitializerBase(object):
class InitializerBase(AutoSlots.Mixin, object):
"""Base class for all Initializer objects"""

__slots__ = ()

verified = False

def __getstate__(self):
"""Class serializer
This class must declare __getstate__ because it is slotized.
This implementation should be sufficient for simple derived
classes (where __slots__ are only declared on the most derived
class).
"""
return {k: getattr(self, k) for k in self.__slots__}

def __setstate__(self, state):
for key, val in state.items():
object.__setattr__(self, key, val)

def constant(self):
"""Return True if this initializer is constant across all indices"""
return False
Expand Down Expand Up @@ -316,6 +349,18 @@ def __call__(self, parent, idx):
return self._fcn(parent, idx)


class ParameterizedIndexedCallInitializer(IndexedCallInitializer):
"""IndexedCallInitializer that accepts additional arguments"""

__slots__ = ()

def __call__(self, parent, idx, *args):
if idx.__class__ is tuple:
return self._fcn(parent, *args, *idx)
else:
return self._fcn(parent, *args, idx)


class CountedCallGenerator(object):
"""Generator implementing the "counted call" initialization scheme
Expand Down Expand Up @@ -442,6 +487,15 @@ def constant(self):
return self._constant


class ParameterizedScalarCallInitializer(ScalarCallInitializer):
"""ScalarCallInitializer that accepts additional arguments"""

__slots__ = ()

def __call__(self, parent, idx, *args):
return self._fcn(parent, *args)


class DefaultInitializer(InitializerBase):
"""Initializer wrapper that maps exceptions to default values.
Expand Down Expand Up @@ -485,6 +539,34 @@ def indices(self):
return self._initializer.indices()


class ParameterizedInitializer(InitializerBase):
"""Base class for all Initializer objects"""

__slots__ = ('_base_initializer',)

def __init__(self, base):
self._base_initializer = base

def constant(self):
"""Return True if this initializer is constant across all indices"""
return self._base_initializer.constant()

def contains_indices(self):
"""Return True if this initializer contains embedded indices"""
return self._base_initializer.contains_indices()

def indices(self):
"""Return a generator over the embedded indices
This will raise a RuntimeError if this initializer does not
contain embedded indices
"""
return self._base_initializer.indices()

def __call__(self, parent, idx, *args):
return self._base_initializer(parent, idx)(parent, *args)


_bound_sequence_types = collections.defaultdict(None.__class__)


Expand Down
Loading

0 comments on commit 084d057

Please sign in to comment.