Skip to content

Commit

Permalink
Create registries for result types
Browse files Browse the repository at this point in the history
Summary:
Created the following registries as base classes for training reporting:

- TrainingReport
- PublishingResult
- ValidationResult

Add `do_validate` & `do_publish` methods to `ModelValidator`, `ModelPublisher`, respectively. This simplifies subclasses as they don't have to create the union type.

Reviewed By: czxttkl

Differential Revision: D20854559

fbshipit-source-id: 4d626dab6711eee35a12581c4e5bbe307b4985b8
  • Loading branch information
kittipatv authored and facebook-github-bot committed Apr 9, 2020
1 parent aee4d09 commit 4548933
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 66 deletions.
59 changes: 49 additions & 10 deletions ml/rl/core/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,42 @@
#!/usr/bin/env python3


import dataclasses
import logging
import os

# Redirection to make import simpler
from dataclasses import field # noqa
from typing import TYPE_CHECKING, Optional

import pydantic


class Config:
arbitrary_types_allowed = True
try:
import fblearner.flow.api # noqa

"""
Inside FBLearner, we don't use pydantic for option parsing. Some types don't have
validator. This necessary to avoid pydantic complaining about validators.
"""
USE_VANILLA_DATACLASS = True

except ImportError:

USE_VANILLA_DATACLASS = False


try:
# Allowing override, e.g., in unit test
USE_VANILLA_DATACLASS = bool(int(os.environ["USE_VANILLA_DATACLASS"]))
except KeyError:
pass


logger = logging.getLogger(__name__)


logger.info(f"USE_VANILLA_DATACLASS: {USE_VANILLA_DATACLASS}")


if TYPE_CHECKING:
Expand All @@ -23,16 +50,28 @@ class Config:
def dataclass(
_cls: Optional[pydantic.typing.AnyType] = None, *, config=None, **kwargs
):
"""
Inside FB, we don't use pydantic for option parsing. Some types don't have
validator. This necessary to avoid pydantic complaining about validators.
"""
def wrap(cls):
# We don't want to look at parent class
if "__post_init__" in cls.__dict__:
raise TypeError(
f"{cls} has __post_init__. "
"Please use __post_init_post_parse__ instead."
)

if config is None:
config = Config
if USE_VANILLA_DATACLASS:
try:
post_init_post_parse = cls.__dict__["__post_init_post_parse__"]
logger.info(
f"Setting {cls.__name__}.__post_init__ to its "
"__post_init_post_parse__"
)
cls.__post_init__ = post_init_post_parse
except KeyError:
pass

def wrap(cls):
return pydantic.dataclasses.dataclass(cls, config=config, **kwargs)
return dataclasses.dataclass(**kwargs)(cls)
else:
return pydantic.dataclasses.dataclass(cls, **kwargs)

if _cls is None:
return wrap
Expand Down
24 changes: 22 additions & 2 deletions ml/rl/core/registry_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ def __init__(cls, name, bases, attrs):
logger.info("Adding REGISTRY to type {}".format(name))
cls.REGISTRY: Dict[str, Type] = {}
cls.REGISTRY_NAME = name
cls.REGISTRY_FROZEN = False

if not cls.__abstractmethods__:
assert not cls.REGISTRY_FROZEN, (
f"{cls.REGISTRY_NAME} has been used to fill a union. "
"Please rearrange your import orders"
)

