Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft release v1.3.0 rc1 #34

Merged
merged 19 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
bde067e
added torch reinhard base support
andreped Dec 3, 2022
03d499a
removed redundant import and added newline in tf
andreped Dec 3, 2022
ce443d8
added all utils for torch
andreped Dec 3, 2022
18b539e
torch reinhard almost working
andreped Dec 3, 2022
ede14d7
fixed bug - torch reinhard works
andreped Dec 3, 2022
51f9448
updated README as all backends are supported with Reinhard
andreped Dec 3, 2022
3af7477
added reinhard test for all backends
andreped Dec 3, 2022
17c2d28
added support for modified reinhard, all backends
andreped Dec 22, 2022
afe810c
updated README regarding modified reinhard support
andreped Dec 22, 2022
f14df82
Removed deprecated torch.lstsq, replaced by torch.linalg.lstsq
raphaelattias Jan 11, 2023
6883215
removed .vscode, revert change to example.py and updated min version …
raphaelattias Jan 11, 2023
26091e4
fix example
raphaelattias Jan 11, 2023
c5b715c
add torch version checking for backward compatibility
raphaelattias Jan 11, 2023
ab16792
Merge pull request #27 from andreped/reinhard
carloalbertobarbano Jan 17, 2023
9b8e6cc
Merge pull request #30 from andreped/modified-reinhard
carloalbertobarbano Jan 17, 2023
fe23f25
use python float instead of np.float
carloalbertobarbano Jan 17, 2023
3d0bbad
Merge branch 'development' of github.com:EIDOSLAB/torchstain into dev…
carloalbertobarbano Jan 17, 2023
2b4dd90
Added torch version check in the __init__
raphaelattias Jan 17, 2023
64622cd
Merge pull request #33 from raphaelattias/deprecated-lstsq
carloalbertobarbano Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ GPU-accelerated stain normalization tools for histopathological images. Compatib
Normalization algorithms currently implemented:

- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support)
- Reinhard et al. [\[2\]](#reference)

## Installation

Expand Down Expand Up @@ -49,7 +49,7 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)
| Algorithm | numpy | torch | tensorflow |
|-|-|-|-|
| Macenko | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✗ | ✓ |
| Reinhard | ✓ | ✓ | ✓ |

## Backend comparison

Expand Down
30 changes: 29 additions & 1 deletion tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)

def test_normalize_tf():
def test_macenko_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
Expand All @@ -49,3 +49,31 @@ def test_normalize_tf():

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)

def test_reinhard_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = lambda x: tf.convert_to_tensor(x, dtype=tf.float32)
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy')
normalizer.fit(target)

tf_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='tensorflow')
tf_normalizer.fit(T(target))

# transform
result_numpy = normalizer.normalize(I=to_transform)
result_tf = tf_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_tf = result_tf.numpy().astype("float32")

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
35 changes: 33 additions & 2 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)

def test_normalize_torch():
def test_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
Expand All @@ -36,7 +36,7 @@ def test_normalize_torch():
# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x*255)
transforms.Lambda(lambda x: x * 255)
])
t_to_transform = T(to_transform)

Expand All @@ -57,3 +57,34 @@ def test_normalize_torch():

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)

def test_reinhard_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy')
normalizer.fit(target)

torch_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='torch')
torch_normalizer.fit(T(target))

# transform
result_numpy = normalizer.normalize(I=to_transform)
result_torch = torch_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_torch = result_torch.numpy().astype("float32")

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
3 changes: 2 additions & 1 deletion torchstain/base/normalizers/reinhard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ def ReinhardNormalizer(backend='numpy'):
from torchstain.numpy.normalizers import NumpyReinhardNormalizer
return NumpyReinhardNormalizer()
elif backend == "torch":
raise NotImplementedError
from torchstain.torch.normalizers import TorchReinhardNormalizer
return TorchReinhardNormalizer()
elif backend == "tensorflow":
from torchstain.tf.normalizers import TensorFlowReinhardNormalizer
return TensorFlowReinhardNormalizer()
Expand Down
2 changes: 1 addition & 1 deletion torchstain/tf/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
1 change: 0 additions & 1 deletion torchstain/tf/utils/split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorflow as tf
from torchstain.tf.utils.rgb2lab import rgb2lab

def csplit(I):
return [I[..., i] for i in range(I.shape[-1])]
Expand Down
1 change: 1 addition & 0 deletions torchstain/torch/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer
55 changes: 55 additions & 0 deletions torchstain/torch/normalizers/reinhard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torchstain.base.normalizers import HENormalizer
from torchstain.torch.utils.rgb2lab import rgb2lab
from torchstain.torch.utils.lab2rgb import lab2rgb
from torchstain.torch.utils.split import csplit, cmerge, lab_split, lab_merge
from torchstain.torch.utils.stats import get_mean_std, standardize

"""
Source code adapted from:
https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py
https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py
"""
class TorchReinhardNormalizer(HENormalizer):
def __init__(self):
super().__init__()
self.target_mus = None
self.target_stds = None

def fit(self, target):
# normalize
target = target.type(torch.float32) / 255

# convert to LAB
lab = rgb2lab(target)

