Skip to content

Commit

Permalink
lstm: Add LSTMMechanism + compiled support
Browse files Browse the repository at this point in the history
  • Loading branch information
SamKG committed Jul 26, 2020
1 parent 4d4c9bc commit 45c5d1d
Show file tree
Hide file tree
Showing 6 changed files with 657 additions and 10 deletions.
223 changes: 215 additions & 8 deletions psyneulink/core/components/functions/transferfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,7 @@
from psyneulink.core.components.functions.selectionfunctions import OneHot
from psyneulink.core.components.functions.statefulfunctions.integratorfunctions import SimpleIntegrator
from psyneulink.core.components.shellclasses import Projection
from psyneulink.core.globals.keywords import \
ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, \
GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, \
IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, \
TANH_FUNCTION, MATRIX_KEYWORD_NAMES, MATRIX, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, \
OFF, OFFSET, ON, PER_ITEM, PROB, PRODUCT, OUTPUT_TYPE, PROB_INDICATOR, \
RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM,\
TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIANCE, VARIABLE, X_0, PREFERENCE_SET_NAME
from psyneulink.core.globals.keywords import ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, LSTM_FUNCTION, MATRIX, MATRIX_KEYWORD_NAMES, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, OFF, OFFSET, ON, OUTPUT_TYPE, PER_ITEM, PREFERENCE_SET_NAME, PROB, PROB_INDICATOR, PRODUCT, RANDOM_CONNECTIVITY_MATRIX, RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM, TANH_FUNCTION, TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIABLE, VARIANCE, X_0
from psyneulink.core.globals.parameters import \
Parameter, get_validator_by_function
from psyneulink.core.globals.utilities import parameter_spec, get_global_seed, safe_len
Expand Down Expand Up @@ -2530,7 +2523,221 @@ def derivative(self, output, input=None, context=None):

return derivative

# **********************************************************************************************************************
# SoftMax
# **********************************************************************************************************************

class LSTM(TransferFunction):
componentName = LSTM_FUNCTION

def __init__(self,
default_variable=None,
params=None,
owner=None,
prefs: tc.optional(is_pref_set) = None):

super().__init__(
default_variable=default_variable,
params=params,
owner=owner,
prefs=prefs)

class Parameters(TransferFunction.Parameters):
i_input_matrix = Parameter(modulable=True)
i_hidden_matrix = Parameter(modulable=True)
i_gate_func = Parameter(default_value=Logistic())

f_input_matrix = Parameter(modulable=True)
f_hidden_matrix = Parameter(modulable=True)
f_gate_func = Parameter(default_value=Logistic())

g_input_matrix = Parameter(modulable=True)
g_hidden_matrix = Parameter(modulable=True)
g_gate_func = Parameter(default_value=Tanh())

o_input_matrix = Parameter(modulable=True)
o_hidden_matrix = Parameter(modulable=True)
o_gate_func = Parameter(default_value=Logistic())

h_gate_func = Parameter(default_value=Tanh())


def _instantiate_attributes_before_function(self, function=None, context=None):
input_size = len(self.variable[0])
hidden_size = len(self.variable[1])

# Instatiate input matrices
for param_id in ["i_input_matrix", "f_input_matrix", "g_input_matrix", "o_input_matrix"]:
param_val = getattr(self, param_id, None)
if param_val is None:
param_val = RANDOM_CONNECTIVITY_MATRIX

setattr(self, param_id, get_matrix(param_val, hidden_size, input_size, context=context))

# Instantiate hidden matrices
for param_id in ["i_hidden_matrix", "f_hidden_matrix", "g_hidden_matrix", "o_hidden_matrix"]:
param_val = getattr(self, param_id, None)
if param_val is None:
param_val = RANDOM_CONNECTIVITY_MATRIX

setattr(self, param_id, get_matrix(param_val, hidden_size, hidden_size, context=context))

# Instantiate function default variables
for param_id in ["i_gate_func", "f_gate_func","g_gate_func", "o_gate_func", "h_gate_func"]:
param_val = getattr(self, param_id)
param_val.default_variable = np.zeros(hidden_size)
param_val.defaults.variable = np.zeros(hidden_size)
param_val.variable = np.zeros(hidden_size)
param_val.default_value = np.zeros(hidden_size)
param_val.defaults.value = np.zeros(hidden_size)
param_val.value = np.zeros(hidden_size)

