forked from bd4/sycl-test
-
Notifications
You must be signed in to change notification settings - Fork 0
/
expr.hpp
156 lines (118 loc) · 3.11 KB
/
expr.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
template <typename T>
struct plus {
auto operator()(T left, T right) const -> T {
return left + right;
}
};
template <typename T>
struct minus {
auto operator()(T left, T right) const -> T {
return left + right;
}
};
template <typename T>
struct times {
auto operator()(T left, T right) const -> T {
return left * right;
}
};
template <typename DerivedExpression>
class expression {
protected:
expression() = default;
public:
using derived_type = DerivedExpression;
const derived_type& derived() const&;
derived_type& derived() &;
derived_type derived() &&;
};
template <typename DerivedExpression>
inline auto expression<DerivedExpression>::derived() const&
-> const derived_type& {
return static_cast<const derived_type&>(*this);
}
template <typename DerivedExpression>
inline auto expression<DerivedExpression>::derived() & -> derived_type& {
return static_cast<derived_type&>(*this);
}
template <typename DerivedExpression>
inline auto expression<DerivedExpression>::derived() && -> derived_type {
return static_cast<derived_type&&>(*this);
}
template <typename Result, typename Arg>
class constfn {
public:
using result_type = Result;
using arg_type = Arg;
constfn(result_type v) : v_(v) {}
result_type operator() (arg_type arg) const {
return v_;
}
private:
result_type v_;
};
template <typename Result, typename Arg>
class linearfn {
public:
using result_type = Result;
using arg_type = Arg;
linearfn(result_type m, result_type b) : m_(m), b_(b) {}
result_type operator() (arg_type arg) const {
return m_ * arg + b_;
}
private:
result_type m_;
result_type b_;
};
template <typename F, typename T>
class binaryclosure {
public:
binaryclosure(F&& f, T&& arg0, T&& arg1)
: f_(std::forward<F>(f)),
arg0_(std::forward<T>(arg0)),
arg1_(std::forward<T>(arg1))
{}
T operator() () const {
return f_(arg0_, arg1_);
}
private:
F f_;
T arg0_;
T arg1_;
};
template <typename F, typename E1, typename E2>
class binaryexpr;
template <typename F, typename E1, typename E2>
class binaryexpr : public expression<binaryexpr<F, E1, E2>> {
public:
using self_type = binaryexpr<F, E1, E2>;
using base_type = expression<self_type>;
using function_type = F;
using expression_type_1 = E1;
using expression_type_2 = E2;
binaryexpr(F&& f, E1&& e1, E2&& e2)
: f_(std::forward<F>(f)),
e1_(std::forward<E1>(e1)),
e2_(std::forward<E2>(e2))
{}
template <typename Arg>
auto operator() (Arg arg) const;
private:
F f_;
E1 e1_;
E2 e2_;
};
template <typename F, typename E1, typename E2>
template <typename Arg>
auto binaryexpr<F, E1, E2>::operator()(Arg arg) const {
return f_(e1_(arg), e2_(arg));
}
template <typename F, typename E1, typename E2>
auto mkexpr(F&& f, E1&& e1, E2&& e2)
{
return binaryexpr<F, E1, E2>(std::forward<F>(f),
std::forward<E1>(e1),
std::forward<E2>(e2));
}
template <typename Expr>
class SyclKernel;