Skip to content

Commit

Permalink
Random Resized Crop (keras-team#499)
Browse files Browse the repository at this point in the history
* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Doc changes

* Made requested changes

* Reformatted and tested

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Made requested changes

* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Doc changes

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Made requested changes

* Reformatted and tested

* Luke edits

* remove merge conflict

* Fix broken test cases

* Made changes to rrc

* Minor changes

* Added checks

* Added checks

* Added checks for inputs

* Docstring updates

* Made requested changes

* Minor changes

* Added demo

* Formatted

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Luke Wood <[email protected]>
  • Loading branch information
3 people committed Jun 29, 2022
1 parent d383952 commit 23c8686
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/layers/preprocessing/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ def visualize_dataset(ds):
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()


def gallery_show(images):
images = images.astype(int)
for i in range(9):
image = images[i]
plt.subplot(3, 3, i + 1)
plt.imshow(image.astype("uint8"))
plt.axis("off")
plt.show()


def load_elephant_tensor(output_size=(300, 300)):
elephants = tf.keras.utils.get_file(
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
)
elephants = tf.keras.utils.load_img(elephants, target_size=output_size)
elephants = tf.keras.utils.img_to_array(elephants)

many_elephants = tf.repeat(tf.expand_dims(elephants, axis=0), 9, axis=0)
return many_elephants
37 changes: 37 additions & 0 deletions examples/layers/preprocessing/random_resized_crop_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""random_resized_crop_demo.py.py shows how to use the RandomResizedCrop
preprocessing layer. Operates on an image of elephant. In this script the image
is loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""

import demo_utils
from keras_cv.layers.preprocessing import RandomResizedCrop


def main():
many_elephants = demo_utils.load_elephant_tensor(output_size=(300, 300))
layer = RandomResizedCrop(
target_size=(224, 224),
crop_area_factor=(0.08, 1.0),
aspect_ratio_factor=(3.0 / 4.0, 4.0 / 3.0),
)
augmented = layer(many_elephants)
demo_utils.gallery_show(augmented.numpy())


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from keras_cv.layers.preprocessing.random_gaussian_blur import RandomGaussianBlur
from keras_cv.layers.preprocessing.random_hue import RandomHue
from keras_cv.layers.preprocessing.random_jpeg_quality import RandomJpegQuality
from keras_cv.layers.preprocessing.random_resized_crop import RandomResizedCrop
from keras_cv.layers.preprocessing.random_saturation import RandomSaturation
from keras_cv.layers.preprocessing.random_sharpness import RandomSharpness
from keras_cv.layers.preprocessing.random_shear import RandomShear
Expand Down
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from keras_cv.layers.preprocessing.random_gaussian_blur import RandomGaussianBlur
from keras_cv.layers.preprocessing.random_hue import RandomHue
from keras_cv.layers.preprocessing.random_jpeg_quality import RandomJpegQuality
from keras_cv.layers.preprocessing.random_resized_crop import RandomResizedCrop
from keras_cv.layers.preprocessing.random_saturation import RandomSaturation
from keras_cv.layers.preprocessing.random_sharpness import RandomSharpness
from keras_cv.layers.preprocessing.random_shear import RandomShear
Expand Down
182 changes: 182 additions & 0 deletions keras_cv/layers/preprocessing/random_resized_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class RandomResizedCrop(BaseImageAugmentationLayer):
"""Randomly crops a part of an image and resizes it to provided size.
This implementation takes an intuitive approach, where we crop the images to a
random height and width, and then resize them. To do this, we first sample a
random value for area using `crop_area_factor` and a value for aspect ratio using
`aspect_ratio_factor`. Further we get the new height and width by
dividing and multiplying the old height and width by the random area
respectively. We then sample offsets for height and width and clip them such
that the cropped area does not exceed image boundaries. Finally we do the
actual cropping operation and resize the image to `target_size`.
Args:
target_size: A tuple of two integers used as the target size to ultimately crop
images to.
crop_area_factor: A tuple of two floats, ConstantFactorSampler or
UniformFactorSampler. The ratio of area of the cropped part to
that of original image is sampled using this factor. Represents the
lower and upper bounds for the area relative to the original image
of the cropped image before resizing it to `target_size`. For
self-supervised pretraining a common value for this parameter is
`(0.08, 1.0)`. For fine tuning and classification a common value for this
is `0.8, 1.0`.
aspect_ratio_factor: A tuple of two floats, ConstantFactorSampler or
UniformFactorSampler. Aspect ratio means the ratio of width to
height of the cropped image. In the context of this layer, the aspect ratio
sampled represents a value to distort the aspect ratio by.
Represents the lower and upper bound for the aspect ratio of the
cropped image before resizing it to `target_size`. For most tasks, this
should be `(3/4, 4/3)`. To perform a no-op provide the value `(1.0, 1.0)`.
interpolation: (Optional) A string specifying the sampling method for
resizing. Defaults to "bilinear".
seed: (Optional) Used to create a random seed. Defaults to None.
"""

def __init__(
self,
target_size,
crop_area_factor,
aspect_ratio_factor,
interpolation="bilinear",
seed=None,
**kwargs,
):
super().__init__(seed=seed, **kwargs)

self._check_class_arguments(target_size, crop_area_factor, aspect_ratio_factor)

self.target_size = target_size
self.aspect_ratio_factor = preprocessing.parse_factor(
aspect_ratio_factor,
min_value=0.0,
max_value=None,
param_name="aspect_ratio_factor",
seed=seed,
)
self.crop_area_factor = preprocessing.parse_factor(
crop_area_factor,
max_value=1.0,
param_name="crop_area_factor",
seed=seed,
)

self.interpolation = interpolation
self.seed = seed

def get_random_transformation(
self, image=None, label=None, bounding_box=None, **kwargs
):
crop_area_factor = self.crop_area_factor()
aspect_ratio = self.aspect_ratio_factor()

new_height = tf.clip_by_value(
tf.sqrt(crop_area_factor / aspect_ratio), 0.0, 1.0
) # to avoid unwanted/unintuitive effects
new_width = tf.clip_by_value(tf.sqrt(crop_area_factor * aspect_ratio), 0.0, 1.0)

height_offset = self._random_generator.random_uniform(
(),
minval=tf.minimum(0.0, 1.0 - new_height),
maxval=tf.maximum(0.0, 1.0 - new_height),
dtype=tf.float32,
)

width_offset = self._random_generator.random_uniform(
(),
minval=tf.minimum(0.0, 1.0 - new_width),
maxval=tf.maximum(0.0, 1.0 - new_width),
dtype=tf.float32,
)

y1 = height_offset
y2 = height_offset + new_height
x1 = width_offset
x2 = width_offset + new_width

return [[y1, x1, y2, x2]]

def call(self, inputs, training=True):

if training:
return super().call(inputs, training)
else:
inputs = self._ensure_inputs_are_compute_dtype(inputs)
inputs, is_dict, use_targets = self._format_inputs(inputs)
output = inputs
# self._resize() returns valid results for both batched and
# unbatched
output["images"] = self._resize(inputs["images"])
return self._format_output(output, is_dict, use_targets)

def augment_image(self, image, transformation, **kwargs):
image = tf.expand_dims(image, axis=0)
boxes = transformation

# See bit.ly/tf_crop_resize for more details
augmented_image = tf.image.crop_and_resize(
image, # image shape: [B, H, W, C]
boxes, # boxes: (1, 4) in this case; represents area
# to be cropped from the original image
[0], # box_indices: maps boxes to images along batch axis
# [0] since there is only one image
self.target_size, # output size
)

return tf.squeeze(augmented_image, axis=0)

def _resize(self, image):
outputs = tf.keras.preprocessing.image.smart_resize(image, self.target_size)
# smart_resize will always output float32, so we need to re-cast.
return tf.cast(outputs, self.compute_dtype)

def _check_class_arguments(
self, target_size, crop_area_factor, aspect_ratio_factor
):
if (
not isinstance(target_size, (tuple, list))
or len(target_size) != 2
or not isinstance(target_size[0], int)
or not isinstance(target_size[1], int)
or isinstance(target_size, int)
):
raise ValueError(
"`target_size` must be tuple of two integers."
f"Received target_size={target_size}"
)

def get_config(self):
config = super().get_config()
config.update(
{
"target_size": self.target_size,
"crop_area_factor": self.crop_area_factor,
"aspect_ratio_factor": self.aspect_ratio_factor,
"interpolation": self.interpolation,
"seed": self.seed,
}
)
return config
72 changes: 72 additions & 0 deletions keras_cv/layers/preprocessing/random_resized_crop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf

from keras_cv.layers import preprocessing


class RandomResizedCropTest(tf.test.TestCase):
height, width = 300, 300
batch_size = 4
target_size = (224, 224)
seed = 42

def test_train_augments_image(self):
# Checks if original and augmented images are different

input_image_shape = (self.batch_size, self.height, self.width, 3)
image = tf.random.uniform(shape=input_image_shape, seed=self.seed)

layer = preprocessing.RandomResizedCrop(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
seed=self.seed,
)
output = layer(image, training=True)

input_image_resized = tf.image.resize(image, self.target_size)

self.assertNotAllClose(output, input_image_resized)

def test_grayscale(self):
input_image_shape = (self.batch_size, self.height, self.width, 1)
image = tf.random.uniform(shape=input_image_shape)

layer = preprocessing.RandomResizedCrop(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
)
output = layer(image, training=True)

input_image_resized = tf.image.resize(image, self.target_size)

self.assertAllEqual(output.shape, (4, 224, 224, 1))
self.assertNotAllClose(output, input_image_resized)

def test_preserves_image(self):
image_shape = (self.batch_size, self.height, self.width, 3)
image = tf.random.uniform(shape=image_shape)

layer = preprocessing.RandomResizedCrop(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
)

input_resized = tf.image.resize(image, self.target_size)
output = layer(image, training=False)

self.assertAllClose(output, input_resized)
11 changes: 11 additions & 0 deletions keras_cv/layers/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
"seed": 1,
},
),
(
"RandomResizedCrop",
preprocessing.RandomResizedCrop,
{
"target_size": (224, 224),
"crop_area_factor": (0.08, 1.0),
"aspect_ratio_factor": (3.0 / 4.0, 4.0 / 3.0),
"interpolation": "bilinear",
"seed": 1,
},
),
(
"DropBlock2D",
regularization.DropBlock2D,
Expand Down

0 comments on commit 23c8686

Please sign in to comment.