Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update WireSetsToProfileQIR pass: remove WireSets and fix LLVM lowering issue #2069

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/cudaq/Optimizer/CodeGen/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def WireSetToProfileQIR : Pass<"wireset-to-profile-qir", "mlir::func::FuncOp"> {
];
}

def WireSetToProfileQIRPost :
Pass<"wireset-to-profile-qir-post", "mlir::ModuleOp"> {
let summary = "Post processing for lowering wire sets to a profile of QIR";
let description = [{
This pass should be run immediately after wireset-to-profile-qir.
}];

let dependentDialects = ["cudaq::cc::CCDialect", "mlir::func::FuncDialect"];
}

def WireSetToProfileQIRPrep :
Pass<"wireset-to-profile-qir-prep", "mlir::ModuleOp"> {
let summary = "Prepare for lowering wire sets to a profile of QIR";
Expand Down
45 changes: 42 additions & 3 deletions lib/Optimizer/CodeGen/WireSetsToProfileQIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

namespace cudaq::opt {
#define GEN_PASS_DEF_WIRESETTOPROFILEQIR
#define GEN_PASS_DEF_WIRESETTOPROFILEQIRPOST
#define GEN_PASS_DEF_WIRESETTOPROFILEQIRPREP
#include "cudaq/Optimizer/CodeGen/Passes.h.inc"
} // namespace cudaq::opt
Expand Down Expand Up @@ -145,6 +146,17 @@ struct ReturnWireRewrite : OpConversionPattern<quake::ReturnWireOp> {
}
};

struct WireSetRewrite : OpConversionPattern<quake::WireSetOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(quake::WireSetOp wireSetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(wireSetOp);
return success();
}
};

struct MzRewrite : OpConversionPattern<quake::MzOp> {
using Base = OpConversionPattern;
explicit MzRewrite(TypeConverter &typeConverter, unsigned &counter,
Expand Down Expand Up @@ -195,13 +207,18 @@ struct DiscriminateRewrite : OpConversionPattern<quake::DiscriminateOp> {
// NB: This is thread safe as it should never do an insertion, just a
// lookup.
auto nameObj = irb.genCStringLiteralAppendNul(loc, mod, iter->second);
auto arrI8Ty = mlir::LLVM::LLVMArrayType::get(rewriter.getI8Type(),
iter->second.size() + 1);
auto ptrArrTy = cudaq::cc::PointerType::get(arrI8Ty);
Value nameVal = rewriter.create<cudaq::cc::AddressOfOp>(loc, ptrArrTy,
nameObj.getName());
auto cstrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
Value nameVal =
rewriter.create<cudaq::cc::AddressOfOp>(loc, cstrTy, nameObj.getName());
Value nameValCStr =
rewriter.create<cudaq::cc::CastOp>(loc, cstrTy, nameVal);

rewriter.create<func::CallOp>(
loc, std::nullopt, "__quantum__rt__result_record_output",
ValueRange{adaptor.getMeasurement(), nameVal});
ValueRange{adaptor.getMeasurement(), nameValCStr});
if (isAdaptiveProfile) {
std::string funcName = toQisBodyName(std::string("read_result"));
rewriter.replaceOpWithNewOp<func::CallOp>(
Expand Down Expand Up @@ -371,6 +388,27 @@ struct WireSetToProfileQIRPrepPass
LLVM_DEBUG(llvm::dbgs() << "Module after prep:\n"; op->dump());
}
};

struct WireSetToProfileQIRPostPass
: public cudaq::opt::impl::WireSetToProfileQIRPostBase<
WireSetToProfileQIRPostPass> {
using WireSetToProfileQIRPostBase::WireSetToProfileQIRPostBase;

void runOnOperation() override {
ModuleOp op = getOperation();
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
QuakeTypeConverter quakeTypeConverter;
patterns.insert<WireSetRewrite>(quakeTypeConverter, ctx);
ConversionTarget target(*ctx);
target.addIllegalDialect<quake::QuakeDialect>();

LLVM_DEBUG(llvm::dbgs() << "Module before:\n"; op.dump());
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
LLVM_DEBUG(llvm::dbgs() << "Module after:\n"; op.dump());
}
};
} // namespace

void cudaq::opt::addWiresetToProfileQIRPipeline(OpPassManager &pm,
Expand All @@ -380,6 +418,7 @@ void cudaq::opt::addWiresetToProfileQIRPipeline(OpPassManager &pm,
if (!profile.empty())
wopt.convertTo = profile.str();
pm.addNestedPass<func::FuncOp>(cudaq::opt::createWireSetToProfileQIR(wopt));
pm.addPass(cudaq::opt::createWireSetToProfileQIRPost());
}

// Pipeline option: let the user specify the profile name.
Expand Down
Loading