diff --git a/model/pytorch/dffml_model_pytorch/utils/config.py b/dffml/util/config/inspect.py similarity index 85% rename from model/pytorch/dffml_model_pytorch/utils/config.py rename to dffml/util/config/inspect.py index 803e0c8362..ac77dc81a7 100644 --- a/model/pytorch/dffml_model_pytorch/utils/config.py +++ b/dffml/util/config/inspect.py @@ -1,9 +1,10 @@ import inspect -from dffml.base import field, make_config from typing import Callable, Optional, Dict, Tuple, Type +from ...base import field, make_config -def inspect_pytorch_params(cls: Callable): + +def inspect_params(cls: Callable): parameters = inspect.signature(cls).parameters args = {} @@ -21,7 +22,7 @@ def inspect_pytorch_params(cls: Callable): return args -def make_pytorch_config( +def make_config_inspect( name: str, cls: Type, properties: Optional[Dict[str, Tuple[Type, field]]] = None, @@ -34,7 +35,7 @@ def make_pytorch_config( if properties is None: properties = {} - properties.update(inspect_pytorch_params(cls)) + properties.update(inspect_params(cls)) return make_config( name, [tuple([key] + list(value)) for key, value in properties.items()] diff --git a/model/pytorch/dffml_model_pytorch/utils/utils.py b/model/pytorch/dffml_model_pytorch/utils/utils.py index e7a593c874..a3b5da176a 100644 --- a/model/pytorch/dffml_model_pytorch/utils/utils.py +++ b/model/pytorch/dffml_model_pytorch/utils/utils.py @@ -8,7 +8,7 @@ from dffml.base import BaseDataFlowFacilitatorObject from dffml.util.entrypoint import base_entry_point, entrypoint -from .config import make_pytorch_config +from dffml.util.config.inspect import make_config_inspect def create_layer(layer_dict): @@ -110,7 +110,7 @@ def load(cls, class_name: str = None): for name, loss_class in inspect.getmembers(nn, inspect.isclass): if name.endswith("Loss"): - cls_config = make_pytorch_config(name + "Config", loss_class) + cls_config = make_config_inspect(name + "Config", loss_class) cls = entrypoint(name.lower())( type(