diff --git a/alf/utils/dist_utils.py b/alf/utils/dist_utils.py index 31f2978cf..b903a42b8 100644 --- a/alf/utils/dist_utils.py +++ b/alf/utils/dist_utils.py @@ -373,6 +373,12 @@ def __init__(self, loc, scale): def stddev(self): return self.base_dist.stddev + def log_prob(self, value, summed=True): + if summed: + return self.base_dist.log_prob(value).sum(-1) + else: + return self.base_dist.log_prob(value) + @alf.configurable(whitelist=['eps']) class Beta(td.Beta):