Skip to content

Commit

Permalink
util: config: inspect: Add make_config_inspect()
Browse files Browse the repository at this point in the history
The make_config_inspect() function is the same as the
make_pytorch_config() function from the PyTorch models.

We move it to dffml.util.config.inspect:make_config_inspect() so that it
can be accessed from the main package.

The function can be used to create a config class out of any function
that has type hints on it's arguments, it also handles default
arguments.

Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Mar 18, 2021
1 parent 2799195 commit 3dfb033
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect
from dffml.base import field, make_config
from typing import Callable, Optional, Dict, Tuple, Type

from ...base import field, make_config

def inspect_pytorch_params(cls: Callable):

def inspect_params(cls: Callable):
parameters = inspect.signature(cls).parameters
args = {}

Expand All @@ -21,7 +22,7 @@ def inspect_pytorch_params(cls: Callable):
return args


def make_pytorch_config(
def make_config_inspect(
name: str,
cls: Type,
properties: Optional[Dict[str, Tuple[Type, field]]] = None,
Expand All @@ -34,7 +35,7 @@ def make_pytorch_config(
if properties is None:
properties = {}

properties.update(inspect_pytorch_params(cls))
properties.update(inspect_params(cls))

return make_config(
name, [tuple([key] + list(value)) for key, value in properties.items()]
Expand Down
4 changes: 2 additions & 2 deletions model/pytorch/dffml_model_pytorch/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dffml.base import BaseDataFlowFacilitatorObject
from dffml.util.entrypoint import base_entry_point, entrypoint
from .config import make_pytorch_config
from dffml.util.config.inspect import make_config_inspect


def create_layer(layer_dict):
Expand Down Expand Up @@ -110,7 +110,7 @@ def load(cls, class_name: str = None):
for name, loss_class in inspect.getmembers(nn, inspect.isclass):
if name.endswith("Loss"):

cls_config = make_pytorch_config(name + "Config", loss_class)
cls_config = make_config_inspect(name + "Config", loss_class)

cls = entrypoint(name.lower())(
type(
Expand Down

0 comments on commit 3dfb033

Please sign in to comment.