def _function(self,
variable=None,
context=None,
params=None,
):

x_t = variable[0]
h_prev = variable[1]
c_prev = variable[2]

# Calculate input
i_input_matrix = self._get_current_function_param("i_input_matrix", context=context)
i_hidden_matrix = self._get_current_function_param("i_hidden_matrix", context=context)
i_gate_func = self._get_current_function_param("i_gate_func", context=context)
i_t = i_gate_func(np.matmul(i_input_matrix, x_t) + np.matmul(i_hidden_matrix, h_prev))

# Calculate forget gate
f_input_matrix = self._get_current_function_param("f_input_matrix", context=context)
f_hidden_matrix = self._get_current_function_param("f_hidden_matrix", context=context)
f_gate_func = self._get_current_function_param("f_gate_func", context=context)
f_t = f_gate_func(np.matmul(f_input_matrix, x_t) + np.matmul(f_hidden_matrix, h_prev))

# Update cell state
g_input_matrix = self._get_current_function_param("g_input_matrix", context=context)
g_hidden_matrix = self._get_current_function_param("g_hidden_matrix", context=context)
g_gate_func = self._get_current_function_param("g_gate_func", context=context)
g_t = g_gate_func(np.matmul(g_input_matrix, x_t) + np.matmul(g_hidden_matrix, h_prev))
c_t = np.multiply(f_t, c_prev) + np.multiply(i_t, g_t)

# Calculate output gate
o_input_matrix = self._get_current_function_param("o_input_matrix", context=context)
o_hidden_matrix = self._get_current_function_param("o_hidden_matrix", context=context)
o_gate_func = self._get_current_function_param("o_gate_func", context=context)
o_t = o_gate_func(np.matmul(o_input_matrix, x_t) + np.matmul(o_hidden_matrix, h_prev))

# Update hidden state
h_gate_func = self._get_current_function_param("h_gate_func", context=context)
h_t = np.multiply(o_t, h_gate_func(c_t))
value = [h_t, c_t]

return value

def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
matmul = ctx.import_llvm_function("__pnl_builtin_mxm")
vecadd = ctx.import_llvm_function("__pnl_builtin_vec_add")
vechadamard = ctx.import_llvm_function("__pnl_builtin_vec_hadamard")

x_t = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
h_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])
c_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(2)])

def _mxv(m, v):
tmp = builder.alloca(h_prev.type.pointee)
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
ctx.int32_ty(0)])
dim_x = len(m.type.pointee)
dim_y = len(m.type.pointee.elements[0])
m_ptr = builder.gep(m, [ctx.int32_ty(0),
ctx.int32_ty(0),
ctx.int32_ty(0)])
v_ptr = builder.gep(v, [ctx.int32_ty(0),
ctx.int32_ty(0)])

builder.call(matmul, [m_ptr,
v_ptr,
ctx.int32_ty(dim_x),
ctx.int32_ty(dim_y),
ctx.int32_ty(1),
tmp_ptr])

return tmp

def _vxv(v1, v2):
tmp = builder.alloca(h_prev.type.pointee)
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
ctx.int32_ty(0)])
dim_x = len(v1.type.pointee)
v1_ptr = builder.gep(v1, [ctx.int32_ty(0),
ctx.int32_ty(0)])
v2_ptr = builder.gep(v2, [ctx.int32_ty(0),
ctx.int32_ty(0)])

builder.call(vechadamard, [v1_ptr,
v2_ptr,
ctx.int32_ty(dim_x),
tmp_ptr])
return tmp

def _mac(m1, v1, m2, v2, mul_op=_mxv):
val1 = mul_op(m1, v1)
val2 = mul_op(m2, v2)
val1_ptr = builder.gep(val1, [ctx.int32_ty(0),
ctx.int32_ty(0)])
val2_ptr = builder.gep(val2, [ctx.int32_ty(0),
ctx.int32_ty(0)])
builder.call(vecadd, [val1_ptr,
val2_ptr,
ctx.int32_ty(len(m1.type.pointee)),
val1_ptr])
return val1

def _call_func(func_id, in_vec, out_vec):
param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, func_id)
state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, func_id)

llvm_func = ctx.import_llvm_function(getattr(self, func_id), tags=tags)
builder.call(llvm_func, [param_ptr, state_ptr, in_vec, out_vec])

