diff --git a/scinim/fuse_loops.nim b/scinim/fuse_loops.nim new file mode 100644 index 0000000..8e44b4c --- /dev/null +++ b/scinim/fuse_loops.nim @@ -0,0 +1,189 @@ +import std / [macros, options, algorithm] + +type + ForLoop = object + n: NimNode # the actual node + body: NimNode # body of the loop *WITHOUT* any inner loops! + idx: NimNode # the loop index + start: NimNode # start of the loop + stop: NimNode # stop of the loop + +template nofuse*(arg: untyped): untyped = + ## Just a dummy template, which can be easily used to disable fusing of + ## a nested loop + arg + +proc extractBody(n: NimNode): NimNode = + ## Returns the input tree without any possible nested for loops. Nested + ## loops are replaced by `nnkEmpty` nodes to be filled again later in `bodies`. + case n.kind + of nnkForStmt: + if n[1].kind == nnkInfix and n[1][0].strVal == "..<": + result = newEmptyNode() ## Flattened nested loop body will be inserted here + else: + result = n + else: + if n.len > 0: + result = newTree(n.kind) + for ch in n: + let bd = extractBody(ch) + if bd != nil: + result.add bd + else: + result = n + +proc toForLoop(n: NimNode): Option[ForLoop] = + ## Returns a `some(ForLoop)` if the given node is a fuse-able for loop + doAssert n.kind == nnkForStmt + if n[1].kind != nnkInfix: return + if n[1][0].strVal != "..<": + error("Unexpected iterator: " & $n[1].repr & + ". It must be of the form `0 ..< X`.") + if not (n[1][1].kind == nnkIntLit and n[1][1].intVal == 0): + error("Starting iteration index must be 0!") + result = some(ForLoop(n: n, + body: extractBody(n[2]), + idx: n[0], + start: n[1][1], + stop: n[1][2])) + +template addIf(s, opt): untyped = + if opt.isSome: + s.add opt.unsafeGet + +proc extractLoops(n: NimNode): seq[ForLoop] = + ## Extracts (fuse-able) loops from the given Nim node and errors if more than + ## one for loop found at the same level. + case n.kind + of nnkForStmt: + result.addIf toForLoop(n) + result.add extractLoops(n[2]) # go over body + else: + var foundLoops = 0 # counter for number of loops at current body + for ch in n: + let loops = extractLoops(ch) + if loops.len > 0: + result.add loops + inc foundLoops + if foundLoops > 1: + error("Found more than one loop (" & $foundLoops & ") at the level of node: " & + n.repr & ". Please wrap " & "these loops as `nofuse`, i.e. `nofuse(0 ..< X)`") + +proc genFusedLoop(idx: NimNode, stop: NimNode, ompStr = ""): NimNode = + ## Generate either regular or OpenMP for loop + let loopIter = if ompStr.len == 0: + nnkInfix.newTree(ident"..<", + newLit 0, + stop) + else: + nnkCall.newTree(ident"||", + newLit 0, + stop, + newLit ompStr) + result = nnkForStmt.newTree( + idx, + loopIter + ) + +proc calcStop(loops: seq[ForLoop]): NimNode = + ## Returns `N * T * U * ...` expression where the indices are + ## the stop indices of the loops to be fused. + case loops.len + of 0: doAssert false, "Must not happen" + of 1: result = loops[0].stop + else: + var ml = loops.reversed # want last elements first + let x = ml.pop + result = nnkInfix.newTree(ident"*", x.stop, + calcStop(ml.reversed)) + +proc modOrDiv(prefix, suffix: NimNode, isDiv: bool): NimNode = + if isDiv: + result = quote do: + `prefix` div `suffix` + else: + result = quote do: + `prefix` mod `suffix` + +proc asLet(v, val: NimNode): NimNode = + result = quote do: + let `v` = `val` + +proc genPrelude(idx: NimNode, loops: seq[ForLoop]): NimNode = + ## The basic algorithm for generating the correct index for fused loops is + ## + ## Notation: + ## `i` = Loop index of single remaining outer loop + ## `N_i` = Stopping index (-1) of the inner loop `i` + ## `n` = Total number of nested loops + ## + ## Whichever is easiest to read for you: + ## + ## `let i0 = i div (N_0 * N_1 ... N_n)` + ## `let i1 = (i mod (N_0 * N_1 ... N_n)) div (N_1 * N_2 * ... N_n)` + ## `let i2 = ((i mod (N_0 * N_1 ... N_n)) mod (N_1 * N_2 * ... N_n)) div (N_2 * ... * N_n)` + ## ... + ## + ## ... or + ## + ## `let i0 = i div Π_i=0^n N_i` + ## `let i1 = (i mod Π_i=0^n N_i) div Π_i=1^n N_i` + ## `let i2 = ((i mod Π_i=0^n N_i) mod Π_i=1^n N_i) div Π_i=2^n N_i` + ## + ## ...or + ## + ## `let i0 = Idx div [Product of remaining N-1 loops]` + ## `let i1 = (Idx mod [Product of remaining loops]) div [Product of remaining N-2 loops]` + ## `let i2 = (Idx mod [Product of remaining loops]) mod [Product of remaining N-2 loops]` + result = newStmtList() + var prefix = idx + var ml = loops.reversed + var lIdx = ml.pop # drop first element + var suffix = ml.calcStop() + while ml.len > 0: + result.add asLet(lIdx.idx, modOrDiv(prefix, suffix, isDiv = true)) + lIdx = ml.pop # get next loop index & adjust remaining loops + # now adjust prefix and suffix + prefix = modOrDiv(prefix, suffix, isDiv = false) + if ml.len > 0: # adjust suffix + suffix = ml.calcStop() + else: # simply add last 'prefix' + result.add asLet(lIdx.idx, prefix) + +proc bodies(loops: seq[ForLoop]): NimNode = + ## Concatenates all loop bodies, by placing the next loop into the + ## `nnkEmpty` node of the current node + var ml = loops.reversed + #echo ml.repr + var cur = ml.pop + result = cur.body + for i in 0 ..< result.len: + let ch = result[i] + if ch.kind == nnkEmpty: + # insert next loop + result[i] = bodies(ml.reversed) # revert order again + break # there can only be a single `nnkEmpty` (multiple loops not allowed, + # yields CT error) + +proc fuseLoopImpl(ompStr: string, body: NimNode): NimNode = + # 1. extract all loops from the body + let loops = extractLoops(body) + # 2. generate identifier for the final loop + let idx = genSym(nskForVar, "idx") + # 3. generate the fused outer loop + result = genFusedLoop(idx, calcStop(loops), ompStr) + # 4. generate final loop body by... + var loopBody = newStmtList() + # 4a. generate prelude of loop variables of original loops + loopBody.add genPrelude(idx, loops) # gen code to produce the old loop variables + # 4b. insert old loop bodies into respective positions + loopBody.add bodies(loops) + result.add loopBody + when defined(DebugFuseLoop): + echo result.repr + +macro fuseLoops*(body: untyped): untyped = + result = fuseLoopImpl("", body) + +macro fuseLoops*(ompStr: untyped{lit}, body: untyped): untyped = + result = fuseLoopImpl(ompStr.strVal, body) diff --git a/tests/tFuseLoops.nim b/tests/tFuseLoops.nim new file mode 100644 index 0000000..474da11 --- /dev/null +++ b/tests/tFuseLoops.nim @@ -0,0 +1,54 @@ +import ../scinim/fuse_loops +import std / unittest + +suite "fuseLoops": + test "Compiles test for different `fuseLoops` setups": + const N = 5 + const T = 10 + const X = 3 + + ## XXX: These should probably become proper tests. :) + + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + echo i, j, x, z + echo x + + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + echo i, j, x, z + for k in nofuse(0 ..< T): + echo k + echo x + + fuseLoops("parallel for"): + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + for k in 0 ..< X: + echo i, j, k, x, z + echo x + + ## The following raises a CT error + when compiles(( + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + var zsum = 0 + for j in 0 ..< T: + let z = x * j + zsum += z + echo i, x, z + echo x + for j in 0 ..< 2 * T: + zsum += j + echo zsum + )): + doAssert false