Skip to content

Commit

Permalink
add support for converting a cps call result
Browse files Browse the repository at this point in the history
Conversions are simple and side-effect free, so to support them in
assignments we only have to tweak the shim so that it look for the call
within the expression and replace it with the result.

For `discard`, we rewrite the statement into an assignment with a
temporary to simulate the effect since we cannot "skip" the conversion
due to their interactions with destructors.
  • Loading branch information
alaviss authored and disruptek committed Nov 7, 2021
1 parent c4c34e1 commit e91cc02
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 11 deletions.
29 changes: 28 additions & 1 deletion cps/normalizedast.nim
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ type
## opaque sum: call node of some variety, see `macros.nnkCallKinds`,
## this is an alias as it's not really useful to distinguish the two.

Conv* = distinct NormNode
## an nnkConv node

Pragma* = distinct NormNode
## opaque sum: `PragmaStmt`, `PragmaBlock`, and `PragmaExpr`
PragmaStmt* = distinct Pragma
Expand Down Expand Up @@ -256,7 +259,7 @@ defineToNimNodeConverter(

# all the types that can convert down to `NormNode`
allowAutoDowngradeNormalizedNode(
Name, TypeExpr, Call, PragmaStmt, PragmaAtom, IdentDef, RoutineDef,
Name, TypeExpr, Call, Conv, PragmaStmt, PragmaAtom, IdentDef, RoutineDef,
ProcDef, FormalParams, RoutineParam, VarSection, LetSection, VarLet,
VarLetIdentDef, VarLetTuple, DefVarLet, IdentDefLet, Sym
)
Expand Down Expand Up @@ -346,6 +349,18 @@ template findChild*(n: NormNode; cond: untyped): NormNode =
## finds the first child node matching the condition or nil
NormNode macros.findChild(n, cond)

proc findChildRecursive*(n: NormNode, cmp: proc(n: NormNode): bool): NormNode =
## finds the first child node where `cmp(node)` returns true, recursively
##
## returns nil if none found
if cmp(n):
result = n
else:
for child in n.items:
result = findChildRecursive(NormNode(child), cmp)
if not result.isNil:
return

proc getImpl*(n: NormNode): NormNode {.borrow.}
## the implementaiton of a normalized node should be normalized itself

Expand Down Expand Up @@ -1059,6 +1074,18 @@ proc desym*(n: Call) =
## desyms the callee name
n.name = desym n.name

# fn-Conv

createAsTypeFunc(Conv, {nnkConv}, "node is not a conv node")

proc typ*(n: Conv): TypeExpr =
## the type being converted to
n[0].asTypeExpr

proc expr*(n: Conv): NormNode =
## the expression being converted
n[1]

# fn-FormalParams

proc newFormalParams*(ret: TypeExpr, ps: varargs[IdentDef]): FormalParams =
Expand Down
9 changes: 9 additions & 0 deletions cps/spec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,15 @@ proc isCpsCall*(n: NormNode): bool =
# or it could be a completely new continuation
result = it.impl.hasPragma "cpsMustJump"

proc isCpsConvCall*(n: NormNode): bool =
## true if this node holds a cps call that might be nested within one or more
## conversions.
case n.kind
of nnkConv:
isCpsConvCall(n.last)
else:
isCpsCall(n)

proc isCpsBlock*(n: NormNode): bool =
## `true` if the block `n` contains a cps call anywhere at all;
## this is used to figure out if a block needs tailcall handling...
Expand Down
48 changes: 38 additions & 10 deletions cps/transform.nim
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,15 @@ proc setupChildContinuation(env: var Env; call: Call): (Name, NormNode) =
# XXX: also sniff and return the child continuation's "user type"?
result = (child, etype)

proc shimAssign(env: var Env; store: NormNode, call: Call, tail: NormNode): NormNode =
## this rewrite supports `x = contProc()` and `let x = contProc()`
proc shimAssign(env: var Env; store: NormNode, expr: NormNode, tail: NormNode): NormNode =
## this rewrite supports `x = contProc()`, `let x = contProc()`,
## `x = Type(contProc())` and `let x = Type(contProc())`
if not expr.isCpsConvCall:
return expr.errorAst("cps don't know how to shim this")

let call = asCall:
expr.findChildRecursive(isCpsCall)

# create the unshimmed assignment
var assign = newStmtList()
case store.kind
Expand All @@ -611,9 +618,9 @@ proc shimAssign(env: var Env; store: NormNode, call: Call, tail: NormNode): Norm
of nnkSym, nnkIdent, nnkDotExpr:
# perform a normal assignment
assign.add:
newAssignment(store, call)
newAssignment(store, expr)
else:
raise Defect.newException "unsupported store kind: " & $store.kind
return store.errorAst("unsupported store")

# swap the call in the assignment statement(s)
let (child, etype) = setupChildContinuation(env, call)
Expand Down Expand Up @@ -752,21 +759,23 @@ proc annotate(parent: var Env; n: NormNode): NormNode =

of nnkVarSection, nnkLetSection:
let section = asVarLet nc
if section.val.isCpsCall or section.val.isCpsBlock:
if section.val.isCpsConvCall:
let assign = section
result.add: # shimming `let x = foo()` or `let (a, b) = bar()`
env.shimAssign(assign, asCall(assign.val)):
# shimming `let x = foo()` or `let (a, b) = bar()` or `let x = T(foo())`
result.add:
env.shimAssign(assign, assign.val):
anyTail()
return
endAndReturn()
else:
# add definitions into the environment
env.localSection(section, result)

of nnkAsgn:
if nc.last.isCpsCall:
if nc.last.isCpsConvCall:
# shimming `x = foo()` or `x = T(foo())`
result.add:
if nc.len >= 2:
env.shimAssign(nc[0], asCall(nc[1])): # shimming `x = foo()`
env.shimAssign(nc[0], nc[1]):
anyTail()
else:
nc.errorAst "i expected at least two kids on an nkAsgn"
Expand Down Expand Up @@ -864,6 +873,25 @@ proc annotate(parent: var Env; n: NormNode): NormNode =
# Rewrite it into an inline call, which cause them to be treated like
# result-less calls by the transformation.
env.annotate newStmtList(nc[0])
elif nc[0].isCpsConvCall:
# If the discarded expression is a conversion of a call result
let conv = nc[0].asConv

# Rewrite it into
#
# let tmp: Type = Type(call)
#
# And let assignment rewrite handle the rest.
#
# TODO: either make `tmp` not stored in the continuation, or get rid of
# it entirely.
result.add:
env.annotate:
newStmtList:
newLetIdentDef(genSymLet(info = nc), conv.typ):
# XXX: Not sure why I have to convert here, the type is already
# specified in allowAutoDowngradeNormalizedNode
NormNode conv
else:
result.add env.annotate(nc)

Expand Down
28 changes: 28 additions & 0 deletions tests/treturns.nim
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,31 @@ suite "returns and results":
check (x, y) == (10, 20)

foo()

block:
## converting a cps return value
var k = newKiller(3)

proc bar(): int {.cps: Cont.} =
noop()
42

proc foo() {.cps: Cont.} =
let x = Natural bar()
let x1 = Natural int Natural bar()

step 1
check x == 42

var y: Natural
y = Natural bar()
y = Natural int Natural bar()

step 2
check y == 42

discard Natural bar()
discard Natural int Natural bar()
step 3

foo()

0 comments on commit e91cc02

Please sign in to comment.