Skip to content

Commit

Permalink
Merge pull request #526 from CUQI-DTU/add_gradient_to_uniform
Browse files Browse the repository at this point in the history
Add gradient to uniform distribution
  • Loading branch information
chaozg authored Sep 26, 2024
2 parents 59ae0d0 + db2ccba commit d00b1c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
12 changes: 12 additions & 0 deletions cuqi/distribution/_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, low=None, high=None, is_symmetric=True, **kwargs):
self.high = high

def logpdf(self, x):
"""
Evaluate the logarithm of the PDF at the given values of x.
"""
# First check whether x is outside bounds.
# It is outside if any coordinate is outside the interval.
if np.any(x < self.low) or np.any(x > self.high):
Expand All @@ -38,6 +41,15 @@ def logpdf(self, x):
return_val = np.log(1.0/v)
return return_val

def gradient(self, x):
"""
Computes the gradient of logpdf at the given values of x.
"""
if np.any(x < self.low) or np.any(x > self.high):
return np.NaN*np.ones_like(x)
else:
return np.zeros_like(x)

def _sample(self,N=1, rng=None):

if rng is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_distributions_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_multivariate_scalar_vars_gradient(dist):
try:
assert np.allclose(
dist_from_vec.gradient(val),
dist_from_dim.gradient(val)
dist_from_dim.gradient(val),
equal_nan=True
)
except NotImplementedError:
pass # Pass the test if NotImplementedError is raised

0 comments on commit d00b1c9

Please sign in to comment.