diff --git a/ml/rl/core/dataclasses.py b/ml/rl/core/dataclasses.py index f35996175..b32e1270b 100644 --- a/ml/rl/core/dataclasses.py +++ b/ml/rl/core/dataclasses.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 +import dataclasses +import logging +import os + # Redirection to make import simpler from dataclasses import field # noqa from typing import TYPE_CHECKING, Optional @@ -8,8 +12,31 @@ import pydantic -class Config: - arbitrary_types_allowed = True +try: + import fblearner.flow.api # noqa + + """ + Inside FBLearner, we don't use pydantic for option parsing. Some types don't have + validator. This necessary to avoid pydantic complaining about validators. + """ + USE_VANILLA_DATACLASS = True + +except ImportError: + + USE_VANILLA_DATACLASS = False + + +try: + # Allowing override, e.g., in unit test + USE_VANILLA_DATACLASS = bool(int(os.environ["USE_VANILLA_DATACLASS"])) +except KeyError: + pass + + +logger = logging.getLogger(__name__) + + +logger.info(f"USE_VANILLA_DATACLASS: {USE_VANILLA_DATACLASS}") if TYPE_CHECKING: @@ -23,16 +50,28 @@ class Config: def dataclass( _cls: Optional[pydantic.typing.AnyType] = None, *, config=None, **kwargs ): - """ - Inside FB, we don't use pydantic for option parsing. Some types don't have - validator. This necessary to avoid pydantic complaining about validators. - """ + def wrap(cls): + # We don't want to look at parent class + if "__post_init__" in cls.__dict__: + raise TypeError( + f"{cls} has __post_init__. " + "Please use __post_init_post_parse__ instead." + ) - if config is None: - config = Config + if USE_VANILLA_DATACLASS: + try: + post_init_post_parse = cls.__dict__["__post_init_post_parse__"] + logger.info( + f"Setting {cls.__name__}.__post_init__ to its " + "__post_init_post_parse__" + ) + cls.__post_init__ = post_init_post_parse + except KeyError: + pass - def wrap(cls): - return pydantic.dataclasses.dataclass(cls, config=config, **kwargs) + return dataclasses.dataclass(**kwargs)(cls) + else: + return pydantic.dataclasses.dataclass(cls, **kwargs) if _cls is None: return wrap diff --git a/ml/rl/core/registry_meta.py b/ml/rl/core/registry_meta.py index 27ad5705e..d15fd151b 100644 --- a/ml/rl/core/registry_meta.py +++ b/ml/rl/core/registry_meta.py @@ -19,10 +19,21 @@ def __init__(cls, name, bases, attrs): logger.info("Adding REGISTRY to type {}".format(name)) cls.REGISTRY: Dict[str, Type] = {} cls.REGISTRY_NAME = name + cls.REGISTRY_FROZEN = False - if not cls.__abstractmethods__: + assert not cls.REGISTRY_FROZEN, ( + f"{cls.REGISTRY_NAME} has been used to fill a union. " + "Please rearrange your import orders" + ) + + if not cls.__abstractmethods__ and name != cls.REGISTRY_NAME: # Only register fully-defined classes - logger.info("Registering {} to {}".format(name, cls.REGISTRY_NAME)) + logger.info(f"Registering {name} to {cls.REGISTRY_NAME}") + if hasattr(cls, "__registry_name__"): + registry_name = cls.__registry_name__ + logger.info(f"Using {registry_name} instead of {name}") + name = registry_name + assert name not in cls.REGISTRY cls.REGISTRY[name] = cls else: logger.info( @@ -33,6 +44,15 @@ def __init__(cls, name, bases, attrs): def fill_union(cls): def wrapper(union): + cls.REGISTRY_FROZEN = True + + def make_union_instance(inst, instance_class=None): + inst_class = instance_class or type(inst) + key = getattr(inst_class, "__registry_name__", inst_class.__name__) + return union(**{key: inst}) + + union.make_union_instance = make_union_instance + if issubclass(union, TaggedUnion): # OSS TaggedUnion union.__annotations__ = { diff --git a/ml/rl/parameters_seq2slate.py b/ml/rl/parameters_seq2slate.py index 19f7fa3b2..b29a2ed3d 100644 --- a/ml/rl/parameters_seq2slate.py +++ b/ml/rl/parameters_seq2slate.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from dataclasses import dataclass from enum import Enum from typing import Optional +from ml.rl.core.dataclasses import dataclass + class LearningMethod(Enum): TEACHER_FORCING = "teacher_forcing" diff --git a/ml/rl/workflow/model_managers/discrete_dqn_base.py b/ml/rl/workflow/model_managers/discrete_dqn_base.py index 8fbec00cb..d96765bd4 100644 --- a/ml/rl/workflow/model_managers/discrete_dqn_base.py +++ b/ml/rl/workflow/model_managers/discrete_dqn_base.py @@ -19,6 +19,7 @@ ReaderOptions, RewardOptions, RLTrainingOutput, + RLTrainingReport, TableSpec, ) from ml.rl.workflow.utils import train_and_evaluate_generic @@ -172,4 +173,7 @@ def train( evaluation_page_handler, reader_options=self.reader_options, ) - return RLTrainingOutput(training_report=reporter.generate_training_report()) + training_report = RLTrainingReport.make_union_instance( + reporter.generate_training_report() + ) + return RLTrainingOutput(training_report=training_report) diff --git a/ml/rl/workflow/publishers/model_publisher.py b/ml/rl/workflow/publishers/model_publisher.py index b7bbd13bc..700fb9b21 100644 --- a/ml/rl/workflow/publishers/model_publisher.py +++ b/ml/rl/workflow/publishers/model_publisher.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 import abc +import inspect from typing import Optional from ml.rl.core.registry_meta import RegistryMeta from ml.rl.workflow.model_managers.model_manager import ModelManager -from ml.rl.workflow.types import PublishingOutput, RecurringPeriod, RLTrainingOutput +from ml.rl.workflow.result_registries import PublishingResult +from ml.rl.workflow.types import RecurringPeriod, RLTrainingOutput class ModelPublisher(metaclass=RegistryMeta): @@ -14,7 +16,6 @@ class ModelPublisher(metaclass=RegistryMeta): they can be registered in the workflows. """ - @abc.abstractmethod def publish( self, model_manager: ModelManager, @@ -22,7 +23,38 @@ def publish( recurring_workflow_id: int, child_workflow_id: int, recurring_period: Optional[RecurringPeriod], - ) -> PublishingOutput: + ): + """ + This method takes RLTrainingOutput so that it can extract anything it + might need from it. + + ModelManager is given here so that config can be shared + """ + result = self.do_publish( + model_manager, + training_output, + recurring_workflow_id, + child_workflow_id, + recurring_period, + ) + # Avoid circular dependency at import time + from ml.rl.workflow.types import PublishingResult__Union + + # We need to use inspection because the result can be a future when running on + # FBL + result_type = inspect.signature(self.do_publish).return_annotation + assert result_type != inspect.Signature.empty + return PublishingResult__Union.make_union_instance(result, result_type) + + @abc.abstractmethod + def do_publish( + self, + model_manager: ModelManager, + training_output: RLTrainingOutput, + recurring_workflow_id: int, + child_workflow_id: int, + recurring_period: Optional[RecurringPeriod], + ) -> PublishingResult: """ This method takes RLTrainingOutput so that it can extract anything it might need from it. diff --git a/ml/rl/workflow/publishers/no_publishing.py b/ml/rl/workflow/publishers/no_publishing.py index 453464974..6f44af4ac 100644 --- a/ml/rl/workflow/publishers/no_publishing.py +++ b/ml/rl/workflow/publishers/no_publishing.py @@ -5,8 +5,8 @@ from ml.rl.workflow.model_managers.model_manager import ModelManager from ml.rl.workflow.publishers.model_publisher import ModelPublisher -from ml.rl.workflow.result_types import NoPublishingResults, PublishingResults -from ml.rl.workflow.types import PublishingOutput, RecurringPeriod, RLTrainingOutput +from ml.rl.workflow.result_types import NoPublishingResults +from ml.rl.workflow.types import RecurringPeriod, RLTrainingOutput @dataclass @@ -17,15 +17,12 @@ class NoPublishing(ModelPublisher): some publishing. """ - def publish( + def do_publish( self, model_manager: ModelManager, training_output: RLTrainingOutput, recurring_workflow_id: int, child_workflow_id: int, recurring_period: Optional[RecurringPeriod], - ) -> PublishingOutput: - return PublishingOutput( - success=True, - results=PublishingResults(no_publishing_results=NoPublishingResults()), - ) + ) -> NoPublishingResults: + return NoPublishingResults(success=True) diff --git a/ml/rl/workflow/result_registries.py b/ml/rl/workflow/result_registries.py new file mode 100644 index 000000000..1eaf08095 --- /dev/null +++ b/ml/rl/workflow/result_registries.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from ml.rl.core.dataclasses import dataclass +from ml.rl.core.registry_meta import RegistryMeta + + +class TrainingReport(metaclass=RegistryMeta): + pass + + +@dataclass +class PublishingResult(metaclass=RegistryMeta): + success: bool + + +@dataclass +class ValidationResult(metaclass=RegistryMeta): + should_publish: bool diff --git a/ml/rl/workflow/result_types.py b/ml/rl/workflow/result_types.py index e86ad92bd..5718809c4 100644 --- a/ml/rl/workflow/result_types.py +++ b/ml/rl/workflow/result_types.py @@ -1,26 +1,15 @@ #!/usr/bin/env python3 - -from typing import NamedTuple, Optional +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from ml.rl.core.dataclasses import dataclass -from ml.rl.core.tagged_union import TaggedUnion - - -class NoPublishingResults(NamedTuple): - pass - - -class NoValidationResults(NamedTuple): - pass +from ml.rl.workflow.result_registries import PublishingResult, ValidationResult @dataclass -class PublishingResults(TaggedUnion): - no_publishing_results: Optional[NoPublishingResults] = None - # Add your own validation results type here +class NoPublishingResults(PublishingResult): + __registry_name__ = "no_publishing_results" @dataclass -class ValidationResults(TaggedUnion): - no_validation_results: Optional[NoValidationResults] = None - # Add your own validation results type here +class NoValidationResults(ValidationResult): + __registry_name__ = "no_validation_results" diff --git a/ml/rl/workflow/tagged_union.py b/ml/rl/workflow/tagged_union.py new file mode 100644 index 000000000..1c5d69376 --- /dev/null +++ b/ml/rl/workflow/tagged_union.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from ml.rl.core.tagged_union import TaggedUnion # noqa F401 diff --git a/ml/rl/workflow/training.py b/ml/rl/workflow/training.py index 6530d3055..b281b6fa9 100644 --- a/ml/rl/workflow/training.py +++ b/ml/rl/workflow/training.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import dataclasses import logging from typing import Dict, NamedTuple, Optional, Tuple @@ -8,10 +9,12 @@ from ml.rl.workflow.model_managers.union import ModelManager__Union from ml.rl.workflow.publishers.union import ModelPublisher__Union from ml.rl.workflow.types import ( + PublishingResult__Union, RecurringPeriod, RewardOptions, RLTrainingOutput, TableSpec, + ValidationResult__Union, ) from ml.rl.workflow.validators.union import ModelValidator__Union @@ -159,11 +162,11 @@ def run_validator( validator: ModelValidator__Union, training_output: RLTrainingOutput ) -> RLTrainingOutput: assert ( - training_output.validation_output is None + training_output.validation_result is None ), f"validation_output was set to f{training_output.validation_output}" model_validator = validator.value - validation_output = model_validator.validate(training_output) - return training_output._replace(validation_output=validation_output) + validation_result = model_validator.validate(training_output) + return dataclasses.replace(training_output, validation_result=validation_result) def run_publisher( @@ -175,15 +178,15 @@ def run_publisher( recurring_period: Optional[RecurringPeriod], ) -> RLTrainingOutput: assert ( - training_output.publishing_output is None + training_output.publishing_result is None ), f"publishing_output was set to f{training_output.publishing_output}" model_publisher = publisher.value model_manager = model_chooser.value - publishing_output = model_publisher.publish( + publishing_result = model_publisher.publish( model_manager, training_output, recurring_workflow_id, child_workflow_id, recurring_period, ) - return training_output._replace(publishing_output=publishing_output) + return dataclasses.replace(training_output, publishing_result=publishing_result) diff --git a/ml/rl/workflow/types.py b/ml/rl/workflow/types.py index 11dcaa7c0..cde1fd0bc 100644 --- a/ml/rl/workflow/types.py +++ b/ml/rl/workflow/types.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from dataclasses import dataclass from datetime import datetime as RecurringPeriod # noqa from typing import Dict, List, NamedTuple, Optional -from ml.rl.core.tagged_union import TaggedUnion # noqa F401 +# Triggering registration to registries +import ml.rl.workflow.result_types # noqa +import ml.rl.workflow.training_reports # noqa +from ml.rl.core.dataclasses import dataclass from ml.rl.preprocessing.normalization import ( DEFAULT_MAX_QUANTILE_SIZE, DEFAULT_MAX_UNIQUE_ENUM, @@ -13,7 +15,12 @@ DEFAULT_QUANTILE_K2_THRESHOLD, ) from ml.rl.types import BaseDataClass -from ml.rl.workflow.result_types import PublishingResults, ValidationResults +from ml.rl.workflow.result_registries import ( + PublishingResult, + TrainingReport, + ValidationResult, +) +from ml.rl.workflow.tagged_union import TaggedUnion # noqa F401 @dataclass @@ -36,11 +43,23 @@ class PreprocessingOptions(BaseDataClass): assert_whitelist_feature_coverage: bool = True -class PublishingOutput(NamedTuple): - success: bool - results: PublishingResults +@PublishingResult.fill_union() +class PublishingResult__Union(TaggedUnion): + pass + + +@ValidationResult.fill_union() +class ValidationResult__Union(TaggedUnion): + pass -class ValidationOutput(NamedTuple): - should_publish: bool - results: ValidationResults +@TrainingReport.fill_union() +class RLTrainingReport(TaggedUnion): + pass + + +@dataclass +class RLTrainingOutput: + validation_result: Optional[ValidationResult__Union] = None + publishing_result: Optional[PublishingResult__Union] = None + training_report: Optional[RLTrainingReport] = None diff --git a/ml/rl/workflow/validators/model_validator.py b/ml/rl/workflow/validators/model_validator.py index e439c37fe..1e1860198 100644 --- a/ml/rl/workflow/validators/model_validator.py +++ b/ml/rl/workflow/validators/model_validator.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 import abc +import inspect import logging from ml.rl.core.registry_meta import RegistryMeta -from ml.rl.workflow.types import RLTrainingOutput, ValidationOutput +from ml.rl.workflow.result_registries import ValidationResult +from ml.rl.workflow.types import RLTrainingOutput logger = logging.getLogger(__name__) @@ -16,8 +18,23 @@ class ModelValidator(metaclass=RegistryMeta): they can be registered in the workflows. """ + def validate(self, training_output: RLTrainingOutput): + """ + This method takes RLTrainingOutput so that it can extract anything it + might need from it. + """ + result = self.do_validate(training_output) + # Avoid circular dependency at import time + from ml.rl.workflow.types import ValidationResult__Union + + # We need to use inspection because the result can be a future when running on + # FBL + result_type = inspect.signature(self.do_validate).return_annotation + assert result_type != inspect.Signature.empty + return ValidationResult__Union.make_union_instance(result, result_type) + @abc.abstractmethod - def validate(self, training_output: RLTrainingOutput) -> ValidationOutput: + def do_validate(self, training_output: RLTrainingOutput) -> ValidationResult: """ This method takes RLTrainingOutput so that it can extract anything it might need from it. diff --git a/ml/rl/workflow/validators/no_validation.py b/ml/rl/workflow/validators/no_validation.py index a1c930973..8d0e6dea6 100644 --- a/ml/rl/workflow/validators/no_validation.py +++ b/ml/rl/workflow/validators/no_validation.py @@ -2,8 +2,8 @@ from dataclasses import dataclass -from ml.rl.workflow.result_types import NoValidationResults, ValidationResults -from ml.rl.workflow.types import RLTrainingOutput, ValidationOutput +from ml.rl.workflow.result_types import NoValidationResults +from ml.rl.workflow.types import RLTrainingOutput from ml.rl.workflow.validators.model_validator import ModelValidator @@ -15,8 +15,5 @@ class NoValidation(ModelValidator): some validation. """ - def validate(self, training_output: RLTrainingOutput) -> ValidationOutput: - return ValidationOutput( - should_publish=True, - results=ValidationResults(no_validation_results=NoValidationResults()), - ) + def do_validate(self, training_output: RLTrainingOutput) -> NoValidationResults: + return NoValidationResults(should_publish=True)