Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cupy implementation of eq_hist #1129

Merged
merged 1 commit into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Cupy implementation of eq_hist
  • Loading branch information
ianthomas23 committed Oct 5, 2022
commit 2df3e349a60bca78b20af12fbc878dacd75b2351
5 changes: 5 additions & 0 deletions datashader/tests/test_transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def test_shade(agg, attr, span):

img = tf.shade(x, cmap=cmap, how='eq_hist', rescale_discrete_levels=True)
sol = tf.Image(eq_hist_sol_rescale_discrete_levels[attr], coords=coords, dims=dims)
if cupy and attr=='a' and isinstance(agg.a.data, cupy.ndarray):
# cupy eq_hist has slightly different numerics hence slightly different RGBA results
sol = sol.copy(deep=True)
sol[2, 0] = sol[2, 0] - 0x100

assert_eq_xr(img, sol)

img = tf.shade(x, cmap=cmap,
Expand Down
19 changes: 11 additions & 8 deletions datashader/transfer_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,36 +160,39 @@ def eq_hist(data, mask=None, nbins=256*256):
"""
if cupy and isinstance(data, cupy.ndarray):
from._cuda_utils import interp
array_module = cupy
elif not isinstance(data, np.ndarray):
raise TypeError("data must be an ndarray")
else:
interp = np.interp
array_module = np

data2 = data if mask is None else data[~mask]

# Run more accurate value counting if data is of boolean or integer type
# and unique value array is smaller than nbins.
if data2.dtype == bool or (np.issubdtype(data2.dtype, np.integer) and data2.ptp() < nbins):
values, counts = np.unique(data2, return_counts=True)
vmin, vmax = values[0], values[-1]
if data2.dtype == bool or (array_module.issubdtype(data2.dtype, array_module.integer) and
data2.ptp() < nbins):
values, counts = array_module.unique(data2, return_counts=True)
vmin, vmax = values[0].item(), values[-1].item() # Convert from arrays to scalars.
interval = vmax-vmin
bin_centers = np.arange(vmin, vmax+1)
hist = np.zeros(interval+1, dtype='uint64')
bin_centers = array_module.arange(vmin, vmax+1)
hist = array_module.zeros(interval+1, dtype='uint64')
hist[values-vmin] = counts
discrete_levels = len(values)
else:
hist, bin_edges = np.histogram(data2, bins=nbins)
hist, bin_edges = array_module.histogram(data2, bins=nbins)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
keep_mask = (hist > 0)
discrete_levels = np.count_nonzero(keep_mask)
discrete_levels = array_module.count_nonzero(keep_mask)
if discrete_levels != len(hist):
# Remove empty histogram bins.
hist = hist[keep_mask]
bin_centers = bin_centers[keep_mask]
cdf = hist.cumsum()
cdf = cdf / float(cdf[-1])
out = interp(data, bin_centers, cdf).reshape(data.shape)
return out if mask is None else np.where(mask, np.nan, out), discrete_levels
return out if mask is None else array_module.where(mask, array_module.nan, out), discrete_levels



Expand Down