diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index e550c88..3499e72 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -73,12 +73,14 @@ jobs: matrix: os: [ windows-2019, ubuntu-18.04, macos-11 ] python-version: [ 3.6, 3.7, 3.8, 3.9 ] - pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0] + pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] exclude: - python-version: 3.6 pytorch-version: 1.11.0 - python-version: 3.6 pytorch-version: 1.12.0 + - python-version: 3.6 + pytorch-version: 1.13.0 steps: - uses: actions/checkout@v1 diff --git a/README.md b/README.md index 9047b49..74af142 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,9 @@ GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy. Normalization algorithms currently implemented: -- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) -- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support) +- Macenko [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) +- Reinhard [\[2\]](#reference) +- Modified Reinhard [\[3\]](#reference) ## Installation @@ -49,7 +50,8 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True) | Algorithm | numpy | torch | tensorflow | |-|-|-|-| | Macenko | ✓ | ✓ | ✓ | -| Reinhard | ✓ | ✗ | ✓ | +| Reinhard | ✓ | ✓ | ✓ | +| Modified Reinhard | ✓ | ✓ | ✓ | ## Backend comparison @@ -68,8 +70,9 @@ Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz ## Reference -- [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009. -- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001. +- [1] Macenko, Marc et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009. +- [2] Reinhard, Erik et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001. +- [3] Roy, Santanu et al. "Modified Reinhard Algorithm for Color Normalization of Colorectal Cancer Histopathology Images". 2021 29th European Signal Processing Conference (EUSIPCO), IEEE, 2021. ## Citing diff --git a/example.py b/example.py index 6de0f36..98ba5b9 100644 --- a/example.py +++ b/example.py @@ -83,7 +83,6 @@ t_ = time.time() norm, H, E = tf_normalizer.normalize(I=t_to_transform, stains=True) print("tf runtime:", time.time() - t_) - plt.figure() plt.suptitle('tensorflow normalizer') plt.subplot(2, 2, 1) diff --git a/setup.py b/setup.py index d27acf7..a8b3628 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='torchstain', - version='1.2.0', + version='1.3.0', description='Stain normalization tools for histological analysis and computational pathology', long_description=README, long_description_content_type='text/markdown', diff --git a/tests/test_tf.py b/tests/test_tf.py index b24757a..8edb65a 100644 --- a/tests/test_tf.py +++ b/tests/test_tf.py @@ -3,8 +3,6 @@ import torchstain import torchstain.tf import tensorflow as tf -import time -from skimage.metrics import structural_similarity as ssim import numpy as np def test_cov(): @@ -22,7 +20,7 @@ def test_percentile(): np.testing.assert_almost_equal(p_np, p_t) -def test_normalize_tf(): +def test_macenko_tf(): size = 1024 curr_file_path = os.path.dirname(os.path.realpath(__file__)) target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) @@ -44,8 +42,36 @@ def test_normalize_tf(): result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) # convert to numpy and set dtype - result_numpy = result_numpy.astype("float32") - result_tf = result_tf.numpy().astype("float32") + result_numpy = result_numpy.astype("float32") / 255. + result_tf = result_tf.numpy().astype("float32") / 255. # assess whether the normalized images are identical across backends - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) + np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True) + +def test_reinhard_tf(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = lambda x: tf.convert_to_tensor(x, dtype=tf.float32) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy') + normalizer.fit(target) + + tf_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='tensorflow') + tf_normalizer.fit(T(target)) + + # transform + result_numpy = normalizer.normalize(I=to_transform) + result_tf = tf_normalizer.normalize(I=t_to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") / 255. + result_tf = result_tf.numpy().astype("float32") / 255. + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True) diff --git a/tests/test_torch.py b/tests/test_torch.py index 74b6fd2..eb7b60c 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -4,10 +4,9 @@ import torchstain.torch import torch import torchvision -import time import numpy as np from torchvision import transforms -from skimage.metrics import structural_similarity as ssim + def setup_function(fn): print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__) @@ -27,7 +26,7 @@ def test_percentile(): np.testing.assert_almost_equal(p_np, p_t) -def test_normalize_torch(): +def test_macenko_torch(): size = 1024 curr_file_path = os.path.dirname(os.path.realpath(__file__)) target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) @@ -36,7 +35,7 @@ def test_normalize_torch(): # setup preprocessing and preprocess image to be normalized T = transforms.Compose([ transforms.ToTensor(), - transforms.Lambda(lambda x: x*255) + transforms.Lambda(lambda x: x * 255) ]) t_to_transform = T(to_transform) @@ -52,8 +51,40 @@ def test_normalize_torch(): result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) # convert to numpy and set dtype - result_numpy = result_numpy.astype("float32") - result_torch = result_torch.numpy().astype("float32") + result_numpy = result_numpy.astype("float32") / 255. + result_torch = result_torch.numpy().astype("float32") / 255. + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True) + +def test_reinhard_torch(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 255) + ]) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy') + normalizer.fit(target) + + torch_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='torch') + torch_normalizer.fit(T(target)) + + # transform + result_numpy = normalizer.normalize(I=to_transform) + result_torch = torch_normalizer.normalize(I=t_to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") / 255. + result_torch = result_torch.numpy().astype("float32") / 255. + # assess whether the normalized images are identical across backends - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) + np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True) diff --git a/torchstain/__init__.py b/torchstain/__init__.py index dab8618..4b11e31 100644 --- a/torchstain/__init__.py +++ b/torchstain/__init__.py @@ -1,3 +1,3 @@ -__version__ = '1.2.0' +__version__ = '1.3.0' from torchstain.base import normalizers \ No newline at end of file diff --git a/torchstain/base/normalizers/reinhard.py b/torchstain/base/normalizers/reinhard.py index 96947aa..7800933 100644 --- a/torchstain/base/normalizers/reinhard.py +++ b/torchstain/base/normalizers/reinhard.py @@ -1,11 +1,12 @@ -def ReinhardNormalizer(backend='numpy'): +def ReinhardNormalizer(backend='numpy', method=None): if backend == 'numpy': from torchstain.numpy.normalizers import NumpyReinhardNormalizer - return NumpyReinhardNormalizer() + return NumpyReinhardNormalizer(method=method) elif backend == "torch": - raise NotImplementedError + from torchstain.torch.normalizers import TorchReinhardNormalizer + return TorchReinhardNormalizer(method=method) elif backend == "tensorflow": from torchstain.tf.normalizers import TensorFlowReinhardNormalizer - return TensorFlowReinhardNormalizer() + return TensorFlowReinhardNormalizer(method=method) else: raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/numpy/normalizers/macenko.py b/torchstain/numpy/normalizers/macenko.py index 37dac55..faada97 100644 --- a/torchstain/numpy/normalizers/macenko.py +++ b/torchstain/numpy/normalizers/macenko.py @@ -16,7 +16,7 @@ def __init__(self): def __convert_rgb2od(self, I, Io=240, beta=0.15): # calculate optical density - OD = -np.log((I.astype(np.float)+1)/Io) + OD = -np.log((I.astype(float)+1)/Io) # remove transparent pixels ODhat = OD[~np.any(OD < beta, axis=1)] diff --git a/torchstain/numpy/normalizers/reinhard.py b/torchstain/numpy/normalizers/reinhard.py index fbb8d13..d299dd3 100644 --- a/torchstain/numpy/normalizers/reinhard.py +++ b/torchstain/numpy/normalizers/reinhard.py @@ -11,8 +11,9 @@ https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py """ class NumpyReinhardNormalizer(HENormalizer): - def __init__(self): + def __init__(self, method=None): super().__init__() + self.method = method self.target_mus = None self.target_stds = None @@ -41,9 +42,26 @@ def normalize(self, I): mus = stack_[:, 0] stds = stack_[:, 1] - # standardize intensities channel-wise and normalize using target mus and stds - result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ - in zip(labs, mus, stds, self.target_means, self.target_stds)] + # normalize + if self.method is None: + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + elif self.method == "modified": + # calculate q + q = (self.target_stds[0] - stds[0]) / self.target_stds[0] + q = 0.05 if q <= 0 else q + + # normalize each channel independently + l_norm = mus[0] + (labs[0] - mus[0]) * (1 + q) + a_norm = self.target_means[1] + (labs[1] - mus[1]) + b_norm = self.target_means[2] + (labs[2] - mus[2]) + + result = [l_norm, a_norm, b_norm] + + else: + raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.") # rebuild LAB lab = lab_merge(*result) diff --git a/torchstain/tf/normalizers/__init__.py b/torchstain/tf/normalizers/__init__.py index 0916e1b..fb0718f 100644 --- a/torchstain/tf/normalizers/__init__.py +++ b/torchstain/tf/normalizers/__init__.py @@ -1,2 +1,2 @@ from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer -from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer \ No newline at end of file +from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer diff --git a/torchstain/tf/normalizers/reinhard.py b/torchstain/tf/normalizers/reinhard.py index 8a2b601..977f084 100644 --- a/torchstain/tf/normalizers/reinhard.py +++ b/torchstain/tf/normalizers/reinhard.py @@ -11,8 +11,9 @@ https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py """ class TensorFlowReinhardNormalizer(HENormalizer): - def __init__(self): + def __init__(self, method=None): super().__init__() + self.method = method self.target_mus = None self.target_stds = None @@ -41,9 +42,26 @@ def normalize(self, I): mus = stack_[:, 0] stds = stack_[:, 1] - # standardize intensities channel-wise and normalize using target mus and stds - result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ - in zip(labs, mus, stds, self.target_means, self.target_stds)] + # normalize + if self.method is None: + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + elif self.method == "modified": + # calculate q + q = (self.target_stds[0] - stds[0]) / self.target_stds[0] + q = 0.05 if q <= 0 else q + + # normalize each channel independently + l_norm = mus[0] + (labs[0] - mus[0]) * (1 + q) + a_norm = self.target_means[1] + (labs[1] - mus[1]) + b_norm = self.target_means[2] + (labs[2] - mus[2]) + + result = [l_norm, a_norm, b_norm] + + else: + raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.") # rebuild LAB lab = lab_merge(*result) diff --git a/torchstain/tf/utils/split.py b/torchstain/tf/utils/split.py index 226cde0..5717b87 100644 --- a/torchstain/tf/utils/split.py +++ b/torchstain/tf/utils/split.py @@ -1,5 +1,4 @@ import tensorflow as tf -from torchstain.tf.utils.rgb2lab import rgb2lab def csplit(I): return [I[..., i] for i in range(I.shape[-1])] diff --git a/torchstain/torch/normalizers/__init__.py b/torchstain/torch/normalizers/__init__.py index febcd90..c78c273 100644 --- a/torchstain/torch/normalizers/__init__.py +++ b/torchstain/torch/normalizers/__init__.py @@ -1 +1,2 @@ from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer +from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer diff --git a/torchstain/torch/normalizers/macenko.py b/torchstain/torch/normalizers/macenko.py index 710b7e3..74d5a00 100644 --- a/torchstain/torch/normalizers/macenko.py +++ b/torchstain/torch/normalizers/macenko.py @@ -15,6 +15,9 @@ def __init__(self): [0.4062, 0.5581]]) self.maxCRef = torch.tensor([1.9705, 1.0308]) + # Avoid using deprecated torch.lstsq (since 1.9.0) + self.updated_lstsq = hasattr(torch.linalg, 'lstsq') + def __convert_rgb2od(self, I, Io, beta): I = I.permute(1, 2, 0) @@ -49,13 +52,16 @@ def __find_concentration(self, OD, HE): Y = OD.T # determine concentrations of the individual stains - return torch.lstsq(Y, HE)[0][:2] + if not self.updated_lstsq: + return torch.lstsq(Y, HE)[0][:2] + + return torch.linalg.lstsq(HE, Y)[0] def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True) + _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) diff --git a/torchstain/torch/normalizers/reinhard.py b/torchstain/torch/normalizers/reinhard.py new file mode 100644 index 0000000..dccc63f --- /dev/null +++ b/torchstain/torch/normalizers/reinhard.py @@ -0,0 +1,73 @@ +import torch +from torchstain.base.normalizers import HENormalizer +from torchstain.torch.utils.rgb2lab import rgb2lab +from torchstain.torch.utils.lab2rgb import lab2rgb +from torchstain.torch.utils.split import csplit, cmerge, lab_split, lab_merge +from torchstain.torch.utils.stats import get_mean_std, standardize + +""" +Source code adapted from: +https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py +https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py +""" +class TorchReinhardNormalizer(HENormalizer): + def __init__(self, method=None): + super().__init__() + self.method = method + self.target_mus = None + self.target_stds = None + + def fit(self, target): + # normalize + target = target.type(torch.float32) / 255 + + # convert to LAB + lab = rgb2lab(target) + + # get summary statistics + stack_ = torch.tensor([get_mean_std(x) for x in lab_split(lab)]) + self.target_means = stack_[:, 0] + self.target_stds = stack_[:, 1] + + def normalize(self, I): + # normalize + I = I.type(torch.float32) / 255 + + # convert to LAB + lab = rgb2lab(I) + labs = lab_split(lab) + + # get summary statistics from LAB + stack_ = torch.tensor([get_mean_std(x) for x in labs]) + mus = stack_[:, 0] + stds = stack_[:, 1] + + # normalize + if self.method is None: + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + elif self.method == "modified": + # calculate q + q = (self.target_stds[0] - stds[0]) / self.target_stds[0] + q = 0.05 if q <= 0 else q + + # normalize each channel independently + l_norm = mus[0] + (labs[0] - mus[0]) * (1 + q) + a_norm = self.target_means[1] + (labs[1] - mus[1]) + b_norm = self.target_means[2] + (labs[2] - mus[2]) + + result = [l_norm, a_norm, b_norm] + + else: + raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.") + + # rebuild LAB + lab = lab_merge(*result) + + # convert back to RGB from LAB + lab = lab2rgb(lab) + + # rescale to [0, 255] uint8 + return (lab * 255).type(torch.uint8) diff --git a/torchstain/torch/utils/__init__.py b/torchstain/torch/utils/__init__.py index 5e0de3e..4acea5a 100644 --- a/torchstain/torch/utils/__init__.py +++ b/torchstain/torch/utils/__init__.py @@ -1,2 +1,6 @@ from torchstain.torch.utils.cov import cov from torchstain.torch.utils.percentile import percentile +from torchstain.torch.utils.stats import * +from torchstain.torch.utils.split import * +from torchstain.torch.utils.rgb2lab import * +from torchstain.torch.utils.lab2rgb import * diff --git a/torchstain/torch/utils/lab2rgb.py b/torchstain/torch/utils/lab2rgb.py new file mode 100644 index 0000000..250c1c9 --- /dev/null +++ b/torchstain/torch/utils/lab2rgb.py @@ -0,0 +1,35 @@ +import torch +from torchstain.torch.utils.rgb2lab import _rgb2xyz, _white + +_xyz2rgb = torch.linalg.inv(_rgb2xyz) + +def lab2rgb(lab): + lab = lab.type(torch.float32) + + # rescale back from OpenCV format and extract LAB channel + L, a, b = lab[0] / 2.55, lab[1] - 128, lab[2] - 128 + + # vector scaling to produce X, Y, Z + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + + # merge back to get reconstructed XYZ color image + out = torch.stack([x, y, z], axis=0) + + # apply boolean transforms + mask = out > 0.2068966 + not_mask = torch.logical_not(mask) + out.masked_scatter_(mask, torch.pow(torch.masked_select(out, mask), 3)) + out.masked_scatter_(not_mask, (torch.masked_select(out, not_mask) - 16 / 116) / 7.787) + + # rescale to the reference white (illuminant) + out = torch.mul(out, _white.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1)) + + # convert XYZ -> RGB color domain + arr = torch.tensordot(out, torch.t(_xyz2rgb).type(out.dtype), dims=([0], [0])) + mask = arr > 0.0031308 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, 1.055 * torch.pow(torch.masked_select(arr, mask), 1 / 2.4) - 0.055) + arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) * 12.92) + return torch.clamp(arr, 0, 1) diff --git a/torchstain/torch/utils/rgb2lab.py b/torchstain/torch/utils/rgb2lab.py new file mode 100644 index 0000000..3b1aa50 --- /dev/null +++ b/torchstain/torch/utils/rgb2lab.py @@ -0,0 +1,44 @@ +import torch + +# constant conversion matrices between color spaces: https://gist.github.com/bikz05/6fd21c812ef6ebac66e1 +_rgb2xyz = torch.tensor([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + +_white = torch.tensor([0.95047, 1., 1.08883]) + +def rgb2lab(rgb): + arr = rgb.type(torch.float32) + + # convert rgb -> xyz color domain + mask = arr > 0.04045 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, torch.pow((torch.masked_select(arr, mask) + 0.055) / 1.055, 2.4)) + arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) / 12.92) + + xyz = torch.tensordot(torch.t(_rgb2xyz), arr, dims=([0], [0])) + + # scale by CIE XYZ tristimulus values of the reference white point + arr = torch.mul(xyz, 1 / _white.type(xyz.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1)) + + # nonlinear distortion and linear transformation + mask = arr > 0.008856 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, torch.pow(torch.masked_select(arr, mask), 1 / 3)) + arr.masked_scatter_(not_mask, 7.787 * torch.masked_select(arr, not_mask) + 16 / 166) + + # get each channel as individual tensors + x, y, z = arr[0], arr[1], arr[2] + + # vector scaling + L = (116. * y) - 16. + a = 500.0 * (x - y) + b = 200.0 * (y - z) + + # OpenCV format + L *= 2.55 + a += 128 + b += 128 + + # finally, get LAB color domain + return torch.stack([L, a, b], axis=0) diff --git a/torchstain/torch/utils/split.py b/torchstain/torch/utils/split.py new file mode 100644 index 0000000..d6f6fdb --- /dev/null +++ b/torchstain/torch/utils/split.py @@ -0,0 +1,15 @@ +import torch + +def csplit(I): + return [I[i] for i in range(I.shape[0])] + +def cmerge(I1, I2, I3): + return torch.stack([I1, I2, I3], dim=0) + +def lab_split(I): + I = I.type(torch.float32) + I1, I2, I3 = csplit(I) + return I1 / 2.55, I2 - 128, I3 - 128 + +def lab_merge(I1, I2, I3): + return cmerge(I1 * 2.55, I2 + 128, I3 + 128) diff --git a/torchstain/torch/utils/stats.py b/torchstain/torch/utils/stats.py new file mode 100644 index 0000000..0fa45bb --- /dev/null +++ b/torchstain/torch/utils/stats.py @@ -0,0 +1,7 @@ +import torch + +def get_mean_std(I): + return torch.mean(I), torch.std(I) + +def standardize(x, mu, std): + return (x - mu) / std