Skip to content

Commit

Permalink
Add pixel-wise weights, add entropy-based sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
edornd committed Dec 21, 2021
1 parent 43793e6 commit b75acc1
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 10 deletions.
2 changes: 2 additions & 0 deletions floods/config/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class PreparationConfig(EnvConfig):
morphology: bool = Field(True, description="whether to use morphological operators or not")
morph_kernel: int = Field(5, description="Kernel size for mask preprocessing using opening/closing")
nan_threshold: float = Field(0.75, description="Percentage of invalid pixels before discaring the tile")
vv_multiplier: float = Field(5.0, description="Fixed multiplier for threshold-based pseudolabeling (1st channel)")
vh_multiplier: float = Field(10.0, description="Fixed multiplier for threshold-based pseudolabeling (2nd channel)")

def subset_exists(cls, v):
allowed = {"train", "test", "val"}
Expand Down
2 changes: 1 addition & 1 deletion floods/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DatasetConfig(EnvConfig):
class_weights: str = Field(None, description="Optional path to a class weight array (npy format)")
mask_body_ratio: float = Field(None, description="Percentage of ones in the mask before discarding the tile")
weighted_sampling: bool = Field(False, description="Whether to sample images based on flooded ratio")
sample_smoothing: float = Field(0.8, description="Value between 0 and 1 to smooth out the weights (1 = maximum)")
sample_smoothing: float = Field(0.8, description="Value between 0 and 1 to smooth out the weights (1 = None)")


class ModelConfig(EnvConfig):
Expand Down
36 changes: 36 additions & 0 deletions floods/datasets/flood.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,39 @@ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:

def __len__(self) -> int:
return len(self.image_files)


class WeightedFloodDataset(FloodDataset):
def __init__(self,
path: Path,
subset: str = "train",
include_dem: bool = False,
transform_base: Callable = None,
transform_sar: Callable = None,
transform_dem: Callable = None,
normalization: Callable = None,
class_weights: Tuple[int, int, int] = (1.0, 0.5, 5.0)) -> None:
super().__init__(path,
subset=subset,
include_dem=include_dem,
transform_base=transform_base,
transform_sar=transform_sar,
transform_dem=transform_dem,
normalization=normalization)
# we need 256 positions to account for 255 indices (ignore index)
weights_array = np.zeros(256, dtype=np.float32)
weights_array[:len(class_weights)] = np.array(class_weights)
self.class_weights = weights_array
self.weight_files = sorted(glob(str(path / subset / "weight" / "*.tif")))
assert len(self.image_files) == len(self.weight_files), \
f"Length mismatch between tiles and weights: {len(self.image_files)} != {len(self.weight_files)}"

def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
image, label = super().__getitem__(index)
# read the weight map from file
# 0 = background, 1 = thresholded water U ground truth 2 = threshold ∩ ground truth
# based on this, we produce a pixel-wise weight map, where we aim at giving more weight
# to areas where it's flooded and the threshold agrees, less where it's confused
weight_indices = imread(self.weight_files[index]).squeeze(0).astype(np.uint8)
weight = self.class_weights[weight_indices]
return image, label, weight
10 changes: 6 additions & 4 deletions floods/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from floods.models.modules import SegmentationHead
from floods.transforms import ClipNormalize, Denormalize
from floods.utils.common import get_logger
from floods.utils.tiling.functional import mask_body_ratio_from_threshold, weights_from_body_ratio
from floods.utils.tiling.functional import entropy_weights, mask_body_ratio_from_threshold

LOG = get_logger(__name__)

Expand Down Expand Up @@ -127,15 +127,15 @@ def prepare_datasets(config: TrainConfig) -> Tuple[DatasetBase, DatasetBase]:
return train_dataset, valid_dataset


