Skip to content

Commit

Permalink
Wow... Fixed a VERY old bug with LLVM/LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
csxeba committed May 8, 2019
1 parent 5876042 commit 4c78ba1
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 86 deletions.
6 changes: 5 additions & 1 deletion brainforge/atomic/recurrent_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def forward(self, X, W, b):
p[:, :outdim] = self.actfn.forward(p[:, :outdim])
p[:, outdim:] = sigmoid.forward(p[:, outdim:])

cand[t], f[t], i[t], o[t] = np.split(p, 4, axis=1)
cand[t] = p[:, :outdim]
f[t] = p[:, outdim:2*outdim]
i[t] = p[:, 2*outdim:3*outdim]
o[t] = p[:, 3*outdim:]
# cand[t], f[t], i[t], o[t] = np.split(p, 4, axis=1)

C[t] = C[t-1] * f[t] + cand[t] * i[t]

Expand Down
2 changes: 1 addition & 1 deletion brainforge/llatomic/_llactivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def sigmoid_p(A):
return A * (s1 - A)


@nb.vectorize(finfout, nopython=True)
@nb.jit(nopython=True)
def tanh(Z):
return np.tanh(Z)

Expand Down
47 changes: 10 additions & 37 deletions brainforge/llatomic/_lllstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,8 @@
from ._llactivation import sigmoid, tanh, relu


@nb.jit(nopython=True, cache=True)
def _lstm_update_state(p, T, outdim):
# C[0], Ca[1], cand[2], f[3], i[4], o[5]
p[:, outdim:] = sigmoid(p[:, outdim:]) # sigmoid to gates
T[2] = p[:, :outdim] # candidate
T[3] = p[:, outdim:2*outdim] # forget
T[4] = p[:, 2*outdim:3*outdim] # input
T[5] = p[:, 3*outdim:] # output
T[0] = T[0] * T[3] + T[2] * T[4] # Ct = Ct-1 * f + cand * i
return T


# @nb.jit(nopython=True, cache=True)
def lstm_forward_tanh(X, W, b):
@nb.jit(nopython=True)
def lstm_forward(X, W, b, activation):
outdim = W.shape[-1] // 4
time, batch, indim = X.shape

Expand All @@ -27,31 +15,16 @@ def lstm_forward_tanh(X, W, b):
for t in range(time):
Z[t] = np.concatenate((X[t], O[t-1]), axis=-1)
p = np.dot(Z[t], W) + b
p[:, :outdim] = tanh(p[:, :outdim]) # nonlin to candidate

T[t] = _lstm_update_state(p, T[t], outdim)

T[t, 1] = tanh(T[t, 0]) # nonlin to state
O[t] = T[t, 1] * T[t, 5] # O = f(C) * o
return np.concatenate((O.ravel(), Z.ravel(), T.ravel()))

p[:, :outdim] = activation(p[:, :outdim]) # nonlin to candidate

@nb.jit(nopython=True)
def lstm_forward_relu(X, W, b):
outdim = W.shape[-1] // 4
time, batch, indim = X.shape

Z = np.zeros((time, batch, indim + outdim))
O = np.zeros((time, batch, outdim))
# C[0], Ca[1], cand[2], f[3], i[4], o[5]
T = np.zeros((time, 6, batch, outdim))
p[:, outdim:] = sigmoid(p[:, outdim:]) # sigmoid to gates
T[t, 2] = p[:, :outdim] # candidate
T[t, 3] = p[:, outdim:2*outdim] # forget
T[t, 4] = p[:, 2*outdim:3*outdim] # input
T[t, 5] = p[:, 3*outdim:] # output
T[t, 0] = T[t-1, 0] * T[t, 3] + T[t, 2] * T[t, 4] # Ct = Ct-1 * f + cand * i

for t in range(time):
Z[t] = np.concatenate((X[t], O[t-1]), axis=-1)
p = np.dot(Z[t], W) + b
p[:, :outdim] = relu(p[:, :outdim]) # nonlin to candidate
T[t] = _lstm_update_state(p, T[t], outdim)
T[t, 1] = relu(T[t, 0]) # nonlin to state
T[t, 1] = activation(T[t, 0]) # nonlin to state
O[t] = T[t, 1] * T[t, 5] # O = f(C) * o
return np.concatenate((O.ravel(), Z.ravel(), T.ravel()))

