Skip to content

Commit

Permalink
Get model fit data in gen_metadata on client completion events (#2512)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2512

Differential Revision: D58261583
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 14, 2024
1 parent 8c65768 commit 3467b27
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 19 deletions.
54 changes: 45 additions & 9 deletions ax/telemetry/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional

from ax.core.base_trial import BaseTrial

from ax.service.ax_client import AxClient
from ax.telemetry.common import _get_max_transformed_dimensionality
from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord
Expand Down Expand Up @@ -100,24 +102,58 @@ class AxClientCompletedRecord:
experiment_completed_record: ExperimentCompletedRecord

best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float
model_fit_quality: Optional[float]
model_std_quality: Optional[float]
model_fit_generalization: Optional[float]
model_std_generalization: Optional[float]

@classmethod
def from_ax_client(cls, ax_client: AxClient) -> AxClientCompletedRecord:
def from_ax_client(
cls, ax_client: AxClient, completed_trial: BaseTrial
) -> AxClientCompletedRecord:
return cls(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=ax_client.experiment
),
best_point_quality=float("nan"), # TODO[T147907632]
model_fit_quality=float("nan"), # TODO[T147907632]
model_std_quality=float("nan"),
model_fit_generalization=float("nan"),
model_std_generalization=float("nan"),
**cls._get_model_fit_data_for_trial(completed_trial=completed_trial),
)

@staticmethod
def _get_model_fit_data_for_trial(
completed_trial: BaseTrial,
) -> Dict[str, Optional[float]]:
"""Get model fit quality data for a completed trial. This method assumes that
there is only one generator run on the trial from a model for which fit data
is applicable. If there are multiple, it will use the first."""
fit_and_std_dict = {
"model_fit_quality": None, # TODO[T147907632]
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}
empty_results = (float("nan"), None)
for gr in completed_trial.generator_runs:
gen_metadata = {} if gr.gen_metadata is None else gr.gen_metadata
if (
gen_metadata.get("model_fit_quality") not in empty_results
or gen_metadata.get("model_std_quality") not in empty_results
or gen_metadata.get("model_fit_generalization") not in empty_results
or gen_metadata.get("model_std_generalization") not in empty_results
):
fit_and_std_dict = {
"model_fit_quality": gen_metadata.get("model_fit_quality"),
"model_std_quality": gen_metadata.get("model_std_quality"),
"model_fit_generalization": gen_metadata.get(
"model_fit_generalization"
),
"model_std_generalization": gen_metadata.get(
"model_std_generalization"
),
}
break
return fit_and_std_dict

