Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-label classification dataset and metric #1572

Merged
merged 11 commits into from
Apr 5, 2024
247 changes: 217 additions & 30 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ impl TryFrom<PixelDepth> for f32 {
}
}

/// Image target for different tasks.
/// Annotation type for different tasks.
#[derive(Debug, Clone, PartialEq)]
pub enum Annotation {
/// Image-level label.
Label(usize),
/// Multiple image-level labels.
MultiLabel(Vec<usize>),
/// Object bounding boxes.
BoundingBoxes(Vec<BoundingBox>),
/// Segmentation mask.
Expand Down Expand Up @@ -97,30 +99,56 @@ pub struct ImageDatasetItem {
pub annotation: Annotation,
}

/// Raw annotation types.
#[derive(Deserialize, Serialize, Debug, Clone)]
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
}

#[derive(Deserialize, Serialize, Debug, Clone)]
struct ImageDatasetItemRaw {
/// Image path.
pub image_path: PathBuf,
image_path: PathBuf,

/// Image annotation.
/// The annotation bytes can represent a string (category name) or path to annotation file.
pub annotation: Vec<u8>,
annotation: AnnotationRaw,
}

impl ImageDatasetItemRaw {
fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {
ImageDatasetItemRaw {
image_path: image_path.as_ref().to_path_buf(),
annotation,
}
}
}

struct PathToImageDatasetItem {
classes: HashMap<String, usize>,
}

/// Parse the image annotation to the corresponding type.
fn parse_image_annotation(annotation: &[u8], classes: &HashMap<String, usize>) -> Annotation {
fn parse_image_annotation(
annotation: &AnnotationRaw,
classes: &HashMap<String, usize>,
) -> Annotation {
// TODO: add support for other annotations
// - [ ] Object bounding boxes
// - [ ] Segmentation mask
// For now, only image classification labels are supported.

// Map class string to label id
let name = std::str::from_utf8(annotation).unwrap();
Annotation::Label(*classes.get(name).unwrap())
match annotation {
AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()),
AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(
names
.iter()
.map(|name| *classes.get(name).unwrap())
.collect(),
),
}
}

impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
Expand Down Expand Up @@ -212,7 +240,7 @@ pub enum ImageLoaderError {
type ImageDatasetMapper =
MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;

/// A generic dataset to load classification images from disk.
/// A generic dataset to load images from disk.
pub struct ImageFolderDataset {
dataset: ImageDatasetMapper,
}
Expand Down Expand Up @@ -259,26 +287,14 @@ impl ImageFolderDataset {
P: AsRef<Path>,
S: AsRef<str>,
{
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}

// Glob all images with extensions
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
&[format!(
"*.{{{}}}", // "*.{ext1,ext2,ext3}
extensions
.iter()
.map(check_extension)
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
Expand Down Expand Up @@ -312,21 +328,102 @@ impl ImageFolderDataset {

classes.insert(label.clone());

items.push(ImageDatasetItemRaw {
image_path: image_path.to_path_buf(),
annotation: label.into_bytes(),
})
items.push(ImageDatasetItemRaw::new(
image_path,
AnnotationRaw::Label(label),
))
}

// Sort class names
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();

Self::with_items(items, &classes)
}

/// Create an image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, label)| {
// Map image path and label
let path = path.as_ref();
let label = AnnotationRaw::Label(label);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create a multi-label image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, Vec<String>)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, labels)| {
// Map image path and multi-label
let path = path.as_ref();
let labels = AnnotationRaw::MultiLabel(labels);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, labels))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create an image dataset with the specified items.
///
/// # Arguments
///
/// * `items` - Raw dataset items.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
fn with_items<S: AsRef<str>>(
items: Vec<ImageDatasetItemRaw>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// NOTE: right now we don't need to validate the supported image files since
// the method is private. We assume it's already validated.
let dataset = InMemDataset::new(items);

// Class names to index map
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls, idx))
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();

let mapper = PathToImageDatasetItem {
Expand All @@ -336,6 +433,18 @@ impl ImageFolderDataset {

Ok(Self { dataset })
}

/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -370,6 +479,69 @@ mod tests {
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_with_items() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(root.join("orange").join("dot.jpg"), "orange".to_string()),
(root.join("red").join("dot.jpg"), "red".to_string()),
(root.join("red").join("dot.png"), "red".to_string()),
];
let dataset =
ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: orange (0), red (1), red (1)
assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_multilabel() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(
root.join("orange").join("dot.jpg"),
vec!["dot".to_string(), "orange".to_string()],
),
(
root.join("red").join("dot.jpg"),
vec!["dot".to_string(), "red".to_string()],
),
(
root.join("red").join("dot.png"),
vec!["dot".to_string(), "red".to_string()],
),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["dot", "orange", "red"],
)
.unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::MultiLabel(vec![0, 1])
);
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
#[should_panic]
pub fn image_folder_dataset_invalid_extension() {
Expand Down Expand Up @@ -417,11 +589,26 @@ mod tests {
}

#[test]
pub fn parse_image_annotation_string() {
pub fn parse_image_annotation_label_string() {
let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]);
let anno = AnnotationRaw::Label("0".to_string());
assert_eq!(
parse_image_annotation(&"0".to_string().into_bytes(), &classes),
parse_image_annotation(&anno, &classes),
Annotation::Label(0)
);
}

#[test]
pub fn parse_image_annotation_multilabel_string() {
let classes = HashMap::from([
("0".to_string(), 0_usize),
("1".to_string(), 1_usize),
("2".to_string(), 2_usize),
]);
let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]);
assert_eq!(
parse_image_annotation(&anno, &classes),
Annotation::MultiLabel(vec![0, 2])
);
}
}
27 changes: 26 additions & 1 deletion crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::metric::{AccuracyInput, Adaptor, LossInput};
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};

Expand Down Expand Up @@ -26,3 +26,28 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
LossInput::new(self.loss.clone())
}
}

/// Multi-label classification output adapted for multiple metrics.
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,

/// The output.
pub output: Tensor<B, 2>,

/// The targets.
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}

impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
Loading
Loading