# get summary statistics
stack_ = torch.tensor([get_mean_std(x) for x in lab_split(lab)])
self.target_means = stack_[:, 0]
self.target_stds = stack_[:, 1]

def normalize(self, I):
# normalize
I = I.type(torch.float32) / 255

# convert to LAB
lab = rgb2lab(I)
labs = lab_split(lab)

# get summary statistics from LAB
stack_ = torch.tensor([get_mean_std(x) for x in labs])
mus = stack_[:, 0]
stds = stack_[:, 1]

# standardize intensities channel-wise and normalize using target mus and stds
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
in zip(labs, mus, stds, self.target_means, self.target_stds)]

# rebuild LAB
lab = lab_merge(*result)

# convert back to RGB from LAB
lab = lab2rgb(lab)

# rescale to [0, 255] uint8
return (lab * 255).type(torch.uint8)
4 changes: 4 additions & 0 deletions torchstain/torch/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from torchstain.torch.utils.cov import cov
from torchstain.torch.utils.percentile import percentile
from torchstain.torch.utils.stats import *
from torchstain.torch.utils.split import *
from torchstain.torch.utils.rgb2lab import *
from torchstain.torch.utils.lab2rgb import *
35 changes: 35 additions & 0 deletions torchstain/torch/utils/lab2rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from torchstain.torch.utils.rgb2lab import _rgb2xyz, _white

_xyz2rgb = torch.linalg.inv(_rgb2xyz)

def lab2rgb(lab):
lab = lab.type(torch.float32)

# rescale back from OpenCV format and extract LAB channel
L, a, b = lab[0] / 2.55, lab[1] - 128, lab[2] - 128

# vector scaling to produce X, Y, Z
y = (L + 16.) / 116.
x = (a / 500.) + y
z = y - (b / 200.)

# merge back to get reconstructed XYZ color image
out = torch.stack([x, y, z], axis=0)

# apply boolean transforms
mask = out > 0.2068966
not_mask = torch.logical_not(mask)
out.masked_scatter_(mask, torch.pow(torch.masked_select(out, mask), 3))
out.masked_scatter_(not_mask, (torch.masked_select(out, not_mask) - 16 / 116) / 7.787)

# rescale to the reference white (illuminant)
out = torch.mul(out, _white.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))

# convert XYZ -> RGB color domain
arr = torch.tensordot(out, torch.t(_xyz2rgb).type(out.dtype), dims=([0], [0]))
mask = arr > 0.0031308
not_mask = torch.logical_not(mask)
arr.masked_scatter_(mask, 1.055 * torch.pow(torch.masked_select(arr, mask), 1 / 2.4) - 0.055)
arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) * 12.92)
return torch.clamp(arr, 0, 1)
44 changes: 44 additions & 0 deletions torchstain/torch/utils/rgb2lab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

# constant conversion matrices between color spaces: https://gist.github.com/bikz05/6fd21c812ef6ebac66e1
_rgb2xyz = torch.tensor([[0.412453, 0.357580, 0.180423],
[0.212671, 0.715160, 0.072169],
[0.019334, 0.119193, 0.950227]])

_white = torch.tensor([0.95047, 1., 1.08883])

def rgb2lab(rgb):
arr = rgb.type(torch.float32)

# convert rgb -> xyz color domain
mask = arr > 0.04045
not_mask = torch.logical_not(mask)
arr.masked_scatter_(mask, torch.pow((torch.masked_select(arr, mask) + 0.055) / 1.055, 2.4))
arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) / 12.92)

xyz = torch.tensordot(torch.t(_rgb2xyz), arr, dims=([0], [0]))

# scale by CIE XYZ tristimulus values of the reference white point
arr = torch.mul(xyz, 1 / _white.type(xyz.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))

# nonlinear distortion and linear transformation
mask = arr > 0.008856
not_mask = torch.logical_not(mask)
arr.masked_scatter_(mask, torch.pow(torch.masked_select(arr, mask), 1 / 3))
arr.masked_scatter_(not_mask, 7.787 * torch.masked_select(arr, not_mask) + 16 / 166)

# get each channel as individual tensors
x, y, z = arr[0], arr[1], arr[2]

# vector scaling
L = (116. * y) - 16.
a = 500.0 * (x - y)
b = 200.0 * (y - z)

# OpenCV format
L *= 2.55
a += 128
b += 128

# finally, get LAB color domain
return torch.stack([L, a, b], axis=0)
15 changes: 15 additions & 0 deletions torchstain/torch/utils/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

def csplit(I):
return [I[i] for i in range(I.shape[0])]

def cmerge(I1, I2, I3):
return torch.stack([I1, I2, I3], dim=0)

def lab_split(I):
I = I.type(torch.float32)
I1, I2, I3 = csplit(I)
return I1 / 2.55, I2 - 128, I3 - 128

def lab_merge(I1, I2, I3):
return cmerge(I1 * 2.55, I2 + 128, I3 + 128)
7 changes: 7 additions & 0 deletions torchstain/torch/utils/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

def get_mean_std(I):
return torch.mean(I), torch.std(I)

def standardize(x, mu, std):
return (x - mu) / std