Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-label classification dataset and metric #1572

Merged
merged 11 commits into from
Apr 5, 2024
3 changes: 2 additions & 1 deletion crates/burn-dataset/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fake = ["dep:fake"]
sqlite = ["__sqlite-shared", "dep:rusqlite"]
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]

vision = ["dep:flate2", "dep:globwalk", "dep:burn-common"]
vision = ["dep:bincode", "dep:flate2", "dep:globwalk", "dep:burn-common"]

# internal
__sqlite-shared = [
Expand All @@ -33,6 +33,7 @@ __sqlite-shared = [
]

[dependencies]
bincode = { workspace = true, optional = true }
burn-common = { path = "../burn-common", version = "0.13.0", optional = true, features = [
"network",
] }
Expand Down
259 changes: 231 additions & 28 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ impl TryFrom<PixelDepth> for f32 {
}
}

/// Image target for different tasks.
/// Annotation type for different tasks.
#[derive(Debug, Clone, PartialEq)]
pub enum Annotation {
/// Image-level label.
Label(usize),
/// Multiple image-level labels.
MultiLabel(Vec<usize>),
/// Object bounding boxes.
BoundingBoxes(Vec<BoundingBox>),
/// Segmentation mask.
Expand Down Expand Up @@ -97,14 +99,47 @@ pub struct ImageDatasetItem {
pub annotation: Annotation,
}

/// Raw annotation types.
#[derive(Deserialize, Serialize, Debug, Clone)]
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
}

impl AnnotationRaw {
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}

fn encode(&self) -> Vec<u8> {
bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap()
}

fn decode(annotation: &[u8]) -> Self {
let (annotation, _): (AnnotationRaw, usize) =
bincode::serde::decode_from_slice(annotation, Self::bin_config()).unwrap();
annotation
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the serialization for what exactly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided on having annotations as bytes

struct ImageDatasetItemRaw {
    /// Image path.
    image_path: PathBuf,

    /// Image annotation.
    /// The annotation bytes can represent a string (category name) or path to annotation file.
    annotation: Vec<u8>,
}

But now that you mention it... I don't see any need for serialization just to have bytes 😅 we could simply change the annotation type in ImageDatasetItemRaw to the AnnotationRaw enum. And scrap the encode/decode.

Probably needed another coffee when I went over this part ☕

}

#[derive(Deserialize, Serialize, Debug, Clone)]
struct ImageDatasetItemRaw {
/// Image path.
pub image_path: PathBuf,
image_path: PathBuf,

/// Image annotation.
/// The annotation bytes can represent a string (category name) or path to annotation file.
pub annotation: Vec<u8>,
annotation: Vec<u8>,
}

impl ImageDatasetItemRaw {
fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {
ImageDatasetItemRaw {
image_path: image_path.as_ref().to_path_buf(),
annotation: annotation.encode(),
}
}
}

struct PathToImageDatasetItem {
Expand All @@ -118,9 +153,18 @@ fn parse_image_annotation(annotation: &[u8], classes: &HashMap<String, usize>) -
// - [ ] Segmentation mask
// For now, only image classification labels are supported.

let annotation = AnnotationRaw::decode(annotation);

// Map class string to label id
let name = std::str::from_utf8(annotation).unwrap();
Annotation::Label(*classes.get(name).unwrap())
match annotation {
AnnotationRaw::Label(name) => Annotation::Label(*classes.get(&name).unwrap()),
AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(
names
.iter()
.map(|name| *classes.get(name).unwrap())
.collect(),
),
}
}

impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
Expand Down Expand Up @@ -212,7 +256,7 @@ pub enum ImageLoaderError {
type ImageDatasetMapper =
MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;

/// A generic dataset to load classification images from disk.
/// A generic dataset to load images from disk.
pub struct ImageFolderDataset {
dataset: ImageDatasetMapper,
}
Expand Down Expand Up @@ -259,26 +303,14 @@ impl ImageFolderDataset {
P: AsRef<Path>,
S: AsRef<str>,
{
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}

// Glob all images with extensions
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
&[format!(
"*.{{{}}}", // "*.{ext1,ext2,ext3}
extensions
.iter()
.map(check_extension)
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
Expand Down Expand Up @@ -312,21 +344,102 @@ impl ImageFolderDataset {

classes.insert(label.clone());

items.push(ImageDatasetItemRaw {
image_path: image_path.to_path_buf(),
annotation: label.into_bytes(),
})
items.push(ImageDatasetItemRaw::new(
image_path,
AnnotationRaw::Label(label),
))
}

// Sort class names
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();

Self::with_items(items, &classes)
}

/// Create an image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, label)| {
// Map image path and label
let path = path.as_ref();
let label = AnnotationRaw::Label(label);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create a multi-label image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, Vec<String>)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, labels)| {
// Map image path and multi-label
let path = path.as_ref();
let labels = AnnotationRaw::MultiLabel(labels);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, labels))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create an image dataset with the specified items.
///
/// # Arguments
///
/// * `items` - Raw dataset items.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
fn with_items<S: AsRef<str>>(
items: Vec<ImageDatasetItemRaw>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// NOTE: right now we don't need to validate the supported image files since
// the method is private. We assume it's already validated.
let dataset = InMemDataset::new(items);

// Class names to index map
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls, idx))
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();

let mapper = PathToImageDatasetItem {
Expand All @@ -336,6 +449,18 @@ impl ImageFolderDataset {

Ok(Self { dataset })
}

/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -370,6 +495,69 @@ mod tests {
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_with_items() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(root.join("orange").join("dot.jpg"), "orange".to_string()),
(root.join("red").join("dot.jpg"), "red".to_string()),
(root.join("red").join("dot.png"), "red".to_string()),
];
let dataset =
ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: orange (0), red (1), red (1)
assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_multilabel() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(
root.join("orange").join("dot.jpg"),
vec!["dot".to_string(), "orange".to_string()],
),
(
root.join("red").join("dot.jpg"),
vec!["dot".to_string(), "red".to_string()],
),
(
root.join("red").join("dot.png"),
vec!["dot".to_string(), "red".to_string()],
),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["dot", "orange", "red"],
)
.unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::MultiLabel(vec![0, 1])
);
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
#[should_panic]
pub fn image_folder_dataset_invalid_extension() {
Expand Down Expand Up @@ -417,11 +605,26 @@ mod tests {
}

#[test]
pub fn parse_image_annotation_string() {
pub fn parse_image_annotation_label_string() {
let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]);
let anno = AnnotationRaw::Label("0".to_string()).encode();
assert_eq!(
parse_image_annotation(&"0".to_string().into_bytes(), &classes),
parse_image_annotation(&anno, &classes),
Annotation::Label(0)
);
}

#[test]
pub fn parse_image_annotation_multilabel_string() {
let classes = HashMap::from([
("0".to_string(), 0_usize),
("1".to_string(), 1_usize),
("2".to_string(), 2_usize),
]);
let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]).encode();
assert_eq!(
parse_image_annotation(&anno, &classes),
Annotation::MultiLabel(vec![0, 2])
);
}
}
Loading
Loading