Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.3.0 #38

Merged
merged 24 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bde067e
added torch reinhard base support
andreped Dec 3, 2022
03d499a
removed redundant import and added newline in tf
andreped Dec 3, 2022
ce443d8
added all utils for torch
andreped Dec 3, 2022
18b539e
torch reinhard almost working
andreped Dec 3, 2022
ede14d7
fixed bug - torch reinhard works
andreped Dec 3, 2022
51f9448
updated README as all backends are supported with Reinhard
andreped Dec 3, 2022
3af7477
added reinhard test for all backends
andreped Dec 3, 2022
17c2d28
added support for modified reinhard, all backends
andreped Dec 22, 2022
afe810c
updated README regarding modified reinhard support
andreped Dec 22, 2022
f14df82
Removed deprecated torch.lstsq, replaced by torch.linalg.lstsq
raphaelattias Jan 11, 2023
6883215
removed .vscode, revert change to example.py and updated min version …
raphaelattias Jan 11, 2023
26091e4
fix example
raphaelattias Jan 11, 2023
c5b715c
add torch version checking for backward compatibility
raphaelattias Jan 11, 2023
ab16792
Merge pull request #27 from andreped/reinhard
carloalbertobarbano Jan 17, 2023
9b8e6cc
Merge pull request #30 from andreped/modified-reinhard
carloalbertobarbano Jan 17, 2023
fe23f25
use python float instead of np.float
carloalbertobarbano Jan 17, 2023
3d0bbad
Merge branch 'development' of github.com:EIDOSLAB/torchstain into dev…
carloalbertobarbano Jan 17, 2023
2b4dd90
Added torch version check in the __init__
raphaelattias Jan 17, 2023
64622cd
Merge pull request #33 from raphaelattias/deprecated-lstsq
carloalbertobarbano Jan 18, 2023
80fb609
Merge pull request #34 from EIDOSLAB/development
carloalbertobarbano Jan 20, 2023
dec5429
fix pytorch/python incompatibility
carloalbertobarbano Mar 2, 2023
13a338c
check for torch.linalg.lstsq instead of inconsistent torch.__version__
carloalbertobarbano Mar 2, 2023
602b5b1
remove ssim from tests
carloalbertobarbano Mar 2, 2023
435b547
bump version to 1.3.0
carloalbertobarbano Mar 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/tests_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ jobs:
matrix:
os: [ windows-2019, ubuntu-18.04, macos-11 ]
python-version: [ 3.6, 3.7, 3.8, 3.9 ]
pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0]
pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
exclude:
- python-version: 3.6
pytorch-version: 1.11.0
- python-version: 3.6
pytorch-version: 1.12.0
- python-version: 3.6
pytorch-version: 1.13.0

steps:
- uses: actions/checkout@v1
Expand Down
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy.
Normalization algorithms currently implemented:

- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support)
- Macenko [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
- Reinhard [\[2\]](#reference)
- Modified Reinhard [\[3\]](#reference)

## Installation

Expand Down Expand Up @@ -49,7 +50,8 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)
| Algorithm | numpy | torch | tensorflow |
|-|-|-|-|
| Macenko | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✗ | ✓ |
| Reinhard | ✓ | ✓ | ✓ |
| Modified Reinhard | ✓ | ✓ | ✓ |

## Backend comparison

Expand All @@ -68,8 +70,9 @@ Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz

## Reference

- [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009.
- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001.
- [1] Macenko, Marc et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009.
- [2] Reinhard, Erik et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001.
- [3] Roy, Santanu et al. "Modified Reinhard Algorithm for Color Normalization of Colorectal Cancer Histopathology Images". 2021 29th European Signal Processing Conference (EUSIPCO), IEEE, 2021.

## Citing

Expand Down
1 change: 0 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
t_ = time.time()
norm, H, E = tf_normalizer.normalize(I=t_to_transform, stains=True)
print("tf runtime:", time.time() - t_)

plt.figure()
plt.suptitle('tensorflow normalizer')
plt.subplot(2, 2, 1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='torchstain',
version='1.2.0',
version='1.3.0',
description='Stain normalization tools for histological analysis and computational pathology',
long_description=README,
long_description_content_type='text/markdown',
Expand Down
38 changes: 32 additions & 6 deletions tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import torchstain
import torchstain.tf
import tensorflow as tf
import time
from skimage.metrics import structural_similarity as ssim
import numpy as np

def test_cov():
Expand All @@ -22,7 +20,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)

def test_normalize_tf():
def test_macenko_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
Expand All @@ -44,8 +42,36 @@ def test_normalize_tf():
result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_tf = result_tf.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_tf = result_tf.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)

def test_reinhard_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = lambda x: tf.convert_to_tensor(x, dtype=tf.float32)
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy')
normalizer.fit(target)

tf_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='tensorflow')
tf_normalizer.fit(T(target))

# transform
result_numpy = normalizer.normalize(I=to_transform)
result_tf = tf_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32") / 255.
result_tf = result_tf.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)
45 changes: 38 additions & 7 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import torchstain.torch
import torch
import torchvision
import time
import numpy as np
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim


def setup_function(fn):
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)
Expand All @@ -27,7 +26,7 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)

def test_normalize_torch():
def test_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
Expand All @@ -36,7 +35,7 @@ def test_normalize_torch():
# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x*255)
transforms.Lambda(lambda x: x * 255)
])
t_to_transform = T(to_transform)

