Skip to content

Commit

Permalink
Support log_prob(value, summed) for DiagMultivariateNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Dec 16, 2024
1 parent 3530d14 commit 230c47f
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions alf/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,12 @@ def __init__(self, loc, scale):
def stddev(self):
return self.base_dist.stddev

def log_prob(self, value, summed=True):
if summed:
return self.base_dist.log_prob(value).sum(-1)
else:
return self.base_dist.log_prob(value)


@alf.configurable(whitelist=['eps'])
class Beta(td.Beta):
Expand Down

0 comments on commit 230c47f

Please sign in to comment.