Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Laser forEach iterators #479

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use the new forEach, warning! huge parallel perf regression to invest…
…igate (was already slower than serial)
  • Loading branch information
mratsim committed Jan 3, 2021

Verified

This commit was signed with the committer’s verified signature. The key has expired.
mratsim Mamy Ratsimbazafy
commit ac2f2403001d234ba5a1207eaefc5a774b086e72
25 changes: 18 additions & 7 deletions src/arraymancer/laser/strided_iteration/foreach_common.nim
Original file line number Diff line number Diff line change
@@ -7,12 +7,12 @@ import
std/[macros, strutils],
../compiler_optim_hints

template isVar[T: object](x: T): bool =
template isVar[T](x: T): bool =
## Workaround due to `is` operator not working for `var`
## https://github.com/nim-lang/Nim/issues/9443
compiles(addr(x))

proc aliasTensor(id: int, tensor: NimNode): NimNode =
proc aliasTensor(id: int, tensor: NimNode): tuple[alias: NimNode, isVar: NimNode] =
## Produce an alias variable for a tensor
## Supports:
## - identifiers
@@ -41,11 +41,22 @@ proc aliasTensor(id: int, tensor: NimNode): NimNode =

# Rewrite the AST to untyped
t = nnkBracketExpr.newTree(
tensor[0][1],
tensor[0][2]
tensor[1]
)
for i in 2 ..< tensor.len:
t.add tensor[i]

var alias = ""
let isVar = block:
# Handle slicing cases like foo[0..<1, 0..<2]
# that do not return `var` but are technically `var`
# if `foo` is var
if t.kind in {nnkDotExpr, nnkBracketExpr}:
let t0 = t[0]
quote do: isVar(`t0`)
else:
quote do: isVar(`t`)

while t.kind in {nnkDotExpr, nnkBracketExpr}:
if t[0].kind notin {nnkIdent, nnkSym}:
error "Expected a field name but found \"" & t[0].repr()
@@ -57,7 +68,7 @@ proc aliasTensor(id: int, tensor: NimNode): NimNode =

alias &= $t

return newIdentNode($alias & "_alias" & $id & '_')
return (newIdentNode($alias & "_alias" & $id & '_'), isVar)

proc initForEach*(
params: NimNode,
@@ -105,10 +116,10 @@ proc initForEach*(
aliases_stmt.add newCall(bindSym"withCompilerOptimHints")

for i, tensor in tensors:
let alias = aliasTensor(i, tensor)
let (alias, detectVar) = aliasTensor(i, tensor)
aliases.add alias
aliases_stmt.add quote do:
when isVar(`tensor`):
when `detectVar`:
var `alias`{.align_variable.} = `tensor`
else:
let `alias`{.align_variable.} = `tensor`
92 changes: 58 additions & 34 deletions src/arraymancer/nn_primitives/nnp_gru.nim
Original file line number Diff line number Diff line change
@@ -63,7 +63,6 @@ proc gru_cell_inference*[T: SomeFloat](
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
srz = (0 ..< 2*H)|1
s = (2*H ..< 3*H)|1


@@ -73,19 +72,29 @@ proc gru_cell_inference*[T: SomeFloat](
linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)

# Step 2 - Computing reset (r) and update (z) gate
var W2ru = W3x[_, srz] # shape [batch_size, 2*H] - we reuse the previous buffer
apply2_inline(W2ru, U3h[_, srz]):
sigmoid(x + y)

# Step 3 - Computing candidate hidden state ñ
var n = W3x[_, s] # shape [batch_size, H] - we reuse the previous buffer
apply3_inline(n, W2ru[_, sr], U3h[_, s]):
tanh(x + y * z)

# Step 4 - Update the hidden state
apply3_inline(hidden, W3x[_, sz], n):
(1 - y) * z + y * x
# Step 2 - Fused evaluation of the 4 GRU equations
# r = σ(Wr * x + bWr + Ur * h + bUr)
# z = σ(Wz * x + bWz + Uz * h + bUz)
# n = tanh(W * x + bW + r *. (U * h + bU ))
# h' = (1 - z) *. n + z *. h

# shape [batch_size, H] - we reuse the previous buffers
forEach wrx in W3x[_, sr], # Wr*x
wzx in W3x[_, sz], # Wz*x
wx in W3x[_, s], # W*x
urh in U3h[_, sr], # Ur*h
uzh in U3h[_, sz], # Uz*h
uh in U3h[_, s], # U*h
h in hidden: # hidden state
# Reset (r) gate and Update (z) gate
let r = sigmoid(wrx + urh)
let z = sigmoid(wzx + uzh)

# Candidate hidden state ñ
let n = tanh(wx + r * uh)

# h' = (1 - z) *. ñ + z *. h
h = (1-z) * n + z*h

proc gru_cell_forward*[T: SomeFloat](
input,
@@ -124,26 +133,38 @@ proc gru_cell_forward*[T: SomeFloat](
linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)

# # Saving for backprop
apply2_inline(Uh, U3h[_, s]):
y

# Step 2 - Computing reset (r) and update (z) gate
apply3_inline(r, W3x[_, sr], U3h[_, sr]):
sigmoid(y + z)

apply3_inline(z, W3x[_, sz], U3h[_, sz]):
sigmoid(y + z)

# Step 3 - Computing candidate hidden state ñ
# TODO: need apply4 / loopfusion for efficient
# buffer passing in Stacked GRU implementation
n = map3_inline(W3x[_, s], r, U3h[_, s]):
tanh(x + y * z)

# Step 4 - Update the hidden state
apply3_inline(hidden, z, n):
(1 - y) * z + y * x
# Step 2 - Fused evaluation of the 4 GRU equations
# and saving for backprop
# r = σ(Wr * x + bWr + Ur * h + bUr)
# z = σ(Wz * x + bWz + Uz * h + bUz)
# n = tanh(W * x + bW + r *. (U * h + bU ))
# h' = (1 - z) *. n + z *. h

# shape [batch_size, H] - we reuse the previous buffers
forEach wrx in W3x[_, sr], # Wr*x
wzx in W3x[_, sz], # Wz*x
wx in W3x[_, s], # W*x
urh in U3h[_, sr], # Ur*h
uzh in U3h[_, sz], # Uz*h
uh in U3h[_, s], # U*h
h in hidden, # hidden state
saveUh in Uh, # U*h cache for backprop
reset in r, # reset gate cache for backprop
update in z, # update gate cache for backprop
candidate in n: # candidate hidden state cache for backprop

# Cache for backprop
saveUh = uh

# Reset (r) gate and Update (z) gate
reset = sigmoid(wrx + urh)
update = sigmoid(wzx + uzh)

# Candidate hidden state ñ
candidate = tanh(wx + reset * uh)

# h' = (1 - z) *. ñ + z *. h
h = (1-update) * candidate + update*h

proc gru_cell_backward*[T: SomeFloat](
dx, dh, dW3, dU3, # input and weights gradients
@@ -162,6 +183,9 @@ proc gru_cell_backward*[T: SomeFloat](
## - dnext: gradient flowing back from the next layer
## - x, h, W3, U3: inputs saved from the forward pass
## - r, z, n, Uh: intermediate results saved from the forward pass of shape [batch_size, hidden_size]

# TODO: fused backprop with forEach

# Backprop of step 4 - z part
let dz = (h - n) *. dnext
let dn = (1.0.T -. z) *. dnext