diff --git a/interfere/interventions.py b/interfere/interventions.py index 6346a05..7f1a82b 100644 --- a/interfere/interventions.py +++ b/interfere/interventions.py @@ -150,11 +150,11 @@ def __init__( intervention([10.0, 4.0, 4.0], 0) == [1.6, 4.0, 0.7] # (True.) """ # The case where indexs and constants are floats or ints - if isinstance(intervened_idxs, int) and isinstance(constants, (int, float)): - i = intervened_idxs - c = float(constants) - intervened_idxs = [intervened_idxs] - constants = [c] + if isinstance(intervened_idxs, (int, float)): + intervened_idxs = [int(intervened_idxs)] + + if isinstance(constants, (int, float)): + constants = [float(constants)] if len(constants) != len(intervened_idxs): raise ValueError(