Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-label classification dataset and metric #1572

Merged
merged 11 commits into from
Apr 5, 2024
Prev Previous commit
Next Next commit
Add image_folder_dataset_multilabel test
  • Loading branch information
laggui committed Apr 3, 2024
commit 813ae504f0615d4430ca7a5b50103fe5597a74fb
42 changes: 42 additions & 0 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,48 @@ mod tests {
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_multilabel() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(
root.join("orange").join("dot.jpg"),
vec!["dot".to_string(), "orange".to_string()],
),
(
root.join("red").join("dot.jpg"),
vec!["dot".to_string(), "red".to_string()],
),
(
root.join("red").join("dot.png"),
vec!["dot".to_string(), "red".to_string()],
),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["dot", "orange", "red"],
)
.unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::MultiLabel(vec![0, 1])
);
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
#[should_panic]
pub fn image_folder_dataset_invalid_extension() {
Expand Down