-
Notifications
You must be signed in to change notification settings - Fork 0
/
activations.py
77 lines (54 loc) · 2.06 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Activation Functions
import jax.numpy as jnp
from jax.nn import softmax as jnp_softmax
import numpy as np
__all__ = ['Tanh', 'Softmax', 'ReLu', 'Sigmoid']
class Activation():
def __init__(self, activation, activation_derivative):
self.activation = activation
self.activation_derivative = activation_derivative
self.tag = "Activation"
self.input_shape = None
self.output_shape = None
def forward(self, input):
self.input = input
return self.activation(self.input)
def backward(self, output_gradient, learning_rate):
return jnp.multiply(output_gradient, self.activation_derivative(self.input))
class Tanh(Activation):
def __init__(self):
def tanh(X):
return jnp.tanh(X)
def tanh_derivative(X):
return 1 - jnp.tanh(X) ** 2
super().__init__(tanh, tanh_derivative)
class Softmax(Activation):
def __init__(self):
def softmax(X, axis=0):
epsilon = 1e-7
#X = jnp.where(X==0, epsilon, jnp.where(X!=0, X, X))
#print(f"X: {X}")
X_stable = X - jnp.max(X, axis=axis, keepdims=True)
softmax_output = jnp_softmax(X, axis)
return softmax_output
def softmax_derivative(X):
size = jnp.size(X)
self.output = jnp_softmax(X, 0)
return jnp.dot(jnp.tile(self.output, size) * (jnp.identity(size) - jnp.transpose(jnp.tile(self.output, size))), X)
super().__init__(softmax, softmax_derivative)
class Sigmoid(Activation):
def __init__(self):
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
def sigmoid_prime(x):
s = sigmoid(x)
return s * (1 - s)
super().__init__(sigmoid, sigmoid_prime)
class ReLu(Activation):
def __init__(self):
def relu(X):
return jnp.maximum(0,X)
def relu_derivative(X):
epsilon = 1e-10
return jnp.where(X>=0, 1, jnp.where(X<0, epsilon, X))
super().__init__(relu, relu_derivative)