-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CDAT Migration Phase 1: Replace cdp.cdp_run
#641
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @chengzhuzhang, this PR is ready for review at your convenience (no rush!). The build with the integration tests are passing.
I am just tackling low-hanging fruit in the CDAT migration effort early before 2023.
e3sm_diags/e3sm_diags_driver.py
Outdated
def _run_serially(parameters: List[CoreParameter]) -> List[CoreParameter]: | ||
"""Run diagnostics with the parameters serially. | ||
|
||
Parameters | ||
---------- | ||
parameters : List[CoreParameter] | ||
The list of CoreParameter objects to run diagnostics on. | ||
|
||
Returns | ||
------- | ||
List[CoreParameter] | ||
The list of CoreParameter objects with results from the diagnostic runs. | ||
""" | ||
results = [] | ||
|
||
for p in parameters: | ||
results.append(run_diag(p)) | ||
|
||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on cdp.cdp_run.serial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has been refactored in the latest commit.
e3sm_diags/e3sm_diags_driver.py
Outdated
def _run_with_dask_multiprocessing( | ||
parameters: List[CoreParameter], | ||
num_workers: Optional[int] = None, | ||
) -> List[CoreParameter]: | ||
"""Run diagnostics with the parameters in parallel using Dask. | ||
|
||
This function passes ``run_diag`` to ``dask.bag.map``, which gets computed | ||
with ``.compute``. | ||
|
||
Parameters | ||
---------- | ||
parameters : List[CoreParameter] | ||
The list of CoreParameter objects to run diagnostics on. | ||
num_workers : Optional[int], optional | ||
The number of workers for multiprocessing, by default None | ||
|
||
Returns | ||
------- | ||
List[CoreParameter] | ||
The list of CoreParameter objects with results from the diagnostic runs. | ||
|
||
Notes | ||
----- | ||
https://docs.dask.org/en/stable/generated/dask.bag.map.html | ||
https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html | ||
""" | ||
bag = dask.bag.from_sequence(parameters) | ||
|
||
config = {"scheduler": "processes", "context": "fork"} | ||
with dask.config.set(config): | ||
if num_workers: | ||
results = bag.map(run_diag).compute(num_workers=num_workers) | ||
elif hasattr(parameters[0], "num_workers"): | ||
results = bag.map(run_diag).compute(num_workers=parameters[0].num_workers) | ||
else: | ||
d[t] = 1 | ||
return d | ||
# num of workers is defaulted to the number of logical processes | ||
results = bag.map(run_diag).compute() | ||
|
||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on cdp.cdp_run.multiprocess().
I refactored this function code to remove the func
and context
parameters because they were no longer needed.
func
is simply therun_diag
functioncontext
has always been"fork"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has been refactored in the latest commit.
e3sm_diags/e3sm_diags_driver.py
Outdated
if parameters[0].multiprocessing: | ||
parameters = cdp.cdp_run.multiprocess(run_diag, parameters, context="fork") | ||
elif parameters[0].distributed: | ||
parameters = cdp.cdp_run.distribute(run_diag, parameters) | ||
parameters = _run_with_dask_multiprocessing(parameters) | ||
else: | ||
parameters = cdp.cdp_run.serial(run_diag, parameters) | ||
parameters = _run_serially(parameters) | ||
|
||
parameters = _collapse_results(parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced cdp_run
function calls.
I also removed the distributed option because it seemed like it actually never worked and nobody used it.
This run method requires a local Dask cluster setup and a scheduler address (scheduler_addr
param) passed into cdp_run.distribute()
. The scheduler_addr
is not set anywhere in the codebase for this to work.
More context:
If that arg wasn't used but this error still appears, then I'm not really sure what's up. If you've noticed, since users only run e3sm_diags either in serial or on a single machine/node with multiprocessing, I don't think dask is even needed. It was only selected since we thought we'd run stuff distributedly.
distributed: Set to True to run the diagnostics distributedly. It's False by default. multiprocessing and distributed cannot both be set to True. A Dask cluster needs to be up and running. You'll probably never use this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic has been refactored in the latest commit.
def create_parameter_dict(parameters): | ||
d: Dict[type, int] = dict() | ||
for parameter in parameters: | ||
t = type(parameter) | ||
if t in d.keys(): | ||
d[t] += 1 | ||
else: | ||
d[t] = 1 | ||
return d | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_parameter_dict()
is not new. I just reordered functions.
cdp.cdp_run
dependencycdp.cdp_run
dependency
cdp.cdp_run
dependencycdp.cdp_run
Hey @tomvothecoder , Thank you for starting working on this! A little background: for the "running distributed" option. it was mostly some exploration done by Zeshawn, and never was used in operation. Here is a wiki page for that: https://github.com/E3SM-Project/e3sm_diags/wiki/Running-diagnostics-distributedly |
My approach will require a major rewrite of MPAS-Analysis. No idea what the implications for E3SM_Diags might be but perhaps similar. Still, I think dask also has significant limitations (not cross-node, not task parallel?). |
conda-env/dev.yml
Outdated
- cartopy=0.21.1 | ||
- cartopy_offlinedata=0.21.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the cartopy
dependencies because the integration tests were failing in my local development environment due to SciTools/cartopy#2086
for set_name in parameter.sets: | ||
# FIXME: The parameter and driver for a diagnostic should be mapped | ||
# together. If this is done, the `run_diag` function should be | ||
# accessible by the parameter class without needing to perform a static | ||
# string reference for the module name. | ||
parameter.current_set = set_name | ||
mod_str = "e3sm_diags.driver.{}_driver".format(set_name) | ||
|
||
# Check if there is a matching driver module for the `set_name`. | ||
try: | ||
module = importlib.import_module(mod_str) | ||
single_result = module.run_diag(parameters) | ||
print("") | ||
except ModuleNotFoundError as e: | ||
logger.error(f"'Error with set name {set_name}'", exc_info=e) | ||
continue | ||
|
||
# If the module exists, call the driver module's `run_diag` function. | ||
try: | ||
single_result = module.run_diag(parameter) | ||
results.append(single_result) | ||
except Exception: | ||
logger.exception("Error in {}".format(mod_str), exc_info=True) | ||
traceback.print_exc() | ||
if parameters.debug: | ||
logger.exception(f"Error in {mod_str}", exc_info=True) | ||
|
||
if parameter.debug: | ||
sys.exit() | ||
|
||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original try
and except
statement did not catch ModuleNotFoundError
if a matching driver module could not be found for the set_name
.
To handle this, I refactored this logic to split it into two steps: (1) Try to get the module (2) Try to call the module's run_diags
function.
Side-note: except Exception
should be avoided because it is a catch-all for all possible exceptions which makes it difficult to debug. More specific exceptions should be used instead.
a4258b2
to
56bea35
Compare
@xylar I personally think that parallelization is not too critical for e3sm_diags right now, since the remaping work is off-loaded to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
- Port `serial` and `multiprocess` from `cdp_run` - Remove logic for distributed because it never actually worked Refactor `e3sm_diags_driver.py` functions - Add comments to `main()` - Move function order based on callstack - Move `_collapse_results()` call to `_run_serially()` and `run_with_dask_multiprocess()` - Add docstrings and type annotations - Rename `parameters` to `parameter_results` in `main()` Add `pytest` and `pytest-cov` - Update `test.sh` to use `pytest` - Add `pytest` config in `setup.cfg`
56bea35
to
a0984aa
Compare
- Rename to `_run_with_dask` - Remove `num_workers` arg
Closes #426
This PR removes the
cdp.cdp_run
dependency, which mainly just involved porting code over from this module and doing some refactoring.Summary of Changes:
serial
andmultiprocess
fromcdp.cdp_run
dask
as a dependency in the conda env yml filesTODO:
_run_serially()
_run_with_dask_multiprocessing