Skip to content

Commit

Permalink
tons of updates, but tests are still broken
Browse files Browse the repository at this point in the history
  • Loading branch information
Timothy Baldridge committed May 5, 2013
1 parent 6a4cab0 commit e697998
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 151 deletions.
123 changes: 60 additions & 63 deletions src/examples/mandelbrot.clj
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
(ns examples.mandelbrot
(:require [criterium.core :as crit])
(:require [mjolnir.constructors-init :as const]
(:require [mjolnir.constructors-init :refer [defnf]]
[mjolnir.types :as types :refer [I8* Int64 Float32 Float32* Float64x4 Float64x4* VoidT]]
[mjolnir.expressions :refer [build optimize dump ->ConstVector ->Do ->FPToSI ->SIToFP]]
[mjolnir.config :as config]
[mjolnir.targets.target :refer [emit-to-file as-dll]]
[mjolnir.intrinsics :as intr]
[mjolnir.targets.nvptx-intrinsics :refer [TID_X NTID_X CTAID_X TID_Y NTID_Y CTAID_Y]]
[mjolnir.targets.nvptx :as nvptx])
[mjolnir.targets.nvptx :as nvptx]
[mjolnir.core :refer [build-module]])
(:alias c mjolnir.constructors)
(:import [java.awt Color Image Dimension]
[javax.swing JPanel JFrame SwingUtilities]
Expand Down Expand Up @@ -49,55 +50,48 @@

;; Mjolnir Method (no SSE)

(c/defn square [Float32 x -> Float32]
(c/* x x))

(c/defn calc-iteration [Float32 xpx Float32 ypx Float32 max Float32 width Float32 height -> Float32]
(c/let [x0 (c/- (c/* (c/fdiv xpx width) 3.5) 2.5)
y0 (c/- (c/* (c/fdiv ypx height) 2.0) 1.0)]
(c/loop [iteration 0.0
x 0.0
y 0.0]
(c/if (c/and (c/< (c/+ (square x)
(square y))
(square 2.0))
(c/< iteration max))
(c/recur (c/+ iteration 1.0)
(c/+ (c/- (square x)
(square y))
x0)
(c/+ (c/* 2.0 x y)
y0)
-> Float32)
iteration))))

(defmacro lfor [[var [from to step] tp] & body]
`(c/let [to# ~to]
(c/loop [~var ~from]
(c/if (c/< ~var to#)
(c/do ~@body
(c/recur (c/+ ~var ~step) ~'-> ~tp))
~var))))

(c/defn ^:extern calc-mandelbrot [Float32* arr Float32 width Float32 height Float32 max -> Float32*]
(lfor [y [0.0 height 1.0] Float32]
(lfor [x [0.0 width 1.0] Float32]
(c/let [idx (->FPToSI (c/+ (c/* y width) x)
Int64)]
(c/aset arr idx (c/fdiv (calc-iteration x y max width height) max)))))
(defnf square [Float32 x -> Float32]
(* x x))



(defnf calc-iteration [Float32 xpx Float32 ypx Float32 max Float32 width Float32 height -> Float32]
(let [x0 (- (* (/ xpx width) 3.5) 2.5)
y0 (- (/ (/ ypx height) 2.0) 1.0)]
(loop [iteration 0.0
x 0.0
y 0.0]
(if (and (< (+ (square x)
(square y))
(square 2.0))
(< iteration max))
(recur (+ iteration 1.0)
(+ (- (square x)
(square y))
x0)
(+ (* 2.0 x y)
y0))
iteration))))

(defnf ^:extern calc-mandelbrot [Float32* arr Float32 width Float32 height Float32 max -> Float32*]
(for [y [0.0 height 1.0]]
(for [x [0.0 width 1.0]]
(let [idx (cast Int64 (+ (* y width) x))]
(aset arr idx (/ (calc-iteration x y max width height) max)))))
arr)

(c/defn ^:extern calc-mandelbrot-ptx [Float32* arr Float32 width Float32 height Float32 max -> VoidT]
(c/let [xpx (->SIToFP (c/+ (c/* (CTAID_X) (NTID_X))
(TID_X))
Float32)
ypx (->SIToFP (c/+ (c/* (CTAID_Y) (NTID_Y))
(TID_Y))
Float32)
idx (->FPToSI (c/+ (c/* ypx width) xpx)
Int64)
c (calc-iteration xpx ypx max width height)]
(c/aset arr idx (c/fdiv c max))))
(defnf ^:extern calc-mandelbrot-ptx [Float32* arr
Float32 width
Float32 height
Float32 max
-> VoidT]
(let [xpx (cast Float32 (+ (* (CTAID_X) (NTID_X))
(TID_X)))
ypx (cast Float32 (+ (* (CTAID_Y) (NTID_Y))
(TID_Y)))
idx (cast Int64 (+ (* ypx width) xpx))
c (calc-iteration xpx ypx max width height)]
(aset arr idx (/ c max))))

(defn memory-to-array [^Memory m size]
(let [arr (float-array size)]
Expand Down Expand Up @@ -202,20 +196,23 @@

(defmethod run-command [:benchmark :mjolnir]
[_ _]
(let [module (c/module ['examples.mandelbrot/square
'examples.mandelbrot/calc-iteration
'examples.mandelbrot/calc-mandelbrot])
built (optimize (build module))
_ (dump built)
dll (as-dll (config/default-target)
built
{:verbose true})
mbf (get dll calc-mandelbrot)
buf (Memory. (* SIZE 8))]
(assert (and mbf dll) "Compilation error")
(println "Running...")
(crit/bench
(mbf buf WIDTH HEIGHT 1000.0))))
(binding [config/*target* (nvptx/make-default-target)
config/*float-type* Float32
config/*int-type* Int64]
(let [module (c/module ['examples.mandelbrot/square
'examples.mandelbrot/calc-iteration
'examples.mandelbrot/calc-mandelbrot])
built (build-module module)
_ (dump built)
dll (as-dll (config/default-target)
built
{:verbose true})
mbf (get dll calc-mandelbrot)
buf (Memory. (* SIZE 8))]
(assert (and mbf dll) "Compilation error")
(println "Running...")
(crit/bench
(mbf buf WIDTH HEIGHT 1000.0)))))

(defmethod run-command [:benchmark :java]
[_ _]
Expand Down
37 changes: 26 additions & 11 deletions src/mjolnir/constructors_init.clj
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
(next exprs)))

(defn c-* [& exprs]
(exp/->*Op exprs))
(gen-binops :* exprs))

(defn c-+ [& exprs]
(gen-binops :+ exprs))

(defn c-- [& exprs]
(gen-binops :- exprs))

(defn c-div [& exprs]
(gen-binops :div exprs))

(defn c-and [& exprs]
(reduce exp/->And
(first exprs)
(next exprs)))
(gen-binops :and exprs))

#_(defn c-module
[& body]
Expand Down Expand Up @@ -151,8 +152,8 @@
(defn c-eset [vec idx val]
(exp/->ESet vec idx val))

(defn c-bitcast [a tp]
(exp/->BitCast a tp))
(defn c-cast [tp a]
(exp/->Cast tp a))

(defmacro c-local [nm]
`(exp/->Local ~(name nm)))
Expand Down Expand Up @@ -263,6 +264,17 @@
~tp))))


(defmacro c-for [[var [from to step]] & body]
`(c-let [to# ~to]
(c-loop [~var ~from]
(c-if (c-< ~var to#)
(c-do ~@body
(c-recur (c-+ ~var ~step)))
~var))))




;; Black magic is here
(let [ns (create-ns 'mjolnir.constructors)]
(doseq [[nm var] (ns-publics *ns*)]
Expand All @@ -275,13 +287,15 @@


(defn- constructor [sym]
(let [s (symbol (str "c-" (name sym)))
(let [sym (if (= (name sym) "/")
(symbol "div")
sym)
s (symbol (str "c-" (name sym)))
var ((ns-publics (the-ns 'mjolnir.constructors-init)) s)]
(println "|" s var "|")
(println s)
var))

(defn- constructor? [sym]
(println "<" sym (not (nil? (constructor sym))) ">")
(not (nil? (constructor sym))))

(declare convert-form)
Expand All @@ -302,14 +316,15 @@

(and (symbol? x)
(constructor? x))
(symbol "mjolnir.constructors-init" (str "c-" (name x)))
(symbol "mjolnir.constructors-init" (str "c-" (if (= (name x) "/")
"div"
(symbol x))))


:else
x))

(defn- convert-form [body]
(println "converting " body)
(doall (map convert-form-1 body)))

(defmacro defnf [& body]
Expand Down
4 changes: 4 additions & 0 deletions src/mjolnir/core.clj
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,9 @@
(verify built)
built))

(defn build-module [m]
(-> (to-db m)
(to-llvm-module)))



79 changes: 25 additions & 54 deletions src/mjolnir/expressions.clj
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,15 @@
key)]
const)))

(defrecord BitCast [ptr tp]
Validatable
(validate [this]
(assure (Expression? ptr))
(assure (type? tp)))
Expression
(return-type [this]
tp)
(build [this]
(assert (not (keyword ptr)) ptr)
(llvm/BuildBitCast *builder* (build ptr) (llvm-type tp) (genname "bitcast_")))
(defrecord Cast [tp expr]
SSAWriter
(write-ssa [this]
(gen-plan
[tp-id (add-to-plan tp)
ptr (write-ssa ptr)
expr-id (write-ssa expr)
casted (add-instruction :inst.type/cast
{:inst.cast/type tp-id
:inst.arg/arg0 ptr
:inst.arg/arg0 expr-id
:node/return-type tp-id})]
casted)))

Expand Down Expand Up @@ -208,37 +198,20 @@
:>= llvm/LLVMRealOGE}})

(defrecord Cmp [pred a b]
Validatable
(validate [this]
(valid? a)
(valid? b)
(Expression? a)
(Expression? b)
(assure-same-type (return-type a) (return-type b)))
Expression
(return-type [this]
Int1)
(build [this]
(let [[tp f]
(cond
(integer-type? (return-type a)) [:int llvm/BuildICmp]
(float-type? (return-type a)) [:float llvm/BuildFCmp]
(vector-type? (return-type a)) [:float llvm/BuildFCmp])]
(assert (pred (tp cmp-maps)) "Invalid predicate symbol")
(f *builder* (pred (tp cmp-maps)) (build a) (build b) (genname "cmp_"))))
SSAWriter
(write-ssa [this]
(gen-plan
[tp (add-to-plan Int1)
lh (write-ssa a)
rh (write-ssa b)
nd (add-instruction :inst.type/cmp
{:node/return-type tp
:inst.arg/arg0 lh
:inst.arg/arg1 rh
:inst.cmp/pred pred}
nil)]
nd)))
(let [pred (keyword "inst.cmp.pred" (name pred))]
(gen-plan
[tp (add-to-plan Int1)
lh (write-ssa a)
rh (write-ssa b)
nd (add-instruction :inst.type/cmp
{:node/return-type tp
:inst.arg/arg0 lh
:inst.arg/arg1 rh
:inst.cmp/pred pred}
nil)]
nd))))

(defrecord Not [a]
Validatable
Expand Down Expand Up @@ -338,7 +311,11 @@

(def binop-maps
{:+ :inst.binop.type/add
:- :inst.binop.type/sub})
:- :inst.binop.type/sub
:* :inst.binop.type/mul
:div :inst.binop.type/div
:and :inst.binop.type/and
:or :inst.binop.type/or})



Expand Down Expand Up @@ -910,7 +887,7 @@
member)
_ (assert idx (pr-str "Idx error, did you validate first? " ptr " " member))
bptr (build ptr)
cptr (build (->BitCast ptr (->PointerType etp)))
cptr (build (->Cast (->PointerType etp) ptr))
gep (llvm/BuildStructGEP *builder* cptr idx (genname "set_"))]
(llvm/BuildStore *builder* (build val) gep)
bptr)))
Expand Down Expand Up @@ -946,7 +923,7 @@
idx (member-idx etp
member)
_ (assert idx "Idx error, did you validate first?")
cptr (build (->BitCast ptr (->PointerType etp)))
cptr (build (->Cast (->PointerType etp) ptr))
gep (llvm/BuildStructGEP *builder* cptr idx (genname "get_"))]
(llvm/BuildLoad *builder* gep (genname "load_")))))

Expand Down Expand Up @@ -1128,15 +1105,9 @@
*float-type*)
(build [this]
(encode-const *float-type* this))
#_ (comment IToDatoms
(-to-datoms [this conn]
(transact-new
conn
{:node/type :type/const
:node/return-type (-> (-to-datoms
*float-type*
conn)
:db/id)}))))
SSAWriter
(write-ssa [this]
(write-ssa (->Const *float-type* this))))



Expand Down
Loading

0 comments on commit e697998

Please sign in to comment.