Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Python/C++ interop via exposed JIT functionality #2214

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/Modules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(CONFIG_FILES
CUDAQConfig.cmake
CUDAQEnsmallenConfig.cmake
CUDAQPlatformDefaultConfig.cmake
CUDAQPythonInteropConfig.cmake
)
set(LANG_FILES
CMakeCUDAQCompiler.cmake.in
Expand Down
3 changes: 3 additions & 0 deletions cmake/Modules/CUDAQConfig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ find_dependency(CUDAQNlopt REQUIRED)
set (CUDAQEnsmallen_DIR "${CUDAQ_CMAKE_DIR}")
find_dependency(CUDAQEnsmallen REQUIRED)

set (CUDAQPythonInterop_DIR "${CUDAQ_CMAKE_DIR}")
find_dependency(CUDAQPythonInterop REQUIRED)

get_filename_component(PARENT_DIRECTORY ${CUDAQ_CMAKE_DIR} DIRECTORY)
get_filename_component(CUDAQ_LIBRARY_DIR ${PARENT_DIRECTORY} DIRECTORY)
get_filename_component(CUDAQ_INSTALL_DIR ${CUDAQ_LIBRARY_DIR} DIRECTORY)
Expand Down
13 changes: 13 additions & 0 deletions cmake/Modules/CUDAQPythonInteropConfig.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

get_filename_component(CUDAQ_PYTHONINTEROP_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)

if(NOT TARGET cudaq::cudaq-python-interop)
include("${CUDAQ_PYTHONINTEROP_CMAKE_DIR}/CUDAQPythonInteropTargets.cmake")
endif()
5 changes: 3 additions & 2 deletions include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ std::unique_ptr<mlir::Pass> createRaiseToAffinePass();
std::unique_ptr<mlir::Pass> createUnwindLoweringPass();

std::unique_ptr<mlir::Pass>
createPySynthCallableBlockArgs(const std::vector<std::string> &);
createPySynthCallableBlockArgs(const std::vector<std::string> &,
bool removeBlockArg = false);
inline std::unique_ptr<mlir::Pass> createPySynthCallableBlockArgs() {
return createPySynthCallableBlockArgs({});
return createPySynthCallableBlockArgs({}, false);
}

/// Helper function to build an argument synthesis pass. The names of the
Expand Down
61 changes: 55 additions & 6 deletions lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,41 @@ class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
}
};

class ReplaceCallCallable
: public OpConversionPattern<cudaq::cc::CallCallableOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;

ReplaceCallCallable(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
: OpConversionPattern<cudaq::cc::CallCallableOp>(ctx),
names(functionNames), blockArgToNameMap(map) {}

LogicalResult
matchAndRewrite(cudaq::cc::CallCallableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto callableOperand = adaptor.getCallee();
auto module = op->getParentOp()->getParentOfType<ModuleOp>();
if (auto blockArg = dyn_cast<BlockArgument>(callableOperand)) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
return failure();
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, replacement,
adaptor.getArgs());
return success();
}
return failure();
}
};

class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
public:
const std::vector<std::string> &names;
Expand Down Expand Up @@ -97,10 +132,13 @@ class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
class PySynthCallableBlockArgs
: public cudaq::opt::PySynthCallableBlockArgsBase<
PySynthCallableBlockArgs> {
private:
bool removeBlockArg = false;

public:
std::vector<std::string> names;
PySynthCallableBlockArgs(const std::vector<std::string> &_names)
: names(_names) {}
PySynthCallableBlockArgs(const std::vector<std::string> &_names, bool remove)
: removeBlockArg(remove), names(_names) {}

void runOnOperation() override {
auto op = getOperation();
Expand Down Expand Up @@ -129,8 +167,9 @@ class PySynthCallableBlockArgs
return;
}

patterns.insert<ReplaceCallIndirect, UpdateQuakeApplyOp>(
ctx, names, blockArgToNamesMap);
patterns
.insert<ReplaceCallIndirect, ReplaceCallCallable, UpdateQuakeApplyOp>(
ctx, names, blockArgToNamesMap);
ConversionTarget target(*ctx);
// We should remove these operations
target.addIllegalOp<func::CallIndirectOp>();
Expand All @@ -148,11 +187,21 @@ class PySynthCallableBlockArgs
"error synthesizing callable functions for python.\n");
signalPassFailure();
}