def prepare_sampler(data_root: str, dataset: FloodDataset, smoothing: float = 0.9) -> WeightedRandomSampler:
def prepare_sampler(data_root: str, dataset: FloodDataset, smoothing: float = 0.8) -> WeightedRandomSampler:
data_name = Path(data_root).stem
target_file = Path("data") / f"{data_name}_sample-weights_smooth-{smoothing:.2f}.npy"
if target_file.exists() and target_file.is_file():
LOG.info("Found an existing array of sample weights")
weights = np.load(str(target_file))
else:
LOG.info("Computing weights for weighted random sampling")
weights = weights_from_body_ratio(dataset.label_files, smoothing=smoothing)
weights = entropy_weights(dataset.label_files, smoothing=smoothing)
np.save(str(target_file), weights)
# completely arbitrary, this is just here to maximize the amount of images we look at
num_samples = len(dataset) * 2
Expand Down Expand Up @@ -185,7 +185,9 @@ def prepare_model(config: TrainConfig, num_classes: int) -> nn.Module:
extract_features = False
LOG.info("Returning intermediate features: %s", str(extract_features))
# create final segmentation head and build model
head = SegmentationHead(in_channels=decoder.out_channels(), num_classes=num_classes, upscale=decoder.out_reduction())
head = SegmentationHead(in_channels=decoder.out_channels(),
num_classes=num_classes,
upscale=decoder.out_reduction())
model = Segmenter(encoder, decoder, head, return_features=extract_features)
return model

Expand Down
55 changes: 52 additions & 3 deletions floods/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import cv2
import numpy as np
import rasterio
from joblib import Parallel, delayed
from rasterio.enums import Resampling
from rasterio.io import MemoryFile
from rasterio.windows import Window
from skimage.restoration import denoise_nl_means
from tqdm import tqdm

from floods.config.preproc import ImageType, PreparationConfig, StatsConfig
from floods.datasets.flood import FloodDataset
from floods.utils.common import check_or_make_dir, print_config
from floods.utils.gis import imread, mask_raster, write_window
from floods.utils.ml import F16_EPS, identity
Expand All @@ -26,11 +29,11 @@ class MorphologyTransform:
"""Callable operator that applies morphological transforms to the masks.
"""
def __init__(self, kernel_size: int = 5, channels_first: bool = True) -> None:
self.kernel = self._create_round_kernel(kernel_size=kernel_size)
self.kernel = self.create_round_kernel(kernel_size=kernel_size)
self.channels_first = channels_first

def _create_round_kernel(self, kernel_size: int):
# ocmpute center and radius, suppose symmetrical and centered
def create_round_kernel(self, kernel_size: int):
# compute center and radius, suppose symmetrical and centered
center = kernel_size // 2
radius = min(center, kernel_size - center)
# compute a distance grid from the given center
Expand Down Expand Up @@ -445,3 +448,49 @@ def compute_statistics(config: StatsConfig):
print("channel-wise std: ", ch_std)
print("normalized avg: ", (ch_avg - ch_min) / (ch_max - ch_min))
print("normalized std: ", ch_std / (ch_max - ch_min))


def generate_pseudolabels(config: PreparationConfig):
LOG.info("Generating weight pseudolabels...")
data_path = Path(config.data_processed)
assert data_path.exists() and data_path.is_dir(), "The given path is not a valid directory"

# this is just needed for the training set
dataset = FloodDataset(path=data_path,
subset="train",
include_dem=False,
transform_base=None)
# prepare directory to store resulting images
result_path = data_path / "train" / "weight"
check_or_make_dir(result_path)
morph_kernel = MorphologyTransform().create_round_kernel(kernel_size=config.morph_kernel)

# we could iterate the dataset normally, but we also need the filename to store results
# if we iterate in series, it takes forever to process with NL means and morphology
def process_image(index: int):
image = imread(dataset.image_files[index], channels_first=False)
label, profile = imread(dataset.label_files[index], return_metadata=True)
label = label.squeeze(0).astype(np.uint8)
# multiply VV and VH for fixed constants, more practical for thresholding
image[:, :, 0] *= config.vv_multiplier
image[:, :, 1] *= config.vh_multiplier
# produce a smoother SAR image for a less noisy threshold
# then further clean it up using morphological opening
denoised = denoise_nl_means(image, h=0.1, multichannel=True)
flooded = ((denoised[:, :, 0] <= 0.1) * (denoised[:, :, 1] <= 0.1)).astype(np.uint8)
flooded = cv2.morphologyEx(flooded, cv2.MORPH_OPEN, morph_kernel)
# combine them so that background has index 0, union has index 1, intersection 2
result = flooded + label
# store results to file
image_name = Path(dataset.image_files[index]).name
with rasterio.open(str(result_path / image_name), "w", **profile) as dst:
dst.write(result[np.newaxis, ...])

