-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.h
151 lines (111 loc) · 3.17 KB
/
graph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#ifndef _graph_h_
#define _graph_h_
#include <list>
#include <vector>
#include <memory>
#include <ostream>
#include <unordered_map>
#include "ndarray.h"
class Node;
class Graph;
class Op;
class Variable;
class Kernel;
class ValueKernel;
typedef std::shared_ptr<Node> NodeRef;
typedef std::shared_ptr<Graph> GraphRef;
typedef std::shared_ptr<Op> OpRef;
typedef std::shared_ptr<Variable> VariableRef;
typedef std::shared_ptr<Kernel> KernelRef;
class Node : public std::enable_shared_from_this<Node> {
public:
virtual ~Node();
NodeRef ref();
GraphRef graph();
virtual KernelRef kernel() const = 0;
virtual bool requires_grad() const = 0;
const NDArray& get_value() const;
virtual std::string str() const;
protected:
Node(GraphRef g);
Node(const Node&) = delete;
GraphRef graph_;
};
class Op : public Node {
protected:
struct protected_;
public:
explicit Op(const protected_&, GraphRef graph);
virtual ~Op();
OpRef ref();
operator NodeRef();
OpRef add(NodeRef other);
OpRef sub(NodeRef other);
OpRef mul(NodeRef other);
OpRef dot(NodeRef other);
OpRef mm(NodeRef other);
OpRef bmm(NodeRef other);
OpRef softmax();
OpRef softmax_ce(NodeRef other);
OpRef relu();
virtual std::string str() const override;
protected:
virtual KernelRef kernel() const override {
return kernel_;
}
virtual bool requires_grad() const override {
return false;
}
struct protected_ {
explicit protected_(int) { }
};
private:
Op(const Op&) = delete;
const Op& operator=(const Op&) = delete;
void set_kernel(KernelRef kernel) {
kernel_ = kernel;
}
template <class K, class... Args>
OpRef op(const std::string& name, Args... args);
KernelRef kernel_;
};
class Variable : public Op {
public:
static VariableRef create(GraphRef graph,
const Shape& shape, bool requires_grad=false);
Variable(const Op::protected_&, GraphRef graph,
const Shape& shape, bool requires_grad);
virtual ~Variable();
VariableRef ref();
operator OpRef();
operator NodeRef();
const Shape& shape() const;
void set_value(const NDArray& value);
virtual std::string str() const override;
protected:
virtual bool requires_grad() const override;
virtual KernelRef kernel() const override;
private:
Variable() = delete;
Variable(const Variable&) = delete;
const Variable& operator=(const Variable&) = delete;
std::shared_ptr<ValueKernel> kernel_;
Shape shape_;
bool requires_grad_;
};
class Graph {
public:
void add(NodeRef node);
void forward();
void backward(NodeRef node);
std::vector<VariableRef> get_variables() const;
NDArray gradient(const NodeRef& node) const;
protected:
std::unordered_map<NodeRef, std::list<NodeRef>> adj_;
std::vector<NodeRef> top_order_;
std::unordered_map<NodeRef, NDArray> gradients_;
};
std::ostream& operator<<(std::ostream& os, const NodeRef& node);
std::ostream& operator<<(std::ostream& os, const OpRef& node);
std::ostream& operator<<(std::ostream& os, const VariableRef& node);
#endif // _graph_h_