Skip to content

Commit

Permalink
Make Trainer API more user-friendly
Browse files Browse the repository at this point in the history
Summary:
- Added `make_config_class` decorator to derive config class from constructor
- Added `resolve_defaults` to resolve default factory automatically

Reviewed By: badrinarayan

Differential Revision: D20756863

fbshipit-source-id: 93c3f2d92881e0e7ba1c6435691785bd39af3f6b
  • Loading branch information
kittipatv authored and facebook-github-bot committed Apr 2, 2020
1 parent 460f74a commit dedd844
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 46 deletions.
127 changes: 127 additions & 0 deletions ml/rl/core/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/python3

import functools
from dataclasses import MISSING, Field, dataclass, fields
from inspect import Parameter, signature
from typing import List, Optional, Type

from torch import nn


BLACKLIST_TYPES = [nn.Module]


def make_config_class(
func,
whitelist: Optional[List[str]] = None,
blacklist: Optional[List[str]] = None,
blacklist_types: List[Type] = BLACKLIST_TYPES,
):
"""
Create a decorator to create dataclass with the arguments of `func` as fields.
Only annotated arguments are converted to fields. If the default value is mutable,
you must use `dataclass.field(default_factory=default_factory)` as default.
In that case, the func has to be wrapped with @resolve_defaults below.
`whitelist` & `blacklist` are mutually exclusive.
"""

parameters = signature(func).parameters

assert (
whitelist is None or blacklist is None
), "whitelist & blacklist are mutually exclusive"

blacklist_set = set(blacklist or [])

def _is_type_blacklisted(t):
return any(issubclass(t, blacklist_type) for blacklist_type in blacklist_types)

whitelist = whitelist or [
p.name
for p in parameters.values()
if p.name not in blacklist_set
and p.annotation != Parameter.empty
and not _is_type_blacklisted(p.annotation)
]

for field_name in whitelist:
p = parameters[field_name]
assert p.annotation != Parameter.empty and not _is_type_blacklisted(
p.annotation
), f"{field_name} has wrong annotation: {p.annotation}"

def wrapper(config_cls):
# Add __annotations__ for dataclass
config_cls.__annotations__ = {
field_name: parameters[field_name].annotation for field_name in whitelist
}
# Set default values
for field_name in whitelist:
default = parameters[field_name].default
if default != Parameter.empty:
setattr(config_cls, field_name, default)

# Add hashing to support hashing list and dict
config_cls.__hash__ = param_hash

# Add non-recursive asdict(). dataclasses.asdict() is recursive
def asdict(self):
return {field.name: getattr(self, field.name) for field in fields(self)}

config_cls.asdict = asdict

return dataclass(frozen=True)(config_cls)

return wrapper


def _resolve_default(val):
if not isinstance(val, Field):
return val
if val.default != MISSING:
return val.default
if val.default_factory != MISSING:
return val.default_factory()
raise ValueError("No default value")


def resolve_defaults(func):
"""
Use this decorator to resolve defualt field values in the constructor.
"""

field_parameters = [
p for p in signature(func).parameters.values() if isinstance(p.default, Field)
]

@functools.wraps(func)
def wrapper(*args, **kwargs):
for p in field_parameters:
if p.name not in kwargs:
kwargs[p.name] = _resolve_default(p.default)
return func(*args, **kwargs)

return wrapper


def param_hash(p):
"""
Use this to make parameters hashable. This is required because __hash__()
is not inherited when subclass redefines __eq__(). We only need this when
the parameter dataclass has a list or dict field.
"""
return hash(tuple(_hash_field(getattr(p, f.name)) for f in fields(p)))


def _hash_field(val):
"""
Returns hashable value of the argument. A list is converted to a tuple.
A dict is converted to a tuple of sorted pairs of key and value.
"""
if isinstance(val, list):
return tuple(val)
elif isinstance(val, dict):
return tuple(sorted(val.items()))
else:
return val
24 changes: 1 addition & 23 deletions ml/rl/parameters.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import dataclasses
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from ml.rl.core.configuration import param_hash
from ml.rl.parameters_seq2slate import LearningMethod
from ml.rl.types import BaseDataClass


def param_hash(p):
"""
Use this to make parameters hashable. This is required because __hash__()
is not inherited when subclass redefines __eq__(). We only need this when
the parameter dataclass has a list or dict field.
"""
return hash(tuple(_hash_field(getattr(p, f.name)) for f in dataclasses.fields(p)))


def _hash_field(val):
"""
Returns hashable value of the argument. A list is converted to a tuple.
A dict is converted to a tuple of sorted pairs of key and value.
"""
if isinstance(val, list):
return tuple(val)
elif isinstance(val, dict):
return tuple(sorted(val.items()))
else:
return val


