Skip to content

Commit

Permalink
Add histogram matching (scikit-image#3568)
Browse files Browse the repository at this point in the history
* Add histogram matching, by Solutus Immensus

Recovery from scikit-image#3208
The contributor closed his/her account.

* Histogram matching: Simplify implementation and improve documentation

* Add multichannel support for matching histogram
  • Loading branch information
sciunto authored and jni committed Dec 20, 2018
1 parent 3282eab commit 2ec9b0a
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@
- Lars Grüter
Flood-fill based local maxima detection

- Solutus Immensus
Histogram matching

- Laurent P. René de Cotret
Implementation of masked image translation registration

Expand Down
67 changes: 67 additions & 0 deletions doc/examples/transform/plot_histogram_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
==================
Histogram matching
==================
This example demonstrates the feature of histogram matching. It manipulates the
pixels of an input image so that its histogram matches the histogram of the
reference image. If the images have multiple channels, the matching is done
independently for each channel, as long as the number of channels is equal in
the input image and the reference.
Histogram matching can be used as a lightweight normalisation for image
processing, such as feature matching, especially in circumstances where the
images have been taken from different sources or in different conditions (i.e.
lighting).
"""

import matplotlib.pyplot as plt

from skimage import data
from skimage import exposure
from skimage.transform import match_histograms

reference = data.coffee()
image = data.chelsea()

matched = match_histograms(image, reference)

fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(8, 3),
sharex=True, sharey=True)
for aa in (ax1, ax2, ax3):
aa.set_axis_off()

ax1.imshow(image)
ax1.set_title('Source')
ax2.imshow(reference)
ax2.set_title('Reference')
ax3.imshow(matched)
ax3.set_title('Matched')

plt.tight_layout()
plt.show()


######################################################################
# To illustrate the effect of the histogram matching, we plot for each
# RGB channel, the histogram and the cumulative histogram. Clearly,
# the matched image has the same cumulative histogram as the reference
# image for each channel.

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(8, 8))


for i, img in enumerate((image, reference, matched)):
for c, c_color in enumerate(('red', 'green', 'blue')):
img_hist, bins = exposure.histogram(img[..., c], source_range='dtype')
axes[c, i].plot(bins, img_hist / img_hist.max())
img_cdf, bins = exposure.cumulative_distribution(img[..., c])
axes[c, i].plot(bins, img_cdf)
axes[c, 0].set_ylabel(c_color)

axes[0, 0].set_title('Source')
axes[0, 1].set_title('Reference')
axes[0, 2].set_title('Matched')

plt.tight_layout()
plt.show()
4 changes: 3 additions & 1 deletion skimage/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .histogram_matching import match_histograms
from .hough_transform import (hough_line, hough_line_peaks,
probabilistic_hough_line, hough_circle,
hough_circle_peaks, hough_ellipse)
Expand All @@ -18,7 +19,8 @@
from .seam_carving import seam_carve


__all__ = ['hough_circle',
__all__ = ['match_histograms',
'hough_circle',
'hough_ellipse',
'hough_line',
'probabilistic_hough_line',
Expand Down
72 changes: 72 additions & 0 deletions skimage/transform/histogram_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np


def _match_cumulative_cdf(source, template):
"""
Return modified source array so that the cumulative density function of
its values matches the cumulative density function of the template.
"""
src_values, src_unique_indices, src_counts = np.unique(source.ravel(),
return_inverse=True,
return_counts=True)
tmpl_values, tmpl_counts = np.unique(template.ravel(), return_counts=True)

# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size

interp_a_values = np.interp(src_quantiles, tmpl_quantiles, tmpl_values)
return interp_a_values[src_unique_indices].reshape(source.shape)



def match_histograms(image, reference, multichannel=False):
"""Adjust an image so that its cumulative histogram matches that of another.
The adjustment is applied separately for each channel.
Parameters
----------
image : ndarray
Input image. Can be gray-scale or in color.
reference : ndarray
Image to match histogram of. Must have the same number of channels as
image.
multichannel : bool, optional
Apply the matching separately for each channel.
Returns
-------
matched : ndarray
Transformed input image.
Raises
------
ValueError
Thrown when the number of channels in the input image and the reference
differ.
References
----------
.. [1] http://paulbourke.net/miscellaneous/equalisation/
"""
shape = image.shape
image_dtype = image.dtype

if image.ndim != reference.ndim:
raise ValueError('Image and reference must have the same number of channels.')

if multichannel:
if image.shape[-1] != reference.shape[-1]:
raise ValueError('Number of channels in the input image and reference '
'image must match!')

matched = np.empty(image.shape, dtype=image.dtype)
for channel in range(image.shape[-1]):
matched_channel = _match_cumulative_cdf(image[..., channel], reference[..., channel])
matched[..., channel] = matched_channel
else:
matched = _match_cumulative_cdf(image, reference)

return matched
80 changes: 80 additions & 0 deletions skimage/transform/tests/test_histogram_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np

from skimage.transform import histogram_matching
from skimage import transform
from skimage import data

from skimage._shared.testing import assert_array_almost_equal, \
assert_almost_equal

import pytest


@pytest.mark.parametrize('array, template, expected_array', [
(np.arange(10), np.arange(100), np.arange(9, 100, 10)),
(np.random.rand(4), np.ones(3), np.ones(4))
])
def test_match_array_values(array, template, expected_array):
# when
matched = histogram_matching._match_cumulative_cdf(array, template)

# then
assert_array_almost_equal(matched, expected_array)


class TestMatchHistogram:

image_rgb = data.chelsea()
template_rgb = data.astronaut()

# To handle with mutlichannel ==False
#(image_rgb[:, :, 0], template_rgb[:, :, 0]),
@pytest.mark.parametrize('image, reference', [
(image_rgb, template_rgb)
])
def test_match_histograms(self, image, reference):
"""Assert that pdf of matched image is close to the reference's pdf for
all channels and all values of matched"""

# when
matched = transform.match_histograms(image, reference, multichannel=True)

matched_pdf = self._calculate_image_empirical_pdf(matched)
reference_pdf = self._calculate_image_empirical_pdf(reference)

# then
for channel in range(len(matched_pdf)):
reference_values, reference_quantiles = reference_pdf[channel]
matched_values, matched_quantiles = matched_pdf[channel]

for i, matched_value in enumerate(matched_values):
closest_id = (np.abs(reference_values - matched_value)).argmin()
assert_almost_equal(matched_quantiles[i],
reference_quantiles[closest_id], decimal=1)

@pytest.mark.parametrize('image, reference', [
(image_rgb, template_rgb[:, :, 0]),
(image_rgb[:, :, 0], template_rgb)
])
def test_raises_value_error_on_channels_mismatch(self, image, reference):
with pytest.raises(ValueError):
transform.match_histograms(image, reference)

@classmethod
def _calculate_image_empirical_pdf(cls, image):
"""Helper function for calculating empirical probability density
function of a given image for all channels"""

if image.ndim > 2:
image = image.transpose(2, 0, 1)
channels = np.array(image, copy=False, ndmin=3)

channels_pdf = []
for channel in channels:
channel_values, counts = np.unique(channel, return_counts=True)
channel_quantiles = np.cumsum(counts).astype(np.float64)
channel_quantiles /= channel_quantiles[-1]

channels_pdf.append((channel_values, channel_quantiles))

return np.asarray(channels_pdf)

0 comments on commit 2ec9b0a

Please sign in to comment.