if (removeBlockArg) {
auto numArgs = op.getNumArguments();
BitVector argsToErase(numArgs);
for (std::size_t argIndex = 0; argIndex < numArgs; ++argIndex)
if (isa<cudaq::cc::CallableType>(op.getArgument(argIndex).getType()))
argsToErase.set(argIndex);

op.eraseArguments(argsToErase);
}
}
};
} // namespace

std::unique_ptr<Pass> cudaq::opt::createPySynthCallableBlockArgs(
const std::vector<std::string> &names) {
return std::make_unique<PySynthCallableBlockArgs>(names);
const std::vector<std::string> &names, bool removeBlockArg) {
return std::make_unique<PySynthCallableBlockArgs>(names, removeBlockArg);
}
2 changes: 2 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ if(CUDAQ_BUILD_TESTS)
message(FATAL_ERROR "CUDA Quantum Python Warning - CUDAQ_BUILD_TESTS=TRUE but can't find numpy or pytest modules required for testing.")
endif()
endif()

add_subdirectory(runtime/interop)
26 changes: 26 additions & 0 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,32 @@ def visit_Call(self, node):
self.__insertDbgStmt(self.popValue(), node.func.attr)
return

# Handle the case of `mod.func`, where mod is not cudaq.
if isinstance(node.func, ast.Attribute) and isinstance(
node.func.value,
ast.Name) and node.func.value.id != 'cudaq':
# This could be a C++ generated kernel,
# if so we should get it and add it to
# the module
maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel(
self.module, node.func.value.id + '.' + node.func.attr)
if maybeKernelName != None:
otherKernel = SymbolTable(
self.module.operation)[maybeKernelName]
fType = otherKernel.type
if len(fType.inputs) != len(node.args):
funcName = node.func.id if hasattr(
node.func, 'id') else node.func.attr
self.emitFatalError(
f"invalid number of arguments passed to callable {funcName} ({len(node.args)} vs required {len(fType.inputs)})",
node)

[self.visit(arg) for arg in node.args]
values = [self.popValue() for _ in node.args]
values.reverse()
func.CallOp(otherKernel, values)
return

# If we did have module names, then this is what we are looking for
if len(moduleNames):
name = node.func.attr
Expand Down
45 changes: 44 additions & 1 deletion python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Callable
from ..mlir.ir import *
from ..mlir.passmanager import *
from ..mlir.dialects import quake, cc
from ..mlir.dialects import quake, cc, func
from .ast_bridge import compile_to_mlir, PyASTBridge
from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli, globalRegisteredTypes
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
Expand Down Expand Up @@ -220,6 +220,49 @@ def compile(self):
self.dependentCaptures = extraMetadata[
'dependent_captures'] if 'dependent_captures' in extraMetadata else None

def merge_kernel(self, otherMod):
"""
Merge the kernel in this PyKernelDecorator (the ModuleOp) with
the provided ModuleOp.
"""
self.compile()
if not isinstance(otherMod, str):
otherMod = str(otherMod)
newMod = cudaq_runtime.mergeExternalMLIR(self.module, otherMod)
# Get the name of the kernel entry point
name = self.name
for op in newMod.body:
if isinstance(op, func.FuncOp):
for attr in op.attributes:
if 'cudaq-entrypoint' == attr.name:
name = op.name.value.replace(nvqppPrefix, '')
break

return PyKernelDecorator(None, kernelName=name, module=newMod)

def synthesize_callable_arguments(self, funcNames):
"""
Given this Kernel has callable block arguments, synthesize away these
callable arguments with the in-module FuncOps with given names. The
name at index 0 in the list corresponds to the first callable block
argument, index 1 to the second callable block argument, etc.
"""
self.compile()
cudaq_runtime.synthPyCallable(self.module, funcNames)
# Reset the argument types by removing the Callable
self.argTypes = [
a for a in self.argTypes if not cc.CallableType.isinstance(a)
]

def extract_c_function_pointer(self, name=None):
"""
Return the C function pointer for the function with given name, or
with the name of this kernel if not provided.
"""
self.compile()
return cudaq_runtime.jitAndGetFunctionPointer(
self.module, nvqppPrefix + self.name if name is None else name)