Expand All @@ -52,8 +51,40 @@ def test_normalize_torch():
result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_torch = result_torch.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_torch = result_torch.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)

def test_reinhard_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy')
normalizer.fit(target)

torch_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='torch')
torch_normalizer.fit(T(target))

# transform
result_numpy = normalizer.normalize(I=to_transform)
result_torch = torch_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32") / 255.
result_torch = result_torch.numpy().astype("float32") / 255.


# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
2 changes: 1 addition & 1 deletion torchstain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '1.2.0'
__version__ = '1.3.0'

from torchstain.base import normalizers
9 changes: 5 additions & 4 deletions torchstain/base/normalizers/reinhard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
def ReinhardNormalizer(backend='numpy'):
def ReinhardNormalizer(backend='numpy', method=None):
if backend == 'numpy':
from torchstain.numpy.normalizers import NumpyReinhardNormalizer
return NumpyReinhardNormalizer()
return NumpyReinhardNormalizer(method=method)
elif backend == "torch":
raise NotImplementedError
from torchstain.torch.normalizers import TorchReinhardNormalizer
return TorchReinhardNormalizer(method=method)
elif backend == "tensorflow":
from torchstain.tf.normalizers import TensorFlowReinhardNormalizer
return TensorFlowReinhardNormalizer()
return TensorFlowReinhardNormalizer(method=method)
else:
raise Exception(f'Unknown backend {backend}')
2 changes: 1 addition & 1 deletion torchstain/numpy/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):

def __convert_rgb2od(self, I, Io=240, beta=0.15):
# calculate optical density
OD = -np.log((I.astype(np.float)+1)/Io)
OD = -np.log((I.astype(float)+1)/Io)

# remove transparent pixels
ODhat = OD[~np.any(OD < beta, axis=1)]
Expand Down
26 changes: 22 additions & 4 deletions torchstain/numpy/normalizers/reinhard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py
"""
class NumpyReinhardNormalizer(HENormalizer):
def __init__(self):
def __init__(self, method=None):
super().__init__()
self.method = method
self.target_mus = None
self.target_stds = None

Expand Down Expand Up @@ -41,9 +42,26 @@ def normalize(self, I):
mus = stack_[:, 0]
stds = stack_[:, 1]

# standardize intensities channel-wise and normalize using target mus and stds
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
in zip(labs, mus, stds, self.target_means, self.target_stds)]
# normalize
if self.method is None:
# standardize intensities channel-wise and normalize using target mus and stds
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
in zip(labs, mus, stds, self.target_means, self.target_stds)]

elif self.method == "modified":
# calculate q
q = (self.target_stds[0] - stds[0]) / self.target_stds[0]
q = 0.05 if q <= 0 else q

# normalize each channel independently
l_norm = mus[0] + (labs[0] - mus[0]) * (1 + q)
a_norm = self.target_means[1] + (labs[1] - mus[1])
b_norm = self.target_means[2] + (labs[2] - mus[2])

result = [l_norm, a_norm, b_norm]

else:
raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.")

# rebuild LAB
lab = lab_merge(*result)
Expand Down
2 changes: 1 addition & 1 deletion torchstain/tf/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
26 changes: 22 additions & 4 deletions torchstain/tf/normalizers/reinhard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py
"""
class TensorFlowReinhardNormalizer(HENormalizer):
def __init__(self):
def __init__(self, method=None):
super().__init__()
self.method = method
self.target_mus = None
self.target_stds = None

Expand Down Expand Up @@ -41,9 +42,26 @@ def normalize(self, I):
mus = stack_[:, 0]
stds = stack_[:, 1]

# standardize intensities channel-wise and normalize using target mus and stds
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
in zip(labs, mus, stds, self.target_means, self.target_stds)]
# normalize
if self.method is None:
# standardize intensities channel-wise and normalize using target mus and stds
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
in zip(labs, mus, stds, self.target_means, self.target_stds)]

elif self.method == "modified":
# calculate q
q = (self.target_stds[0] - stds[0]) / self.target_stds[0]
q = 0.05 if q <= 0 else q

# normalize each channel independently
l_norm = mus[0] + (labs[0] - mus[0]) * (1 + q)
a_norm = self.target_means[1] + (labs[1] - mus[1])
b_norm = self.target_means[2] + (labs[2] - mus[2])

result = [l_norm, a_norm, b_norm]

else:
raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.")

# rebuild LAB
lab = lab_merge(*result)
Expand Down
1 change: 0 additions & 1 deletion torchstain/tf/utils/split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorflow as tf
from torchstain.tf.utils.rgb2lab import rgb2lab

def csplit(I):
return [I[..., i] for i in range(I.shape[-1])]
Expand Down
1 change: 1 addition & 0 deletions torchstain/torch/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer
10 changes: 8 additions & 2 deletions torchstain/torch/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def __init__(self):
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])

# Avoid using deprecated torch.lstsq (since 1.9.0)
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)

Expand Down Expand Up @@ -49,13 +52,16 @@ def __find_concentration(self, OD, HE):
Y = OD.T

# determine concentrations of the individual stains
return torch.lstsq(Y, HE)[0][:2]
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]

return torch.linalg.lstsq(HE, Y)[0]

def __compute_matrices(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

# compute eigenvectors
_, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
Expand Down
Loading