Skip to content

Commit

Permalink
feat(shader-ast): update/improve AST optimizer
Browse files Browse the repository at this point in the history
- add support for lit hoisting & single comp swizzling
- add logger support in replaceNode()
- update constantFolding() to run iteratively as many times as needed
- fix op1/op2 optimizers to use correct node predicates
  • Loading branch information
postspectacular committed Aug 13, 2021
1 parent 24c8ad5 commit ad60add
Showing 1 changed file with 116 additions and 25 deletions.
141 changes: 116 additions & 25 deletions packages/shader-ast/src/optimize.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
import { NO_OP } from "@thi.ng/api";
import { LogLevel } from "@thi.ng/api";
import { DEFAULT, defmulti } from "@thi.ng/defmulti";
import type { Lit, Op1, Op2, Term } from "./api/nodes";
import type { Lit, Op1, Op2, Swizzle, Term } from "./api/nodes";
import type { Operator } from "./api/ops";
import { isLitNumeric } from "./ast/checks";
import { lit } from "./ast/lit";
import type { Swizzle4_1 } from "./api/swizzles";
import {
isFloat,
isInt,
isLitNumericConst,
isLitVecConst,
isUint,
} from "./ast/checks";
import { float, int, lit, uint } from "./ast/lit";
import { allChildren, walk } from "./ast/scope";
import { LOGGER } from "./logger";

/**
* Replaces contents of `node` with those of `next`. All other existing props in
* `node` will be removed.
*
* @param node
* @param next
*
* @internal
*/
const replaceNode = (node: any, next: any) => {
if (LOGGER.level <= LogLevel.DEBUG) {
LOGGER.debug(`replacing AST node:`);
LOGGER.debug("\told: " + JSON.stringify(node));
LOGGER.debug("\tnew: " + JSON.stringify(next));
}
for (let k in node) {
!next.hasOwnProperty(k) && delete node[k];
}
return Object.assign(node, next);
Object.assign(node, next);
return true;
};

/** @internal */
const maybeFoldMath = (op: Operator, l: any, r: any) =>
op === "+"
? l + r
Expand All @@ -24,45 +48,90 @@ const maybeFoldMath = (op: Operator, l: any, r: any) =>
? l / r
: undefined;

export const foldNode = defmulti<Term<any>, void>((t) => t.tag);
foldNode.add(DEFAULT, NO_OP);
/** @internal */
const COMPS: Record<Swizzle4_1, number> = { x: 0, y: 1, z: 2, w: 3 };

/** @internal */
export const foldNode = defmulti<Term<any>, boolean | undefined>((t) => t.tag);
foldNode.add(DEFAULT, () => false);

foldNode.addAll({
op1: (t) => {
const op = <Op1<any>>t;
if (op.op == "-" && isLitNumeric(op.val)) {
replaceNode(t, <Lit<"float">>op.val);
(<any>op).val = -(<any>op).val;
op1: (node) => {
const $node = <Op1<any>>node;
if ($node.op == "-" && isLitNumericConst($node.val)) {
(<Lit<"float">>$node.val).val *= -1;
return replaceNode(node, <Lit<"float">>$node.val);
}
},

op2: (node) => {
const op = <Op2<any>>node;
if (isLitNumeric(op.l) && isLitNumeric(op.r)) {
const vl = (<Lit<"float">>op.l).val;
const vr = (<Lit<"float">>op.r).val;
let res = maybeFoldMath(op.op, vl, vr);
const $node = <Op2<any>>node;
if (isLitNumericConst($node.l) && isLitNumericConst($node.r)) {
const vl = (<Lit<"float">>$node.l).val;
const vr = (<Lit<"float">>$node.r).val;
let res = maybeFoldMath($node.op, vl, vr);
if (res !== undefined) {
op.type === "int" && (res |= 0);
op.type === "uint" && (res >>>= 0);
replaceNode(node, lit(op.type, res));
$node.type === "int" && (res |= 0);
$node.type === "uint" && (res >>>= 0);
return replaceNode(node, lit($node.type, res));
}
}
},

lit: (node) => {
const $node = <Lit<any>>node;
if (isLitNumericConst($node.val)) {
if (isFloat($node.val)) {
return replaceNode(node, float($node.val.val));
}
if (isInt($node.val)) {
return replaceNode(node, int($node.val.val));
}
if (isUint($node.val)) {
return replaceNode(node, uint($node.val.val));
}
}
},

swizzle: (node) => {
const $node = <Swizzle<any>>node;
const val = $node.val;
if (isLitVecConst(val)) {
if (isFloat(node)) {
return replaceNode(
node,
float(val.val[COMPS[<Swizzle4_1>$node.id]])
);
}
}
},
});

/**
* Traverses given AST and applies constant folding optimizations where
* possible. Returns possibly updated tree (mutates original).
* Currently, only scalar operations are supported / considered.
* Traverses given AST (potentially several times) and applies constant folding
* optimizations where possible. Returns possibly updated tree (mutates
* original).
*
* @remarks
* Currently, only the following operations are supported / considered:
*
* - scalar math ops
* - single component vector swizzling
* - literal hoisting
*
* @example
* ```ts
* const foo = defn("float", "foo", ["float"], (x) => [
* ret(mul(x, add(neg(float(10)), float(42))))]
* )
*
* const prog = scope([foo, foo(add(float(1), float(2)))], true);
* const bar = vec2(100, 200);
*
* const prog = scope([
* foo,
* foo(add(float(1), float(2))),
* foo(add($x(bar), $y(bar)))
* ], true);
*
* // serialized (GLSL)
* glsl(prog);
Expand All @@ -71,6 +140,7 @@ foldNode.addAll({
* // return (_sym0 * (-10.0 + 42.0));
* // };
* // foo((1.0 + 2.0));
* // foo((vec2(100.0, 200.0).x + vec2(100.0, 200.0).y));
*
* // with constant folding
* glsl(constantFolding(prog))
Expand All @@ -79,11 +149,32 @@ foldNode.addAll({
* // return (_sym0 * 32.0);
* // };
* // foo(3.0);
* // foo(300.0);
*
* const expr = mul(float(4), $x(vec2(2)))
*
* glsl(expr)
* // (4.0 * vec2(2.0).x)
*
* glsl(constantFolding(expr))
* // 8.0
* ```
*
* @param tree -
*/
export const constantFolding = (tree: Term<any>) => {
walk((_, node) => foldNode(node), allChildren, <any>null, tree, false);
let exec = true;
while (exec) {
exec = false;
walk(
(_, node) => {
exec = foldNode(node) || exec;
},
allChildren,
<any>null,
tree,
false
);
}
return tree;
};

0 comments on commit ad60add

Please sign in to comment.