# Calculate input
i_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_input_matrix')
i_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_hidden_matrix')
i_t = _mac(i_input_matrix, x_t, i_hidden_matrix, h_prev)
_call_func("i_gate_func", i_t, i_t)

# Calculate forget gate
f_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_input_matrix')
f_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_hidden_matrix')
f_t = _mac(f_input_matrix, x_t, f_hidden_matrix, h_prev)
_call_func("f_gate_func", f_t, f_t)

# Update cell state
g_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_input_matrix')
g_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_hidden_matrix')
g_t = _mac(g_input_matrix, x_t, g_hidden_matrix, h_prev)
_call_func("g_gate_func", g_t, g_t)

c_t = _mac(f_t, c_prev, i_t, g_t, mul_op=_vxv)

# Calculate output gate
o_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_input_matrix')
o_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_hidden_matrix')
o_t = _mac(o_input_matrix, x_t, o_hidden_matrix, h_prev)
_call_func("o_gate_func", o_t, o_t)

# Update hidden state
h_t = builder.alloca(h_prev.type.pointee)
_call_func("h_gate_func", c_t, h_t)
h_t = _vxv(o_t, h_t)

# Writeback into value struct
builder.store(builder.load(h_t), builder.gep(arg_out, [ctx.int32_ty(0),
ctx.int32_ty(0)]))
builder.store(builder.load(c_t), builder.gep(arg_out, [ctx.int32_ty(0),
ctx.int32_ty(1)]))

return builder
# **********************************************************************************************************************
# LinearMatrix
# **********************************************************************************************************************
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6747,7 +6747,7 @@ def bfs(start):
pathways.append(p)
continue
for projection, efferent_node in [(p, p.receiver.owner) for p in curr_node.efferents]:
if (not hasattr(projection,'learnable')) or (projection.learnable is False) or efferent_node in prev:
if getattr(projection, 'learnable', False) is False or efferent_node in prev:
continue
prev[efferent_node] = projection
prev[projection] = curr_node
Expand Down
4 changes: 3 additions & 1 deletion psyneulink/core/globals/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'LEARNING_PATHWAY', 'LEARNING_PROJECTION', 'LEARNING_PROJECTION_PARAMS', 'LEARNING_RATE', 'LEARNING_SIGNAL',
'LEARNING_SIGNAL_SPECS', 'LEARNING_SIGNALS',
'LESS_THAN', 'LESS_THAN_OR_EQUAL', 'LINEAR', 'LINEAR_COMBINATION_FUNCTION', 'LINEAR_FUNCTION',
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LSTM_FUNCTION', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
'MAPPING_PROJECTION', 'MAPPING_PROJECTION_PARAMS', 'MASKED_MAPPING_PROJECTION',
'MATRIX', 'MATRIX_KEYWORD_NAMES', 'MATRIX_KEYWORD_SET', 'MATRIX_KEYWORD_VALUES', 'MATRIX_KEYWORDS','MatrixKeywords',
'MAX_ABS_DIFF', 'MAX_ABS_INDICATOR', 'MAX_ONE_HOT', 'MAX_ABS_ONE_HOT', 'MAX_ABS_VAL',
Expand Down Expand Up @@ -526,6 +526,7 @@ def _is_metric(metric):
TRANSFER_MECHANISM = "TransferMechanism"
LEABRA_MECHANISM = "LeabraMechanism"
RECURRENT_TRANSFER_MECHANISM = "RecurrentTransferMechanism"
LSTM_MECHANISM = "LSTMMechanism"
CONTRASTIVE_HEBBIAN_MECHANISM = "ContrastiveHebbianMechanism"
LCA_MECHANISM = "LCAMechanism"
KOHONEN_MECHANISM = 'KohonenMechanism'
Expand Down Expand Up @@ -557,6 +558,7 @@ def _is_metric(metric):
GAUSSIAN_FUNCTION = "Gaussian Function"
GAUSSIAN_DISTORT_FUNCTION = "GaussianDistort Function"
SOFTMAX_FUNCTION = 'SoftMax Function'
LSTM_FUNCTION = 'LSTM Function'
LINEAR_MATRIX_FUNCTION = "LinearMatrix Function"
TRANSFER_WITH_COSTS_FUNCTION = "TransferWithCosts Function"

Expand Down
Loading

0 comments on commit 45c5d1d

Please sign in to comment.