Skip to content

Commit

Permalink
Add exp_pauli quantum instruction to the runtime and ASTBridge (#660)
Browse files Browse the repository at this point in the history
* Add exp_pauli quantum instruction to the runtime and ASTBridge

Signed-off-by: Alex McCaskey <[email protected]>

* Drop the cc::StringType for the moment.

Convert the headers to use ASCIIZ string literals for the final argument
to exp_pauli(). (Simplifies the AST presented to the bridge and falls
back to using code that handles string literals.)

Update LowerToQIR for the new types, fix bugs.

Fix up tests.

* build fixes, still have to fix the wheel validation

Signed-off-by: Alex McCaskey <[email protected]>

* cleanup, add qpp applyExpPauli impl

Signed-off-by: Alex McCaskey <[email protected]>

* fix docs gen

Signed-off-by: Alex McCaskey <[email protected]>

* clean up

Signed-off-by: Alex McCaskey <[email protected]>

* clean up, provide docs

Signed-off-by: Alex McCaskey <[email protected]>

* Update runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cu

Co-authored-by: Thien Nguyen <[email protected]>

---------

Signed-off-by: Alex McCaskey <[email protected]>
Co-authored-by: Eric Schweitz <[email protected]>
Co-authored-by: Thien Nguyen <[email protected]>
  • Loading branch information
3 people authored Sep 29, 2023
1 parent 8369b4e commit da6d0b9
Show file tree
Hide file tree
Showing 35 changed files with 842 additions and 438 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ _version.py

# third party integrations
simulators/
apps/

# macOS
.DS_Store
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx/api/languages/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Program Construction
.. automethod:: rz
.. automethod:: r1
.. automethod:: swap
.. automethod:: exp_pauli
.. automethod:: mx
.. automethod:: my
.. automethod:: mz
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def setup(app):
('cpp:identifier', 'mlir::ImplicitLocOpBuilder'),
('cpp:identifier', 'BinarySymplecticForm'),
('cpp:identifier', 'CountsDictionary'),
('cpp:identifier', 'QuakeValueOrNumericType'),
('py:class', 'function'),
('py:class', 'type'),
('py:class', 'cudaq::spin_op'),
Expand Down
12 changes: 12 additions & 0 deletions include/cudaq/Frontend/nvqpp/ASTBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -704,4 +704,16 @@ inline bool isCallOperator(clang::OverloadedOperatorKind kindValue) {
return kindValue == clang::OverloadedOperatorKind::OO_Call;
}

// Is \p t of type `char *`?
inline bool isCharPointerType(mlir::Type t) {
if (auto ptrTy = dyn_cast<cc::PointerType>(t)) {
mlir::Type eleTy = ptrTy.getElementType();
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy))
eleTy = arrTy.getElementType();
if (auto intTy = dyn_cast<mlir::IntegerType>(eleTy))
return intTy.getWidth() == 8;
}
return false;
}

} // namespace cudaq
18 changes: 18 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1369,4 +1369,22 @@ def cc_CallableClosureOp : CCOp<"callable_closure", [Pure]> {
}];
}

def cc_CreateStringLiteralOp : CCOp<"string_literal"> {
let summary = "Create a constant string literal.";
let description = [{
This operation creates a ASCIIZ string literal value. It's argument is a
constant MLIR String Attribute. The literal will have a null character
appended automatically.

```mlir
%0 = cc.string_literal "Quantum Computing" : !cc.ptr<!cc.array<i8 x 18>>
```
}];

let arguments = (ins StrAttr:$stringLiteral);
let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$stringLiteral `:` qualified(type(results)) attr-dict
}];
}
#endif // CUDAQ_OPTIMIZER_DIALECT_CC_OPS
19 changes: 18 additions & 1 deletion include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,23 @@ class TwoTargetOp<string mnemonic, list<Trait> traits = []> :
// Quantum operators (gates)
//===----------------------------------------------------------------------===//

def ExpPauliOp : QuakeOp<"exp_pauli", []> {
let summary = "General Pauli tensor product rotation";
let description = [{
This operation affects a general Pauli tensor product rotation on
the input qubits. The number of Pauli characters in the input Pauli word
string must equal the number of qubits in the veq. Mathematically, this operation
applies exp(i theta P) where P is a general Pauli tensor product.
}];

let arguments = (ins AnyFloat:$parameter, VeqType:$qubits, cc_PointerType:$pauli);
let results = (outs );

let assemblyFormat = [{
`(` $parameter `)` $qubits `,` $pauli `:` functional-type(operands, results) attr-dict
}];
}