def __str__(self):
"""
Return the MLIR Module string representation for this kernel.
Expand Down
6 changes: 6 additions & 0 deletions python/cudaq/kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def mlirTypeToPyType(argType):
if F32Type.isinstance(argType):
return np.float32

if quake.VeqType.isinstance(argType):
return qvector

if cc.CallableType.isinstance(argType):
return Callable

if ComplexType.isinstance(argType):
if F64Type.isinstance(ComplexType(argType).element_type):
return complex
Expand Down
1 change: 1 addition & 0 deletions python/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ target_link_libraries(CUDAQuantumPythonSources.Extension INTERFACE
cudaq-common
cudaq-em-default
cudaq-em-photonics
cudaq-python-interop
fmt::fmt-header-only
unzip_util
)
Expand Down
32 changes: 32 additions & 0 deletions python/extension/CUDAQuantumExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@
#include "utils/OpaqueArguments.h"

#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"

#include "runtime/interop/PythonCppInterop.h"

#include <pybind11/complex.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -221,4 +224,33 @@ PYBIND11_MODULE(_quakeDialects, m) {
cudaqRuntime.def("isTerminator", [](MlirOperation op) {
return unwrap(op)->hasTrait<mlir::OpTrait::IsTerminator>();
});

cudaqRuntime.def(
"checkRegisteredCppDeviceKernel",
[](MlirModule mod,
const std::string &moduleName) -> std::optional<std::string> {
std::tuple<std::string, std::string> ret;
try {
ret = cudaq::getDeviceKernel(moduleName);
} catch (...) {
return std::nullopt;
}

// Take the code for the kernel we found
// and add it to the input module, return
// the func op.
auto [kName, code] = ret;
auto ctx = unwrap(mod).getContext();
auto moduleB = mlir::parseSourceString<ModuleOp>(code, ctx);
auto moduleA = unwrap(mod);
moduleB->walk([&moduleA](func::FuncOp op) {
if (!moduleA.lookupSymbol<func::FuncOp>(op.getName()))
moduleA.push_back(op.clone());
return WalkResult::advance();
});
return kName;
},
"Given a python module name like `mod1.mod2.func`, see if there is a "
"registered C++ quantum kernel. If so, add the kernel to the Module and "
"return its name.");
}
66 changes: 66 additions & 0 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include <fmt/core.h>
Expand Down Expand Up @@ -746,5 +747,70 @@ void bindAltLaunchKernel(py::module &mod) {
}
},
"Remove our pointers to the cudaq states.");

mod.def(
"mergeExternalMLIR",
[](MlirModule modA, const std::string &modBStr) {
auto ctx = unwrap(modA).getContext();
auto moduleB = parseSourceString<ModuleOp>(modBStr, ctx);
auto moduleA = unwrap(modA).clone();
moduleB->walk([&moduleA](func::FuncOp op) {
if (!moduleA.lookupSymbol<func::FuncOp>(op.getName()))
moduleA.push_back(op.clone());
return WalkResult::advance();
});
return wrap(moduleA);
},
"Merge the two Modules into a single Module.");

mod.def(
"synthPyCallable",
[](MlirModule modA, const std::vector<std::string> &funcNames) {
auto m = unwrap(modA);
auto context = m.getContext();
PassManager pm(context);
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createPySynthCallableBlockArgs(funcNames, true));
if (failed(pm.run(m)))
throw std::runtime_error(
"cudaq::jit failed to remove callable block arguments.");

// fix up the mangled name map
DictionaryAttr attr;
m.walk([&](func::FuncOp op) {
if (op->hasAttrOfType<UnitAttr>("cudaq-entrypoint")) {
auto strAttr = StringAttr::get(
context, op.getName().str() + "_PyKernelEntryPointRewrite");
attr = DictionaryAttr::get(
context, {NamedAttribute(StringAttr::get(context, op.getName()),
strAttr)});
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (attr)
m->setAttr("quake.mangled_name_map", attr);
},
"Synthesize away the callable block argument from the entrypoint in modA "
"with the FuncOp of given name.");

mod.def(
"jitAndGetFunctionPointer",
[](MlirModule mod, const std::string &funcName) {
OpaqueArguments runtimeArgs;
auto noneType = mlir::NoneType::get(unwrap(mod).getContext());
auto [jit, rawArgs, size, returnOffset] =
jitAndCreateArgs(funcName, mod, runtimeArgs, {}, noneType);

auto funcPtr = jit->lookup(funcName);
if (!funcPtr) {
throw std::runtime_error(
"cudaq::builder failed to get kernelReg function.");
}

return py::capsule(*funcPtr);
},
"JIT compile and return the C function pointer for the FuncOp of given "
"name.");
}
} // namespace cudaq
Loading
Loading