Expand Down
10 changes: 2 additions & 8 deletions brainforge/llatomic/llrecurrent_op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ._llrecurrent import recurrent_forward_relu, recurrent_forward_tanh, recurrent_backward
from ._lllstm import lstm_forward_tanh, lstm_backward
from ._lllstm import lstm_forward, lstm_backward
from .llactivation_op import llactivations

sigmoid = llactivations["sigmoid"]()
Expand Down Expand Up @@ -48,12 +48,6 @@ def backward(self, Z, O, E, W):

class LSTMOp(ROpBase):

def __init__(self, activation):
super().__init__(activation)
self.fwlow = {
"tanh": lstm_forward_tanh
}[activation.lower()]

def forward(self, X, W, b):
do = W.shape[-1] // 4
t, m, di = X.shape
Expand All @@ -64,7 +58,7 @@ def forward(self, X, W, b):
Obord = np.prod(Oshape)
Zbord = np.prod(Zshape) + Obord

vector = self.fwlow(X, W, b)
vector = lstm_forward(X, W, b, self.llact.llact)

O = vector[:Obord].reshape(*Oshape)
Z = vector[Obord:Zbord].reshape(*Zshape)
Expand Down
22 changes: 14 additions & 8 deletions tests/test_numba_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_lstm_op(self):
nbop = NbLSTM("tanh")

BSZE = 20
TIME = 5
TIME = 2
DDIM = 10
NEUR = 15

Expand All @@ -102,24 +102,30 @@ def test_lstm_op(self):
b = np.random.randn(NEUR * 4)
# E = np.random.randn(BSZE, TIME, NEUR)

npO, npZ, cache = npop.forward(X, W, b)
nbO, nbZ, cache = nbop.forward(X, W, b)
npO, npZ, npcache = npop.forward(X, W, b)
nbO, nbZ, nbcache = nbop.forward(X, W, b)

# visualize(X, npZ, nbZ)
for i, array_type in enumerate(["C", "Ca", "cand", "f", "i", "o"]):
for t in range(TIME):
d = np.sum(np.abs(npcache[i, t] - nbcache[i, t]))
print("{} diff @ t {}: {}".format(array_type, t, d))
self.assertTrue(np.allclose(npcache, nbcache))
self.assertTrue(np.allclose(npZ, nbZ))
self.assertTrue(np.allclose(npO, nbO))


def visualize(A, O1, O2, supt=None):
TAKE = 0
d = O1 - O2
vmax, vmin = max(O1.max(), O2.max()), min(O1.min(), O2.min())
fig, axarr = plt.subplots(2, 2)
axarr[0][0].imshow(A[0, 0], vmin=0, vmax=1, cmap="autumn")
axarr[0][0].set_title("A")
axarr[0][1].imshow(d[0, 0], cmap="seismic")
print("Total deviance:", d.sum())
axarr[0][1].imshow(d[TAKE], cmap="seismic")
axarr[0][1].set_title("d")
axarr[1][0].imshow(O1[0, 0], vmin=vmin, vmax=vmax, cmap="hot")
axarr[1][0].imshow(O1[TAKE], vmin=vmin, vmax=vmax, cmap="hot")
axarr[1][0].set_title("npO")
axarr[1][1].imshow(O2[0, 0], vmin=vmin, vmax=vmax, cmap="hot")
axarr[1][1].imshow(O2[TAKE], vmin=vmin, vmax=vmax, cmap="hot")
axarr[1][1].set_title("nbO")
plt.suptitle(supt)
plt.tight_layout()
Expand Down
30 changes: 0 additions & 30 deletions tests/test_persistance.py

This file was deleted.

2 changes: 1 addition & 1 deletion xperiments/xp_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Y = np.random.randn(*OUTSHP)

net = BackpropNetwork(input_shape=DSHAPE[1:], layerstack=[
LSTM(16, activation="tanh"),
LSTM(16, activation="tanh", compiled=1),
DenseLayer(OUTSHP[1:], activation="linear", trainable=0)
], cost="mse", optimizer="sgd")

Expand Down

0 comments on commit 4c78ba1

Please sign in to comment.