def HOp : OneTargetOp<"h", [Hermitian]> {
let summary = "Hadamard operation";
let description = [{
Expand Down Expand Up @@ -815,7 +832,7 @@ def PhasedRxOp : QuakeOperator<"phased_rx",
Matrix representation:
```
PhasedRx(θ,φ) = | cos(θ/2) -iexp(-iφ) * sin(θ/2) |
| -iexp(iφ)) * sin(θ/2) cos(θ/2) |
| -iexp(iφ) * sin(θ/2) cos(θ/2) |
```

Circuit symbol:
Expand Down
7 changes: 4 additions & 3 deletions lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ void QuakeBridgeVisitor::addArgumentSymbols(
// Transform pass-by-value arguments to stack slots.
auto loc = toLocation(argVal);
auto parmTy = entryBlock->getArgument(index).getType();
if (isa<cc::CallableType, cc::StdvecType, cc::ArrayType, cc::StructType,
LLVM::LLVMStructType, FunctionType, quake::RefType,
quake::VeqType>(parmTy)) {
if (isa<FunctionType, cc::ArrayType, cc::CallableType, cc::PointerType,
cc::StdvecType, cc::StructType, LLVM::LLVMStructType,
quake::ControlType, quake::RefType, quake::VeqType,
quake::WireType>(parmTy)) {
symbolTable.insert(name, entryBlock->getArgument(index));
} else {
auto stackSlot = builder.create<cc::AllocaOp>(loc, parmTy);
Expand Down
49 changes: 46 additions & 3 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,47 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
isAdjoint = structTypeAsRecord->getName() == "adj";
}

if (funcName.equals("exp_pauli")) {
assert(args.size() > 2);
SmallVector<Value> processedArgs;
auto addTheString = [&](Value v) {
// The C-string argument (char*) may be loaded by an lvalue to rvalue
// cast. Here, we must pass the pointer and not the first character's
// value.
if (isCharPointerType(v.getType())) {
processedArgs.push_back(v);
} else if (auto load = v.getDefiningOp<cudaq::cc::LoadOp>()) {
processedArgs.push_back(load.getPtrvalue());
} else {
reportClangError(x, mangler, "could not determine string argument");
}
};
if (args.size() == 3 && isa<quake::VeqType>(args[1].getType())) {
// Have f64, veq, string
processedArgs.push_back(args[0]);
processedArgs.push_back(args[1]);
addTheString(args[2]);
} else {
// should have f64, string, qubits...
// need f64, veq, string, so process here

// add f64 value
processedArgs.push_back(args[0]);

// concat the qubits to a veq
SmallVector<Value> quantumArgs;
for (std::size_t i = 2; i < args.size(); i++)
quantumArgs.push_back(args[i]);
processedArgs.push_back(builder.create<quake::ConcatOp>(
loc, quake::VeqType::get(builder.getContext(), quantumArgs.size()),
quantumArgs));
addTheString(args[1]);
}

builder.create<quake::ExpPauliOp>(loc, TypeRange{}, processedArgs);
return true;
}

if (funcName.equals("mx") || funcName.equals("my") ||
funcName.equals("mz")) {
// Measurements always return a bool or a std::vector<bool>.
Expand Down Expand Up @@ -2140,7 +2181,7 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {

// TODO: remove this when we can handle ctors more generally.
if (!ctor->isDefaultConstructor()) {
LLVM_DEBUG(llvm::dbgs() << "unhandled ctor:\n"; x->dump());
LLVM_DEBUG(llvm::dbgs() << ctorName << " - unhandled ctor:\n"; x->dump());
TODO_loc(loc, "C++ ctor (not-default)");
}

Expand Down Expand Up @@ -2206,8 +2247,10 @@ bool QuakeBridgeVisitor::VisitDeclRefExpr(clang::DeclRefExpr *x) {
}

bool QuakeBridgeVisitor::VisitStringLiteral(clang::StringLiteral *x) {
TODO_x(toLocation(x->getSourceRange()), x, mangler, "string literal");
return false;
auto strLitTy = cc::PointerType::get(cc::ArrayType::get(
builder.getContext(), builder.getI8Type(), x->getString().size() + 1));
return pushValue(builder.create<cc::CreateStringLiteralOp>(
toLocation(x), strLitTy, builder.getStringAttr(x->getString())));
}

} // namespace cudaq::details
3 changes: 2 additions & 1 deletion lib/Frontend/nvqpp/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ QuakeBridgeVisitor::findCallOperator(const clang::CXXRecordDecl *decl) {

bool QuakeBridgeVisitor::TraverseRecordType(clang::RecordType *t) {
auto *recDecl = t->getDecl();

if (ignoredClass(recDecl))
return true;
auto reci = records.find(t);
Expand Down Expand Up @@ -311,7 +312,7 @@ bool QuakeBridgeVisitor::doSyntaxChecks(const clang::FunctionDecl *x) {
// device kernels may take veq and/or ref arguments.
if (isArithmeticType(t) || isArithmeticSequenceType(t) ||
isQuantumType(t) || isKernelCallable(t) || isFunctionCallable(t) ||
isReferenceToCallableRecord(t, p))
isCharPointerType(t) || isReferenceToCallableRecord(t, p))
continue;
reportClangError(p, mangler, "kernel argument type not supported");
return false;
Expand Down
116 changes: 91 additions & 25 deletions lib/Optimizer/CodeGen/LowerToQIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,16 +316,14 @@ class SubveqOpRewrite : public ConvertOpToLLVMPattern<quake::SubVeqOp> {
};

/// Lower the quake.reset op to QIR
template <typename ResetOpType>
class ResetRewrite : public ConvertOpToLLVMPattern<ResetOpType> {
class ResetRewrite : public ConvertOpToLLVMPattern<quake::ResetOp> {
public:
using Base = ConvertOpToLLVMPattern<ResetOpType>;
using Base::Base;
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(ResetOpType instOp, typename Base::OpAdaptor adaptor,
matchAndRewrite(quake::ResetOp instOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto parentModule = instOp->template getParentOfType<ModuleOp>();
auto parentModule = instOp->getParentOfType<ModuleOp>();
auto context = parentModule->getContext();
std::string qirQisPrefix(cudaq::opt::QIRQISPrefix);
std::string instName = instOp->getName().stripDialect().str();
Expand All @@ -348,6 +346,37 @@ class ResetRewrite : public ConvertOpToLLVMPattern<ResetOpType> {
}
};

/// Lower exp_pauli(f64, veq, cc.string) to __quantum__qis__exp_pauli
class ExpPauliRewrite : public ConvertOpToLLVMPattern<quake::ExpPauliOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(quake::ExpPauliOp instOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = instOp->getLoc();
auto parentModule = instOp->getParentOfType<ModuleOp>();
auto *context = rewriter.getContext();
std::string qirQisPrefix(cudaq::opt::QIRQISPrefix);
auto qirFunctionName = qirQisPrefix + "exp_pauli";
FlatSymbolRefAttr symbolRef = cudaq::opt::factory::createLLVMFunctionSymbol(
qirFunctionName, /*return type=*/LLVM::LLVMVoidType::get(context),
{rewriter.getF64Type(), cudaq::opt::getArrayType(context),
cudaq::opt::factory::getPointerType(context)},
parentModule);
SmallVector<Value> operands = adaptor.getOperands();
// Make sure to drop any length information from the type of the Pauli word.
auto pauliWord = operands.back();
operands.pop_back();
auto castedPauli = rewriter.create<LLVM::BitcastOp>(
loc, cudaq::opt::factory::getPointerType(context), pauliWord);
operands.push_back(castedPauli);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(instOp, TypeRange{}, symbolRef,
operands);
return success();
}
};

/// Lower single target Quantum ops with no parameter to QIR:
/// h, x, y, z, s, t
template <typename OP>
Expand Down Expand Up @@ -1310,6 +1339,42 @@ class StdvecSizeOpPattern
}
};

class CreateStringLiteralOpPattern
: public ConvertOpToLLVMPattern<cudaq::cc::CreateStringLiteralOp> {
public:
using Base = ConvertOpToLLVMPattern<cudaq::cc::CreateStringLiteralOp>;
using Base::Base;

LogicalResult
matchAndRewrite(cudaq::cc::CreateStringLiteralOp stringLiteralOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = stringLiteralOp.getLoc();
auto parentModule = stringLiteralOp->getParentOfType<ModuleOp>();
StringRef stringLiteral = stringLiteralOp.getStringLiteral();

// Write to the module body
auto insertPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(parentModule.getBody());

// Create the register name global
auto builder = cudaq::IRBuilder::atBlockEnd(parentModule.getBody());
auto slGlobal =
builder.genCStringLiteralAppendNul(loc, parentModule, stringLiteral);

// Shift back to the function
rewriter.restoreInsertionPoint(insertPoint);

// Get the string address
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
stringLiteralOp,
cudaq::opt::factory::getPointerType(slGlobal.getType()),
slGlobal.getSymName());

return success();
}
};

class StoreOpPattern : public ConvertOpToLLVMPattern<cudaq::cc::StoreOp> {
public:
using Base = ConvertOpToLLVMPattern<cudaq::cc::StoreOp>;
Expand Down Expand Up @@ -1420,25 +1485,26 @@ class QuakeToQIRRewrite : public cudaq::opt::QuakeToQIRBase<QuakeToQIRRewrite> {

patterns.insert<GetVeqSizeOpRewrite, MxToMz, MyToMz, ReturnBitRewrite>(
context);
patterns.insert<
AllocaOpRewrite, AllocaOpPattern, CallableClosureOpPattern,
CallableFuncOpPattern, CallCallableOpPattern, CastOpPattern,
ComputePtrOpPattern, ConcatOpRewrite, DeallocOpRewrite,
ExtractQubitOpRewrite, ExtractValueOpPattern, FuncToPtrOpPattern,
InsertValueOpPattern, InstantiateCallableOpPattern, LoadOpPattern,
OneTargetRewrite<quake::HOp>, OneTargetRewrite<quake::XOp>,
OneTargetRewrite<quake::YOp>, OneTargetRewrite<quake::ZOp>,
OneTargetRewrite<quake::SOp>, OneTargetRewrite<quake::TOp>,
OneTargetOneParamRewrite<quake::R1Op>,
OneTargetTwoParamRewrite<quake::PhasedRxOp>,
OneTargetOneParamRewrite<quake::RxOp>,
OneTargetOneParamRewrite<quake::RyOp>,
OneTargetOneParamRewrite<quake::RzOp>,
OneTargetTwoParamRewrite<quake::U2Op>,
OneTargetTwoParamRewrite<quake::U3Op>, ResetRewrite<quake::ResetOp>,
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
StoreOpPattern, SubveqOpRewrite, TwoTargetRewrite<quake::SwapOp>,
UndefOpPattern>(typeConverter);
patterns
.insert<AllocaOpRewrite, AllocaOpPattern, CallableClosureOpPattern,
CallableFuncOpPattern, CallCallableOpPattern, CastOpPattern,
ComputePtrOpPattern, ConcatOpRewrite, DeallocOpRewrite,
CreateStringLiteralOpPattern, ExtractQubitOpRewrite,
ExtractValueOpPattern, FuncToPtrOpPattern, InsertValueOpPattern,
InstantiateCallableOpPattern, LoadOpPattern, ExpPauliRewrite,
OneTargetRewrite<quake::HOp>, OneTargetRewrite<quake::XOp>,
OneTargetRewrite<quake::YOp>, OneTargetRewrite<quake::ZOp>,
OneTargetRewrite<quake::SOp>, OneTargetRewrite<quake::TOp>,
OneTargetOneParamRewrite<quake::R1Op>,
OneTargetTwoParamRewrite<quake::PhasedRxOp>,
OneTargetOneParamRewrite<quake::RxOp>,
OneTargetOneParamRewrite<quake::RyOp>,
OneTargetOneParamRewrite<quake::RzOp>,
OneTargetTwoParamRewrite<quake::U2Op>,
OneTargetTwoParamRewrite<quake::U3Op>, ResetRewrite,
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
StoreOpPattern, SubveqOpRewrite,
TwoTargetRewrite<quake::SwapOp>, UndefOpPattern>(typeConverter);
patterns.insert<MeasureRewrite<quake::MzOp>>(typeConverter, measureCounter);

target.addLegalDialect<LLVM::LLVMDialect>();
Expand Down
4 changes: 2 additions & 2 deletions lib/Optimizer/Dialect/CC/CCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ void cc::ArrayType::print(AsmPrinter &printer) const {
#define GET_TYPEDEF_CLASSES
#include "cudaq/Optimizer/Dialect/CC/CCTypes.cpp.inc"

//===----------------------------------------------------------------------===//

namespace cudaq {

cc::CallableType cc::CallableType::getNoSignature(MLIRContext *ctx) {
return CallableType::get(ctx, FunctionType::get(ctx, {}, {}));
}

//===----------------------------------------------------------------------===//

void cc::CCDialect::registerTypes() {
addTypes<ArrayType, CallableType, PointerType, StdvecType, StructType>();
}
Expand Down
19 changes: 18 additions & 1 deletion python/runtime/cudaq/builder/py_kernel_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,24 @@ provided `function` will be applied within `self` at each iteration.
.def("to_quake", &kernel_builder<>::to_quake, "See :func:`__str__`.")
.def("__str__", &kernel_builder<>::to_quake,
"Return the :class:`Kernel` as a string in its MLIR representation "
"using the Quake dialect.\n");
"using the Quake dialect.\n")
.def(
"exp_pauli",
[](kernel_builder<> &self, py::object theta, const QuakeValue &qubits,
const std::string &pauliWord) {
if (py::isinstance<py::float_>(theta))
self.exp_pauli(theta.cast<double>(), qubits, pauliWord);
else if (py::isinstance<QuakeValue>(theta))
self.exp_pauli(theta.cast<QuakeValue &>(), qubits, pauliWord);
else
throw std::runtime_error(
"Invalid `theta` argument type. Must be a "
"`float` or a `QuakeValue`.");
},
"Apply a general Pauli tensor product rotation, `exp(i theta P)`, on "
"the specified qubit register. The Pauli tensor product is provided "
"as a string, e.g. `XXYX` for a 4-qubit term. The angle parameter "
"can be provided as a concrete float or a `QuakeValue`.");
}

void bindBuilder(py::module &mod) {
Expand Down
Loading

0 comments on commit da6d0b9

Please sign in to comment.