Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unnecessary serialisation of PassManager in serial contexts (backport #12410) #12500

Merged
merged 3 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions qiskit/passmanager/passmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import dill

from qiskit.utils.parallel import parallel_map
from qiskit.utils.parallel import parallel_map, should_run_in_parallel
from .base_tasks import Task, PassManagerIR
from .exceptions import PassManagerError
from .flow_controllers import FlowControllerLinear
Expand Down Expand Up @@ -220,16 +220,16 @@ def callback_func(**kwargs):
in_programs = [in_programs]
is_list = False

if len(in_programs) == 1:
out_program = _run_workflow(
program=in_programs[0],
pass_manager=self,
callback=callback,
**kwargs,
)
if is_list:
return [out_program]
return out_program
# If we're not going to run in parallel, we want to avoid spending time `dill` serialising
# ourselves, since that can be quite expensive.
if len(in_programs) == 1 or not should_run_in_parallel():
out = [
_run_workflow(program=program, pass_manager=self, callback=callback, **kwargs)
for program in in_programs
]
if len(in_programs) == 1 and not is_list:
return out[0]
return out

del callback
del kwargs
Expand Down
4 changes: 2 additions & 2 deletions qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
from .backend_utils import has_ibmq, has_aer
from .name_unnamed_args import name_args
from .algorithm_globals import algorithm_globals

from .parallel import parallel_map
from .parallel import parallel_map, should_run_in_parallel


__all__ = [
Expand All @@ -114,4 +113,5 @@
"is_main_process",
"apply_prefix",
"parallel_map",
"should_run_in_parallel",
]
38 changes: 24 additions & 14 deletions qiskit/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from the multiprocessing library.
"""

from __future__ import annotations

import os
from concurrent.futures import ProcessPoolExecutor
import sys
Expand Down Expand Up @@ -103,6 +105,21 @@ def _task_wrapper(param):
return task(value, *task_args, **task_kwargs)


def should_run_in_parallel(num_processes: int | None = None) -> bool:
"""Return whether the current parallelisation configuration suggests that we should run things
like :func:`parallel_map` in parallel (``True``) or degrade to serial (``False``).

Args:
num_processes: the number of processes requested for use (if given).
"""
num_processes = CPU_COUNT if num_processes is None else num_processes
return (
num_processes > 1
and os.getenv("QISKIT_IN_PARALLEL", "FALSE") == "FALSE"
and CONFIG.get("parallel_enabled", PARALLEL_DEFAULT)
)


def parallel_map( # pylint: disable=dangerous-default-value
task, values, task_args=(), task_kwargs={}, num_processes=CPU_COUNT
):
Expand All @@ -112,21 +129,20 @@ def parallel_map( # pylint: disable=dangerous-default-value

result = [task(value, *task_args, **task_kwargs) for value in values]

On Windows this function defaults to a serial implementation to avoid the
overhead from spawning processes in Windows.
This will parallelise the results if the number of ``values`` is greater than one, and the
current system configuration permits parallelization.

Args:
task (func): Function that is to be called for each value in ``values``.
values (array_like): List or array of values for which the ``task``
function is to be evaluated.
values (array_like): List or array of values for which the ``task`` function is to be
evaluated.
task_args (list): Optional additional arguments to the ``task`` function.
task_kwargs (dict): Optional additional keyword argument to the ``task`` function.
num_processes (int): Number of processes to spawn.

Returns:
result: The result list contains the value of
``task(value, *task_args, **task_kwargs)`` for
each value in ``values``.
result: The result list contains the value of ``task(value, *task_args, **task_kwargs)`` for
each value in ``values``.

Raises:
QiskitError: If user interrupts via keyboard.
Expand Down Expand Up @@ -155,11 +171,7 @@ def _callback(_):
Publisher().publish("terra.parallel.done", nfinished[0])

# Run in parallel if not Win and not in parallel already
if (
num_processes > 1
and os.getenv("QISKIT_IN_PARALLEL") == "FALSE"
and CONFIG.get("parallel_enabled", PARALLEL_DEFAULT)
):
if should_run_in_parallel(num_processes):
os.environ["QISKIT_IN_PARALLEL"] = "TRUE"
try:
results = []
Expand All @@ -183,8 +195,6 @@ def _callback(_):
os.environ["QISKIT_IN_PARALLEL"] = "FALSE"
return results

# Cannot do parallel on Windows , if another parallel_map is running in parallel,
# or len(values) == 1.
results = []
for _, value in enumerate(values):
result = task(value, *task_args, **task_kwargs)
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/parallel-check-8186a8f074774a1f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
:meth:`.PassManager.run` will no longer waste time serializing itself when given multiple inputs
if it is only going to work in serial.
Loading