onnx2versal
Loading...
Searching...
No Matches
graph_gemm.h
1#ifndef __GEMM_GRAPH_H_
2#define __GEMM_GRAPH_H_
3
4#include <assert.h>
5#include <adf.h>
6#include "gemm.h"
7#include "graph_concat.h"
8#include "graph_utils.h"
9
10
41template <template<int, int, int, int> class GEMM,
42 int M, int K, int N, int IS_RELU>
43class GemmReluGraph : public adf::graph {
44
45 private:
46 adf::kernel k[1];
47
48 public:
49 adf::port<input> pin[1];
50 adf::port<output> pout[1];
51
53 std::vector<float> weights,
54 std::vector<float> bias,
55 int repeat_cnt = 1
56 ) {
57 static_assert(M*K*4 <= TILE_BYTES);
58 static_assert(K*N*4 <= MAX_PARAM_BYTES);
59 static_assert(M*N*4 <= TILE_BYTES);
60 k[0] = adf::kernel::create_object<GEMM<M, K, N, IS_RELU>>(weights, bias);
61 adf::source(k[0]) = "gemm.cc";
62 adf::headers(k[0]) = {"gemm.h"};
63 adf::runtime<ratio>(k[0]) = 0.6;
64 adf::repetition_count(k[0]) = repeat_cnt;
65
66 adf::connect<adf::window<M*K*4>> (pin[0], k[0].in[0]);
67 adf::connect<adf::window<M*N*4>> (k[0].out[0], pout[0]);
68
69 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
70 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
71 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
72 adf::location<adf::parameter>(k[0].param[1]) = tilePos;
73 adf::location<adf::parameter>(k[0].param[1]) = adf::offset((K*N*4+31)/32*32);
74 }
75
76};
77
78
89template <template<int, int, int, int> class GEMM,
90 int M, int K, int N, int IS_RELU>
91class GemmReluStreamGraph : public adf::graph {
92
93 private:
94 adf::kernel k[1];
95
96 public:
97 adf::port<input> pin[2];
98 adf::port<output> pout[1];
99
101 std::vector<float> bias
102 ) {
103 k[0] = adf::kernel::create_object<GEMM<M, K, N, IS_RELU>>(bias);
104 adf::source(k[0]) = "gemm.cc";
105 adf::headers(k[0]) = {"gemm.h"};
106 adf::runtime<ratio>(k[0]) = 0.6;
107 adf::heap_size(k[0]) = 24576; // assume KxN > MxN
108
109 adf::connect<adf::stream> (pin[0], k[0].in[0]);
110 adf::connect<adf::stream> (pin[1], k[0].in[1]);
111 adf::connect<adf::window<M*N*4>> (k[0].out[0], pout[0]);
112 adf::samples_per_iteration(k[0].in[0]) = M*K;
113 }
114
115};
116
117
131template <
132 template<int, int, int, int> class GEMM,
133 template<typename, int, int, int, int> class CONCAT,
134 int NCHUNK, int M, int K, int N, int IS_RELU>
135class GemmReluMknkChunkGraph : public adf::graph {
136
137 private:
138 static const int NCUTCHUNK = N % NCHUNK;
139
140 adf::relative_coordinate tileOffsets[8] = {
141 {.col_offset = -1, .row_offset = 0}, // left, right
142 {.col_offset = 1, .row_offset = 0},
143 {.col_offset = -1, .row_offset = 1}, // bottom row
144 {.col_offset = 0, .row_offset = 1},
145 {.col_offset = 1, .row_offset = 1},
146 {.col_offset = -1, .row_offset = -1}, // top row
147 {.col_offset = 0, .row_offset = -1},
148 {.col_offset = 1, .row_offset = -1},
149 };
150
151 public:
152 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK; // ceiling
153 adf::kernel gemms[CHUNK_COUNT];
155
156 adf::port<input> pin[1];
157 adf::port<output> pout[1];
158
160 std::vector<float> weights,
161 std::vector<float> bias,
162 int repeat_cnt = 1
163 ) {
164 static_assert(CHUNK_COUNT <= 8);
165 static_assert(M*K*4 <= TILE_BYTES);
166 static_assert(K*N*4 <= MAX_PARAM_BYTES);
167 static_assert(M*N*4 <= TILE_BYTES);
168
169 std::vector<float> wChunk;
170 std::vector<float> bChunk;
171
172 for (int i = 0; i < CHUNK_COUNT; i++) {
173 int chunkSize = (i*NCHUNK + NCHUNK > N) ? NCUTCHUNK : NCHUNK;
174 wChunk = std::vector<float>(weights.begin()+i*NCHUNK*K, weights.begin()+(i*NCHUNK+chunkSize)*K);
175 wChunk.resize(NCHUNK*K, 0);
176 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+chunkSize);
177 bChunk.resize(NCHUNK, 0);
178
179 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(wChunk, bChunk);
180 adf::source(gemms[i]) = "gemm.cc";
181 adf::headers(gemms[i]) = {"gemm.h"};
182 adf::runtime<ratio>(gemms[i]) = 0.6;
183 adf::repetition_count(gemms[i]) = repeat_cnt;
184
185 adf::location<adf::kernel>(gemms[i]) = adf::location<adf::kernel>(concat_g.k[0]) +
186 adf::relative_offset(tileOffsets[i]);
187
188 adf::location_constraint tilePos = adf::location<adf::kernel>(gemms[i]);
189 adf::location<adf::parameter>(gemms[i].param[0]) = tilePos; // weight (<= 16384B)
190 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
191 adf::location<adf::parameter>(gemms[i].param[1]) = tilePos; // bias (<= 4096B)
192 adf::location<adf::parameter>(gemms[i].param[1]) = adf::offset((NCHUNK*K*4+31)/32*32);
193 // arbitrary input/output buffer location due to interconnect design
194 }
195
196 for (int i = 0; i < CHUNK_COUNT; i++) {
197 adf::connect<adf::window<M*K*4>> (pin[0], gemms[i].in[0]);
198 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_g.pin[i]);
199 }
200
201 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
202 }
203
204};
205
206
214template <
215 template<int, int, int, int> class GEMM,
216 template<typename, int, int, int, int> class CONCAT,
217 int NCHUNK, int M, int K, int N, int IS_RELU>
218class GemmReluMkknChunkGraph : public adf::graph {
219
220 private:
221 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK; // ceiling
222 adf::kernel gemms[CHUNK_COUNT];
223 ConcatStreamGraph<CONCAT, float_t, CHUNK_COUNT, M, NCHUNK, N> concat_g;
224
225 public:
226 adf::port<input> pin[1];
227 adf::port<output> pout[1];
228
230 std::vector<float> weights, // KxN_RND
231 std::vector<float> bias, // N
232 int repeat_cnt = 1
233 ) {
234 static_assert(CHUNK_COUNT <= 8);
235 static_assert(M*K*4 <= TILE_BYTES);
236 static_assert(K*NCHUNK*4 <= MAX_PARAM_BYTES);
237 static_assert(M*NCHUNK*4 <= TILE_BYTES);
238
239 std::vector<float> bChunk;
240
241 for (int i = 0; i < CHUNK_COUNT; i++) {
242
243 // build wchunk
244 std::vector<float> wChunk;
245 wChunk.reserve(NCHUNK*K);
246 for (int j = 0; j < K*N; j+=N) {
247 wChunk.insert(wChunk.end(), weights.begin()+j+i*NCHUNK, weights.begin()+j+i*NCHUNK+NCHUNK);
248 }
249
250 // build bChunk
251 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
252 bChunk.resize(NCHUNK, 0);
253
254 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(wChunk, bChunk);
255 adf::source(gemms[i]) = "gemm.cc";
256 adf::headers(gemms[i]) = {"gemm.h"};
257 adf::runtime<ratio>(gemms[i]) = 0.6;
258 adf::repetition_count(gemms[i]) = repeat_cnt;
259
260 adf::location_constraint tilePos = adf::location<adf::kernel>(gemms[i]);
261 adf::location<adf::parameter>(gemms[i].param[0]) = tilePos;
262 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
263 adf::location<adf::parameter>(gemms[i].param[1]) = tilePos;
264 adf::location<adf::parameter>(gemms[i].param[1]) = adf::offset((NCHUNK*K*4+31)/32*32);
265 // arbitrary input/output buffer location due to interconnect design
266 }
267
268 for (int i = 0; i < CHUNK_COUNT; i++) {
269 adf::connect<adf::window<M*K*4>> (pin[0], gemms[i].in[0]);
270 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_g.pin[i]);
271 }
272 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
273 }
274
275};
276
277
281template <
282 template<int, int, int, int> class GEMM,
283 template<typename, int, int, int, int> class CONCAT,
284 int NCHUNK, int M, int K, int N, int IS_RELU>
285class GemmReluMkknChunkNStreamGraph : public adf::graph {
286
287 private:
288 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK; // ceiling
289 adf::kernel gemms[CHUNK_COUNT];
290 ConcatStreamGraph<CONCAT, float_t, CHUNK_COUNT, M, NCHUNK, N> concat_graph;
291
292 public:
293 adf::port<input> pin[1 + CHUNK_COUNT];
294 adf::port<output> pout[1];
295
297 std::vector<float> bias
298 ) {
299 static_assert(CHUNK_COUNT <= 8);
300 static_assert(M*NCHUNK*4 <= TILE_BYTES);
301
302 std::vector<float> bChunk;
303
304 for (int i = 0; i < CHUNK_COUNT; i++) {
305 // build bChunk
306 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
307 bChunk.resize(NCHUNK, 0);
308
309 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(bChunk);
310 adf::source(gemms[i]) = "gemm.cc";
311 adf::headers(gemms[i]) = {"gemm.h"};
312 adf::runtime<ratio>(gemms[i]) = 0.6;
313 adf::heap_size(gemms[i]) = 24576;
314
315 adf::connect<adf::stream> (pin[0], gemms[i].in[0]);
316 adf::connect<adf::stream> (pin[1+i], gemms[i].in[1]);
317 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_graph.pin[i]);
318
319 adf::samples_per_iteration(gemms[i].in[0]) = M*K;
320 adf::samples_per_iteration(gemms[i].in[1]) = K*NCHUNK;
321
322 adf::location<adf::parameter>(gemms[i].param[0]) = adf::location<adf::kernel>(gemms[i]);
323 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
324 // arbitrary input/output buffer location due to interconnect design
325 }
326
327 adf::connect<adf::stream> (concat_graph.pout[0], pout[0]);
328
329 for (int i = 0; i < concat_graph.k1.size(); i++) {
330 adf::location<adf::kernel>(concat_graph.k1[i]) =
331 adf::location<adf::kernel>(gemms[i*2]) + adf::relative_offset({.col_offset=0, .row_offset=1});
332
333 adf::location_constraint cTilePos = adf::location<adf::kernel>(concat_graph.k1[i]);
334 adf::location<adf::stack>(gemms[i*2]) = cTilePos;
335 adf::location<adf::stack>(concat_graph.k1[i]) = cTilePos;
336 }
337 }
338
339};
343#endif // __GEMM_GRAPH_H_
Graph wrapper for arbitrary concat kernel implementation and lanes.
Definition graph_concat.h:37
Single instance graph that stores weights and biases Max size = 16384 and 4096 bytes respectively.
Definition graph_gemm.h:43
Multiinstance graph for MxK times KxN that stores weights and biases Requires KxN_RND weight,...
Definition graph_gemm.h:218
Multiinstance graph for MxK times KxN that stores biases.
Definition graph_gemm.h:285
Multiinstance graph for MxK times NxK that stores weights and biases. Requires NxK weight,...
Definition graph_gemm.h:135
Single instance graph that streams weights and biases, significantly slower.
Definition graph_gemm.h:91