From 230c47fb00ffd2bcd194de56df35598a3560409a Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Mon, 16 Dec 2024 15:46:19 -0800 Subject: [PATCH] Support log_prob(value, summed) for DiagMultivariateNormal --- alf/utils/dist_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/alf/utils/dist_utils.py b/alf/utils/dist_utils.py index 31f2978cf..b903a42b8 100644 --- a/alf/utils/dist_utils.py +++ b/alf/utils/dist_utils.py @@ -373,6 +373,12 @@ def __init__(self, loc, scale): def stddev(self): return self.base_dist.stddev + def log_prob(self, value, summed=True): + if summed: + return self.base_dist.log_prob(value).sum(-1) + else: + return self.base_dist.log_prob(value) + @alf.configurable(whitelist=['eps']) class Beta(td.Beta):