# Run a bunch of parallel jobs, using the same function
Parallel(n_jobs=12)(delayed(process_image)(i) for i in tqdm(range(len(dataset))))
# just some final checks, just in case
LOG.info("Validating results...")
result_images = glob(str(result_path / "*.tif"))
assert len(result_images) == len(dataset), \
f"Length mismatch between dataset ({len(dataset)}) and result ({len(result_images)})"
LOG.info("Done!")
17 changes: 15 additions & 2 deletions floods/utils/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rasterio.windows import Window


def imread(path: Path, channels_first: bool = True) -> np.ndarray:
def imread(path: Path, channels_first: bool = True, return_metadata: bool = False) -> np.ndarray:
"""Wraps rasterio open functionality to read the numpy array and exit the context.
Args:
Expand All @@ -19,7 +19,11 @@ def imread(path: Path, channels_first: bool = True) -> np.ndarray:
"""
with rasterio.open(str(path), mode="r", driver="GTiff") as src:
image = src.read()
return image if channels_first else image.transpose(1, 2, 0)
metadata = src.profile.copy()
image = image if channels_first else image.transpose(1, 2, 0)
if return_metadata:
return image, metadata
return image


def mask_raster(path: Path, mask: np.ndarray, mask_value: int = 0) -> None:
Expand All @@ -41,6 +45,15 @@ def mask_raster(path: Path, mask: np.ndarray, mask_value: int = 0) -> None:


def write_window(window: Window, source: DatasetReader, path: Path, transform: Affine = None) -> None:
"""Stores the data inside the given window, cutting it from the source dataset.
The image is stored in the given path. When the transform is None, the source transform is used instead.
Args:
window (Window): rasterio Window to delimit the target image
source (DatasetReader): source TIFF to be cut
path (Path): path to the target file to be created
transform (Affine, optional): Optional alternative transform. Defaults to None.
"""
kwargs = source.meta.copy()
transform = transform or source.transform
transform = rasterio.windows.transform(window, transform)
Expand Down
8 changes: 8 additions & 0 deletions floods/utils/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,11 @@ def apply_leaf(module, f):

def set_trainable(layers: list, value: bool) -> None:
apply_leaf(layers, lambda m: set_trainable_attr(m, value))


def entropy(label: np.ndarray, ignore: int = 255) -> np.ndarray:
valid = label.copy()
valid[valid == ignore] = 0
marg = np.histogramdd(valid.ravel(), bins=2)[0] / label.size
marg = list(filter(lambda p: p > 0, np.ravel(marg)))
return -np.sum(np.multiply(marg, np.log2(marg)))
22 changes: 22 additions & 0 deletions floods/utils/tiling/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm

from floods.utils.gis import imread
from floods.utils.ml import entropy


def tile_overlapped(image: np.ndarray,
Expand Down Expand Up @@ -169,3 +170,24 @@ def weights_from_body_ratio(labels: List[Path], normalize: bool = True, smoothin
if normalize:
weights /= weights.max()
return weights


def entropy_weights(labels: List[Path], smoothing: float = 0.8) -> np.ndarray:
"""Computes the entropy from the given list of labels (binary labels).
Args:
labels (List[Path]): list of filenames to be read
smoothing (float, optional): Value to smooth out the final array. Defaults to 0.8.
Returns:
np.ndarray: array of smoothed entropy values (max = 1.0, min = 0.0)
"""
assert smoothing <= 1, "Smooth factor must be between 0 and 1"
minval = 1.0 - smoothing
entropies = list()
for label_path in enumerate(tqdm(labels)):
label = imread(label_path)
entropies.append(entropy(label))

entropies = np.array(entropies)
return np.clip(entropies * smoothing + minval, 0, 1)
5 changes: 5 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def stats(config: StatsConfig):
return preproc.compute_statistics(config=config)


@cli.command()
def pseudolabel(config: PreparationConfig):
preproc.generate_pseudolabels(config=config)


@cli.command()
def train(config: TrainConfig):
training.train(config=config)
Expand Down

0 comments on commit b75acc1

Please sign in to comment.