Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor python kernel #2024

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
* Removing unused imports
* Optimizing control loop
* Using textwrap to remove leading spaces from the source code
  • Loading branch information
sacpis committed Jul 29, 2024
commit 935822cf901f5ffa8d298573dc1d2d835d8c0ea9
136 changes: 57 additions & 79 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
import ast, sys, traceback
import ast
import importlib
import inspect
import json
from typing import Callable
import textwrap
from ..mlir.ir import *
from ..mlir.passmanager import *
from ..mlir.dialects import quake, cc
from ..mlir.dialects import cc
from .ast_bridge import compile_to_mlir, PyASTBridge
from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli
from .analysis import MidCircuitMeasurementAnalyzer, RewriteMeasures, HasReturnNodeVisitor
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime
from .captured_data import CapturedDataStorage

Expand Down Expand Up @@ -78,15 +78,10 @@ def __init__(self,
# We filter only types we accept: integers and floats.
# Note here we assume that the parent scope is 2 stack frames up
self.parentFrame = inspect.stack()[2].frame
if overrideGlobalScopedVars:
self.globalScopedVars = {
k: v for k, v in overrideGlobalScopedVars.items()
}
else:
self.globalScopedVars = {
k: v for k, v in dict(inspect.getmembers(self.parentFrame))
['f_locals'].items()
}
self.globalScopedVars = {
k: v for k, v in (
overrideGlobalScopedVars or self.parentFrame.f_locals).items()
}

# Once the kernel is compiled to MLIR, we
# want to know what capture variables, if any, were
Expand All @@ -108,8 +103,7 @@ def __init__(self,
'arg{}'.format(i): mlirTypeToPyType(v.type)
for i, v in enumerate(self.argTypes)
}
self.returnType = self.signature[
'return'] if 'return' in self.signature else None
self.returnType = self.signature.get('return')
return
else:
emitFatalError(
Expand All @@ -121,15 +115,11 @@ def __init__(self,
else:
# Get the function source
src = inspect.getsource(self.kernelFunction)

# Strip off the extra tabs
leadingSpaces = len(src) - len(src.lstrip())
self.funcSrc = '\n'.join(
[line[leadingSpaces:] for line in src.split('\n')])
self.funcSrc = textwrap.dedent(src).strip()

# Create the AST
self.astModule = ast.parse(self.funcSrc)
if verbose and importlib.util.find_spec('astpretty') is not None:
if verbose and importlib.util.find_spec('astpretty'):
import astpretty
astpretty.pprint(self.astModule.body[0])

Expand All @@ -141,8 +131,7 @@ def __init__(self,
self.arguments = [
(k, v) for k, v in self.signature.items() if k != 'return'
]
self.returnType = self.signature[
'return'] if 'return' in self.signature else None
self.returnType = self.signature.get('return')

# Validate that we have a return type annotation if necessary
hasRetNodeVis = HasReturnNodeVisitor()
Expand All @@ -168,23 +157,23 @@ def compile(self):
if the kernel is already compiled.
"""

if self.module is not None:
return

# Before we can execute, we need to make sure
# variables from the parent frame that we captured
# have not changed. If they have changed, we need to
# recompile with the new values.
s = inspect.currentframe()
while s:
if s == self.parentFrame:
current_frame = inspect.currentframe()
while current_frame:
if current_frame == self.parentFrame:
# We found the parent frame, now
# see if any of the variables we depend
# on have changed.
self.globalScopedVars = {
k: v
for k, v in dict(inspect.getmembers(s))['f_locals'].items()
}
if self.dependentCaptures != None:
self.globalScopedVars = current_frame.f_locals
if self.dependentCaptures:
for k, v in self.dependentCaptures.items():
if (isinstance(v, (list, np.ndarray))):
if isinstance(v, (list, np.ndarray)):
if not all(a == b for a, b in zip(
self.globalScopedVars[k], v)):
# Recompile if values in the list have changed.
Expand All @@ -195,23 +184,20 @@ def compile(self):
self.module = None
break
break
s = s.f_back

if self.module != None:
return
current_frame = current_frame.f_back

self.module, self.argTypes, extraMetadata = compile_to_mlir(
self.astModule,
self.metadata,
self.capturedDataStorage,
verbose=self.verbose,
returnType=self.returnType,
location=self.location,
parentVariables=self.globalScopedVars)
if self.module is None:
self.module, self.argTypes, extraMetadata = compile_to_mlir(
self.astModule,
self.metadata,
self.capturedDataStorage,
verbose=self.verbose,
returnType=self.returnType,
location=self.location,
parentVariables=self.globalScopedVars)

# Grab the dependent capture variables, if any
self.dependentCaptures = extraMetadata[
'dependent_captures'] if 'dependent_captures' in extraMetadata else None
# Grab the dependent capture variables, if any
self.dependentCaptures = extraMetadata.get('dependent_captures')

def __str__(self):
"""
Expand Down Expand Up @@ -278,7 +264,7 @@ def castPyList(self, fromEleTy, toEleTy, list):
return list

def createStorage(self):
ctx = None if self.module == None else self.module.context
ctx = self.module.context if self.module else None
return CapturedDataStorage(ctx=ctx,
loc=self.location,
name=self.name,
Expand All @@ -294,28 +280,22 @@ def type_to_str(t):
"""
if hasattr(t, '__origin__') and t.__origin__ is not None:
# Handle generic types from typing
origin = t.__origin__
args = t.__args__
args_str = ', '.join(
PyKernelDecorator.type_to_str(arg) for arg in args)
return f'{origin.__name__}[{args_str}]'
elif hasattr(t, '__name__'):
return t.__name__
else:
return str(t)
return f'{t.__origin__.__name__}[{", ".join(map(PyKernelDecorator.type_to_str, t.__args__))}]'
return getattr(t, '__name__', str(t))

def to_json(self):
"""
Convert `self` to a JSON-serialized version of the kernel such that
`from_json` can reconstruct it elsewhere.
"""
obj = dict()
obj['name'] = self.name
obj['location'] = self.location
obj['funcSrc'] = self.funcSrc
obj['signature'] = {
k: PyKernelDecorator.type_to_str(v)
for k, v in self.signature.items()
obj = {
'name': self.name,
'location': self.location,
'funcSrc': self.funcSrc,
'signature': {
k: PyKernelDecorator.type_to_str(v)
for k, v in self.signature.items()
}
}
return json.dumps(obj)

Expand Down Expand Up @@ -363,10 +343,10 @@ def __call__(self, *args):
emitErrorIfInvalidPauli(arg)
arg = cudaq_runtime.pauli_word(arg)

if issubclass(type(arg), list):
if all(isinstance(a, str) for a in arg):
[emitErrorIfInvalidPauli(a) for a in arg]
arg = [cudaq_runtime.pauli_word(a) for a in arg]
if issubclass(type(arg), list) and all(
isinstance(a, str) for a in arg):
[emitErrorIfInvalidPauli(a) for a in arg]
arg = [cudaq_runtime.pauli_word(a) for a in arg]

mlirType = mlirTypeFromPyType(type(arg),
self.module.context,
Expand All @@ -375,17 +355,15 @@ def __call__(self, *args):

# Support passing `list[int]` to a `list[float]` argument
# Support passing `list[int]` or `list[float]` to a `list[complex]` argument
if cc.StdvecType.isinstance(mlirType):
if cc.StdvecType.isinstance(self.argTypes[i]):
argEleTy = cc.StdvecType.getElementType(mlirType) # actual
eleTy = cc.StdvecType.getElementType(
self.argTypes[i]) # formal

if self.isCastable(argEleTy, eleTy):
processedArgs.append(
self.castPyList(argEleTy, eleTy, arg))
mlirType = self.argTypes[i]
continue
if cc.StdvecType.isinstance(mlirType) and cc.StdvecType.isinstance(
self.argTypes[i]):
argEleTy = cc.StdvecType.getElementType(mlirType) # actual
eleTy = cc.StdvecType.getElementType(self.argTypes[i]) # formal

if self.isCastable(argEleTy, eleTy):
processedArgs.append(self.castPyList(argEleTy, eleTy, arg))
mlirType = self.argTypes[i]
continue

if not cc.CallableType.isinstance(
mlirType) and mlirType != self.argTypes[i]:
Expand Down