@dataclass(frozen=True)
class RLParameters(BaseDataClass):
__hash__ = param_hash
Expand Down
5 changes: 3 additions & 2 deletions ml/rl/test/gridworld/test_gridworld_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUp(self):
super().setUp()

def get_sarsa_parameters(self) -> ParametricDQNTrainerParameters:
return ParametricDQNTrainerParameters(
return ParametricDQNTrainerParameters( # type: ignore
rl=RLParameters(
gamma=DISCOUNT, target_update_rate=1.0, maxq_learning=False
),
Expand Down Expand Up @@ -69,8 +69,9 @@ def get_trainer(
reward_network = reward_network.get_distributed_data_parallel_model()

q_network_target = q_network.get_target_network()
param_dict = parameters.asdict() # type: ignore
trainer = ParametricDQNTrainer(
q_network, q_network_target, reward_network, parameters=parameters
q_network, q_network_target, reward_network, **param_dict
)
return trainer

Expand Down
42 changes: 23 additions & 19 deletions ml/rl/training/parametric_dqn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from ml.rl.core.configuration import make_config_class, resolve_defaults
from ml.rl.parameters import ContinuousActionModelParameters
from ml.rl.training.dqn_trainer_base import DQNTrainerBase
from ml.rl.training.training_data_page import TrainingDataPage
Expand All @@ -18,44 +19,42 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ParametricDQNTrainerParameters:
rl: rlp.RLParameters = field(default_factory=rlp.RLParameters)
double_q_learning: bool = True
minibatch_size: int = 1024
minibatches_per_step: int = 1
optimizer: rlp.OptimizerParameters = field(default_factory=rlp.OptimizerParameters)


class ParametricDQNTrainer(DQNTrainerBase):
@resolve_defaults
def __init__(
self,
q_network,
q_network_target,
reward_network,
parameters: ParametricDQNTrainerParameters,
rl: rlp.RLParameters = field(default_factory=rlp.RLParameters), # noqa B008
double_q_learning: bool = True,
minibatch_size: int = 1024,
minibatches_per_step: int = 1,
optimizer: rlp.OptimizerParameters = field( # noqa B008
default_factory=rlp.OptimizerParameters
),
use_gpu: bool = False,
) -> None:
super().__init__(parameters.rl, use_gpu=use_gpu)
super().__init__(rl, use_gpu=use_gpu)

self.double_q_learning = parameters.double_q_learning
self.minibatch_size = parameters.minibatch_size
self.minibatches_per_step = parameters.minibatches_per_step or 1
self.double_q_learning = double_q_learning
self.minibatch_size = minibatch_size
self.minibatches_per_step = minibatches_per_step or 1

self.q_network = q_network
self.q_network_target = q_network_target
self._set_optimizer(parameters.optimizer.optimizer)
self._set_optimizer(optimizer.optimizer)
self.q_network_optimizer = self.optimizer_func(
self.q_network.parameters(),
lr=parameters.optimizer.learning_rate,
weight_decay=parameters.optimizer.l2_decay,
lr=optimizer.learning_rate,
weight_decay=optimizer.l2_decay,
)

self.reward_network = reward_network
self.reward_network_optimizer = self.optimizer_func(
self.reward_network.parameters(),
lr=parameters.optimizer.learning_rate,
weight_decay=parameters.optimizer.l2_decay,
lr=optimizer.learning_rate,
weight_decay=optimizer.l2_decay,
)

def warm_start_components(self):
Expand Down Expand Up @@ -180,3 +179,8 @@ def internal_reward_estimation(self, state, action):
)
self.reward_network.train()
return reward_estimates.q_value.cpu()


@make_config_class(ParametricDQNTrainer.__init__, blacklist=["use_gpu"])
class ParametricDQNTrainerParameters:
pass
8 changes: 6 additions & 2 deletions ml/rl/workflow_utils/transitional.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def create_parametric_dqn_trainer_from_params(
q_network_target = q_network_target.get_distributed_data_parallel_model()
reward_network = reward_network.get_distributed_data_parallel_model()

trainer_parameters = ParametricDQNTrainerParameters(
trainer_parameters = ParametricDQNTrainerParameters( # type: ignore
rl=model.rl,
double_q_learning=model.rainbow.double_q_learning,
minibatch_size=model.training.minibatch_size,
Expand All @@ -214,7 +214,11 @@ def create_parametric_dqn_trainer_from_params(
)

return ParametricDQNTrainer(
q_network, q_network_target, reward_network, trainer_parameters, use_gpu
q_network,
q_network_target,
reward_network,
use_gpu=use_gpu,
**trainer_parameters.asdict() # type: ignore
)


Expand Down

0 comments on commit dedd844

Please sign in to comment.