Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update conversion script to fix model naming #2221

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions keras_cv/tools/convert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
# limitations under the License.

import os
import re

import keras_cv # noqa: E402

BUCKET = "keras-cv-kaggle"


def to_snake_case(name):
name = re.sub(r"\W+", "", name)
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower()
return name


def convert_backbone_presets():
# Save and upload Backbone presets

Expand Down Expand Up @@ -57,7 +65,9 @@ def convert_backbone_presets():
]
for backbone_cls in backbone_models:
for preset in backbone_cls.presets:
backbone = backbone_cls.from_preset(preset)
backbone = backbone_cls.from_preset(
preset, name=to_snake_case(backbone_cls.__name__)
)
save_weights = preset in backbone_cls.presets_with_weights
save_to_preset(
backbone,
Expand Down Expand Up @@ -95,7 +105,7 @@ def convert_task_presets():
)
for preset in task_preset_keys:
save_weights = preset in task_cls.presets_with_weights
kwargs = {}
kwargs = {"name": to_snake_case(task_cls.__name__)}
nkovela1 marked this conversation as resolved.
Show resolved Hide resolved
if task_cls in [
keras_cv.models.RetinaNet,
keras_cv.models.YOLOV8Detector,
Expand Down
Loading