def flatten(self) -> Dict[str, Any]:
"""
Flatten into an appropriate format for logging to a tabular database.
Expand Down
6 changes: 5 additions & 1 deletion ax/telemetry/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union

from ax.core.base_trial import BaseTrial

from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy

Expand Down Expand Up @@ -514,9 +516,11 @@ def from_ax_client(
deployed_job_id: Optional[int],
estimated_early_stopping_savings: float,
estimated_global_stopping_savings: float,
completed_trial: BaseTrial,
) -> OptimizationCompletedRecord:
ax_client_completed_record = AxClientCompletedRecord.from_ax_client(
ax_client=ax_client
ax_client=ax_client,
completed_trial=completed_trial,
)
experiment_completed_record = (
ax_client_completed_record.experiment_completed_record
Expand Down
96 changes: 88 additions & 8 deletions ax/telemetry/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

from typing import Dict, List, Sequence, Union
from unittest.mock import ANY

import numpy as np

Expand All @@ -17,6 +18,7 @@
from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord
from ax.telemetry.generation_strategy import GenerationStrategyCreatedRecord
from ax.utils.common.testutils import TestCase
from ax.utils.measurement.synthetic_functions import branin


class TestAxClient(TestCase):
Expand Down Expand Up @@ -116,21 +118,97 @@ def test_ax_client_completed_record_from_ax_client(self) -> None:
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
)
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
# pyre-ignore[6]: branin does return a float
raw_data={"branin": branin(params["x"], params["y"])},
)

record = AxClientCompletedRecord.from_ax_client(ax_client=ax_client)
record = AxClientCompletedRecord.from_ax_client(
ax_client=ax_client,
completed_trial=ax_client.get_trial(0),
)

expected = AxClientCompletedRecord(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=ax_client.experiment
),
best_point_quality=float("nan"),
model_fit_quality=float("nan"),
model_std_quality=float("nan"),
model_fit_generalization=float("nan"),
model_std_generalization=float("nan"),
model_fit_quality=None,
model_std_quality=None,
model_fit_generalization=None,
model_std_generalization=None,
)
self._compare_axclient_completed_records(record, expected)

def test_ax_client_completed_record_from_ax_client_for_model_that_fits(
self,
) -> None:
num_sobol_trials = 5
ax_client = AxClient()
ax_client.create_experiment(
name="test_experiment",
parameters=[
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
choose_generation_strategy_kwargs={
"num_initialization_trials": num_sobol_trials
},
)
for _ in range(num_sobol_trials + 1):
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
# pyre-ignore[6]: branin does return a float
raw_data={"branin": branin(params["x"], params["y"])},
)

with self.subTest("sobol trial"):
record = AxClientCompletedRecord.from_ax_client(
ax_client=ax_client,
# the last trial is not sobol so the fit can be evaluated
completed_trial=ax_client.get_trial(num_sobol_trials - 1),
)

expected = AxClientCompletedRecord(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=ax_client.experiment
),
best_point_quality=float("nan"),
model_fit_quality=None,
model_std_quality=None,
model_fit_generalization=None,
model_std_generalization=None,
)
self._compare_axclient_completed_records(record, expected)

with self.subTest("non sobol trial"):
record = AxClientCompletedRecord.from_ax_client(
ax_client=ax_client,
# the last trial is not sobol so the fit can be evaluated
completed_trial=ax_client.get_trial(num_sobol_trials),
)

expected = AxClientCompletedRecord(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=ax_client.experiment
),
best_point_quality=float("nan"),
model_fit_quality=ANY,
model_std_quality=ANY,
model_fit_generalization=ANY,
model_std_generalization=ANY,
)
self._compare_axclient_completed_records(record, expected)
self.assertIsNotNone(record.model_fit_quality)
self.assertIsNotNone(record.model_std_quality)
self.assertIsNotNone(record.model_fit_generalization)
self.assertIsNotNone(record.model_std_generalization)

def test_batch_trial_warning(self) -> None:
ax_client = AxClient()
error_msg = (
Expand Down Expand Up @@ -166,7 +244,9 @@ def _compare_axclient_completed_records(
for field in numeric_fields:
rec_field = getattr(record, field)
exp_field = getattr(expected, field)
if np.isnan(rec_field):
self.assertTrue(np.isnan(exp_field))
if rec_field is None:
self.assertIsNone(exp_field, msg=field)
elif np.isnan(rec_field):
self.assertTrue(np.isnan(exp_field), msg=field)
else:
self.assertAlmostEqual(rec_field, exp_field)
self.assertAlmostEqual(rec_field, exp_field, msg=field)
11 changes: 10 additions & 1 deletion ax/telemetry/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,25 @@ def test_optimization_completed_record_from_ax_client(self) -> None:
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
)
ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=0,
raw_data={"branin": 10.0},
)
completed_trial = ax_client.get_trial(0)

record = OptimizationCompletedRecord.from_ax_client(
ax_client=ax_client,
unique_identifier="foo",
deployed_job_id=1118,
estimated_early_stopping_savings=19,
estimated_global_stopping_savings=98,
completed_trial=completed_trial,
)
expected_dict = {
**AxClientCompletedRecord.from_ax_client(ax_client=ax_client).flatten(),
**AxClientCompletedRecord.from_ax_client(
ax_client=ax_client, completed_trial=completed_trial
).flatten(),
"unique_identifier": "foo",
"deployed_job_id": 1118,
"estimated_early_stopping_savings": 19,
Expand Down

0 comments on commit 3467b27

Please sign in to comment.