if not cls.__abstractmethods__ and name != cls.REGISTRY_NAME:
# Only register fully-defined classes
logger.info("Registering {} to {}".format(name, cls.REGISTRY_NAME))
logger.info(f"Registering {name} to {cls.REGISTRY_NAME}")
if hasattr(cls, "__registry_name__"):
registry_name = cls.__registry_name__
logger.info(f"Using {registry_name} instead of {name}")
name = registry_name
assert name not in cls.REGISTRY
cls.REGISTRY[name] = cls
else:
logger.info(
Expand All @@ -33,6 +44,15 @@ def __init__(cls, name, bases, attrs):

def fill_union(cls):
def wrapper(union):
cls.REGISTRY_FROZEN = True

def make_union_instance(inst, instance_class=None):
inst_class = instance_class or type(inst)
key = getattr(inst_class, "__registry_name__", inst_class.__name__)
return union(**{key: inst})

union.make_union_instance = make_union_instance

if issubclass(union, TaggedUnion):
# OSS TaggedUnion
union.__annotations__ = {
Expand Down
3 changes: 2 additions & 1 deletion ml/rl/parameters_seq2slate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from dataclasses import dataclass
from enum import Enum
from typing import Optional

from ml.rl.core.dataclasses import dataclass


class LearningMethod(Enum):
TEACHER_FORCING = "teacher_forcing"
Expand Down
6 changes: 5 additions & 1 deletion ml/rl/workflow/model_managers/discrete_dqn_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ReaderOptions,
RewardOptions,
RLTrainingOutput,
RLTrainingReport,
TableSpec,
)
from ml.rl.workflow.utils import train_and_evaluate_generic
Expand Down Expand Up @@ -172,4 +173,7 @@ def train(
evaluation_page_handler,
reader_options=self.reader_options,
)
return RLTrainingOutput(training_report=reporter.generate_training_report())
training_report = RLTrainingReport.make_union_instance(
reporter.generate_training_report()
)
return RLTrainingOutput(training_report=training_report)
38 changes: 35 additions & 3 deletions ml/rl/workflow/publishers/model_publisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python3

import abc
import inspect
from typing import Optional

from ml.rl.core.registry_meta import RegistryMeta
from ml.rl.workflow.model_managers.model_manager import ModelManager
from ml.rl.workflow.types import PublishingOutput, RecurringPeriod, RLTrainingOutput
from ml.rl.workflow.result_registries import PublishingResult
from ml.rl.workflow.types import RecurringPeriod, RLTrainingOutput


class ModelPublisher(metaclass=RegistryMeta):
Expand All @@ -14,15 +16,45 @@ class ModelPublisher(metaclass=RegistryMeta):
they can be registered in the workflows.
"""

@abc.abstractmethod
def publish(
self,
model_manager: ModelManager,
training_output: RLTrainingOutput,
recurring_workflow_id: int,
child_workflow_id: int,
recurring_period: Optional[RecurringPeriod],
) -> PublishingOutput:
):
"""
This method takes RLTrainingOutput so that it can extract anything it
might need from it.
ModelManager is given here so that config can be shared
"""
result = self.do_publish(
model_manager,
training_output,
recurring_workflow_id,
child_workflow_id,
recurring_period,
)
# Avoid circular dependency at import time
from ml.rl.workflow.types import PublishingResult__Union

# We need to use inspection because the result can be a future when running on
# FBL
result_type = inspect.signature(self.do_publish).return_annotation
assert result_type != inspect.Signature.empty
return PublishingResult__Union.make_union_instance(result, result_type)

@abc.abstractmethod
def do_publish(
self,
model_manager: ModelManager,
training_output: RLTrainingOutput,
recurring_workflow_id: int,
child_workflow_id: int,
recurring_period: Optional[RecurringPeriod],
) -> PublishingResult:
"""
This method takes RLTrainingOutput so that it can extract anything it
might need from it.
Expand Down
13 changes: 5 additions & 8 deletions ml/rl/workflow/publishers/no_publishing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from ml.rl.workflow.model_managers.model_manager import ModelManager
from ml.rl.workflow.publishers.model_publisher import ModelPublisher
from ml.rl.workflow.result_types import NoPublishingResults, PublishingResults
from ml.rl.workflow.types import PublishingOutput, RecurringPeriod, RLTrainingOutput
from ml.rl.workflow.result_types import NoPublishingResults
from ml.rl.workflow.types import RecurringPeriod, RLTrainingOutput


@dataclass
Expand All @@ -17,15 +17,12 @@ class NoPublishing(ModelPublisher):
some publishing.
"""

def publish(
def do_publish(
self,
model_manager: ModelManager,
training_output: RLTrainingOutput,
recurring_workflow_id: int,
child_workflow_id: int,
recurring_period: Optional[RecurringPeriod],
) -> PublishingOutput:
return PublishingOutput(
success=True,
results=PublishingResults(no_publishing_results=NoPublishingResults()),
)
) -> NoPublishingResults:
return NoPublishingResults(success=True)
19 changes: 19 additions & 0 deletions ml/rl/workflow/result_registries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ml.rl.core.dataclasses import dataclass
from ml.rl.core.registry_meta import RegistryMeta


class TrainingReport(metaclass=RegistryMeta):
pass


@dataclass
class PublishingResult(metaclass=RegistryMeta):
success: bool


@dataclass
class ValidationResult(metaclass=RegistryMeta):
should_publish: bool
23 changes: 6 additions & 17 deletions ml/rl/workflow/result_types.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
#!/usr/bin/env python3

from typing import NamedTuple, Optional
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ml.rl.core.dataclasses import dataclass
from ml.rl.core.tagged_union import TaggedUnion


class NoPublishingResults(NamedTuple):
pass


class NoValidationResults(NamedTuple):
pass
from ml.rl.workflow.result_registries import PublishingResult, ValidationResult


@dataclass
class PublishingResults(TaggedUnion):
no_publishing_results: Optional[NoPublishingResults] = None
# Add your own validation results type here
class NoPublishingResults(PublishingResult):
__registry_name__ = "no_publishing_results"


@dataclass
class ValidationResults(TaggedUnion):
no_validation_results: Optional[NoValidationResults] = None
# Add your own validation results type here
class NoValidationResults(ValidationResult):
__registry_name__ = "no_validation_results"
4 changes: 4 additions & 0 deletions ml/rl/workflow/tagged_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ml.rl.core.tagged_union import TaggedUnion # noqa F401
15 changes: 9 additions & 6 deletions ml/rl/workflow/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import dataclasses
import logging
from typing import Dict, NamedTuple, Optional, Tuple

Expand All @@ -8,10 +9,12 @@
from ml.rl.workflow.model_managers.union import ModelManager__Union
from ml.rl.workflow.publishers.union import ModelPublisher__Union
from ml.rl.workflow.types import (
PublishingResult__Union,
RecurringPeriod,
RewardOptions,
RLTrainingOutput,
TableSpec,
ValidationResult__Union,
)
from ml.rl.workflow.validators.union import ModelValidator__Union

Expand Down Expand Up @@ -159,11 +162,11 @@ def run_validator(
validator: ModelValidator__Union, training_output: RLTrainingOutput
) -> RLTrainingOutput:
assert (
training_output.validation_output is None
training_output.validation_result is None
), f"validation_output was set to f{training_output.validation_output}"
model_validator = validator.value
validation_output = model_validator.validate(training_output)
return training_output._replace(validation_output=validation_output)
validation_result = model_validator.validate(training_output)
return dataclasses.replace(training_output, validation_result=validation_result)


def run_publisher(
Expand All @@ -175,15 +178,15 @@ def run_publisher(
recurring_period: Optional[RecurringPeriod],
) -> RLTrainingOutput:
assert (
training_output.publishing_output is None
training_output.publishing_result is None
), f"publishing_output was set to f{training_output.publishing_output}"
model_publisher = publisher.value
model_manager = model_chooser.value
publishing_output = model_publisher.publish(
publishing_result = model_publisher.publish(
model_manager,
training_output,
recurring_workflow_id,
child_workflow_id,
recurring_period,
)
return training_output._replace(publishing_output=publishing_output)
return dataclasses.replace(training_output, publishing_result=publishing_result)
Loading

0 comments on commit 4548933

Please sign in to comment.