49 adf::port<input> pin[1];
50 adf::port<output> pout[1];
53 std::vector<float> weights,
54 std::vector<float> bias,
57 static_assert(M*K*4 <= TILE_BYTES);
58 static_assert(K*N*4 <= MAX_PARAM_BYTES);
59 static_assert(M*N*4 <= TILE_BYTES);
60 k[0] = adf::kernel::create_object<GEMM<M, K, N, IS_RELU>>(weights, bias);
61 adf::source(k[0]) =
"gemm.cc";
62 adf::headers(k[0]) = {
"gemm.h"};
63 adf::runtime<ratio>(k[0]) = 0.6;
64 adf::repetition_count(k[0]) = repeat_cnt;
66 adf::connect<adf::window<M*K*4>> (pin[0], k[0].in[0]);
67 adf::connect<adf::window<M*N*4>> (k[0].out[0], pout[0]);
69 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
70 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
71 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
72 adf::location<adf::parameter>(k[0].param[1]) = tilePos;
73 adf::location<adf::parameter>(k[0].param[1]) = adf::offset((K*N*4+31)/32*32);
97 adf::port<input> pin[2];
98 adf::port<output> pout[1];
101 std::vector<float> bias
103 k[0] = adf::kernel::create_object<GEMM<M, K, N, IS_RELU>>(bias);
104 adf::source(k[0]) =
"gemm.cc";
105 adf::headers(k[0]) = {
"gemm.h"};
106 adf::runtime<ratio>(k[0]) = 0.6;
107 adf::heap_size(k[0]) = 24576;
109 adf::connect<adf::stream> (pin[0], k[0].in[0]);
110 adf::connect<adf::stream> (pin[1], k[0].in[1]);
111 adf::connect<adf::window<M*N*4>> (k[0].out[0], pout[0]);
112 adf::samples_per_iteration(k[0].in[0]) = M*K;
138 static const int NCUTCHUNK = N % NCHUNK;
140 adf::relative_coordinate tileOffsets[8] = {
141 {.col_offset = -1, .row_offset = 0},
142 {.col_offset = 1, .row_offset = 0},
143 {.col_offset = -1, .row_offset = 1},
144 {.col_offset = 0, .row_offset = 1},
145 {.col_offset = 1, .row_offset = 1},
146 {.col_offset = -1, .row_offset = -1},
147 {.col_offset = 0, .row_offset = -1},
148 {.col_offset = 1, .row_offset = -1},
152 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK;
153 adf::kernel gemms[CHUNK_COUNT];
156 adf::port<input> pin[1];
157 adf::port<output> pout[1];
160 std::vector<float> weights,
161 std::vector<float> bias,
164 static_assert(CHUNK_COUNT <= 8);
165 static_assert(M*K*4 <= TILE_BYTES);
166 static_assert(K*N*4 <= MAX_PARAM_BYTES);
167 static_assert(M*N*4 <= TILE_BYTES);
169 std::vector<float> wChunk;
170 std::vector<float> bChunk;
172 for (
int i = 0; i < CHUNK_COUNT; i++) {
173 int chunkSize = (i*NCHUNK + NCHUNK > N) ? NCUTCHUNK : NCHUNK;
174 wChunk = std::vector<float>(weights.begin()+i*NCHUNK*K, weights.begin()+(i*NCHUNK+chunkSize)*K);
175 wChunk.resize(NCHUNK*K, 0);
176 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+chunkSize);
177 bChunk.resize(NCHUNK, 0);
179 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(wChunk, bChunk);
180 adf::source(gemms[i]) =
"gemm.cc";
181 adf::headers(gemms[i]) = {
"gemm.h"};
182 adf::runtime<ratio>(gemms[i]) = 0.6;
183 adf::repetition_count(gemms[i]) = repeat_cnt;
185 adf::location<adf::kernel>(gemms[i]) = adf::location<adf::kernel>(concat_g.k[0]) +
186 adf::relative_offset(tileOffsets[i]);
188 adf::location_constraint tilePos = adf::location<adf::kernel>(gemms[i]);
189 adf::location<adf::parameter>(gemms[i].param[0]) = tilePos;
190 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
191 adf::location<adf::parameter>(gemms[i].param[1]) = tilePos;
192 adf::location<adf::parameter>(gemms[i].param[1]) = adf::offset((NCHUNK*K*4+31)/32*32);
196 for (
int i = 0; i < CHUNK_COUNT; i++) {
197 adf::connect<adf::window<M*K*4>> (pin[0], gemms[i].in[0]);
198 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_g.pin[i]);
201 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
221 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK;
222 adf::kernel gemms[CHUNK_COUNT];
223 ConcatStreamGraph<CONCAT, float_t, CHUNK_COUNT, M, NCHUNK, N> concat_g;
226 adf::port<input> pin[1];
227 adf::port<output> pout[1];
230 std::vector<float> weights,
231 std::vector<float> bias,
234 static_assert(CHUNK_COUNT <= 8);
235 static_assert(M*K*4 <= TILE_BYTES);
236 static_assert(K*NCHUNK*4 <= MAX_PARAM_BYTES);
237 static_assert(M*NCHUNK*4 <= TILE_BYTES);
239 std::vector<float> bChunk;
241 for (
int i = 0; i < CHUNK_COUNT; i++) {
244 std::vector<float> wChunk;
245 wChunk.reserve(NCHUNK*K);
246 for (
int j = 0; j < K*N; j+=N) {
247 wChunk.insert(wChunk.end(), weights.begin()+j+i*NCHUNK, weights.begin()+j+i*NCHUNK+NCHUNK);
251 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
252 bChunk.resize(NCHUNK, 0);
254 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(wChunk, bChunk);
255 adf::source(gemms[i]) =
"gemm.cc";
256 adf::headers(gemms[i]) = {
"gemm.h"};
257 adf::runtime<ratio>(gemms[i]) = 0.6;
258 adf::repetition_count(gemms[i]) = repeat_cnt;
260 adf::location_constraint tilePos = adf::location<adf::kernel>(gemms[i]);
261 adf::location<adf::parameter>(gemms[i].param[0]) = tilePos;
262 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
263 adf::location<adf::parameter>(gemms[i].param[1]) = tilePos;
264 adf::location<adf::parameter>(gemms[i].param[1]) = adf::offset((NCHUNK*K*4+31)/32*32);
268 for (
int i = 0; i < CHUNK_COUNT; i++) {
269 adf::connect<adf::window<M*K*4>> (pin[0], gemms[i].in[0]);
270 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_g.pin[i]);
272 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
288 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK;
289 adf::kernel gemms[CHUNK_COUNT];
290 ConcatStreamGraph<CONCAT, float_t, CHUNK_COUNT, M, NCHUNK, N> concat_graph;
293 adf::port<input> pin[1 + CHUNK_COUNT];
294 adf::port<output> pout[1];
297 std::vector<float> bias
299 static_assert(CHUNK_COUNT <= 8);
300 static_assert(M*NCHUNK*4 <= TILE_BYTES);
302 std::vector<float> bChunk;
304 for (
int i = 0; i < CHUNK_COUNT; i++) {
306 bChunk = std::vector<float>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
307 bChunk.resize(NCHUNK, 0);
309 gemms[i] = adf::kernel::create_object<GEMM<M, K, NCHUNK, IS_RELU>>(bChunk);
310 adf::source(gemms[i]) =
"gemm.cc";
311 adf::headers(gemms[i]) = {
"gemm.h"};
312 adf::runtime<ratio>(gemms[i]) = 0.6;
313 adf::heap_size(gemms[i]) = 24576;
315 adf::connect<adf::stream> (pin[0], gemms[i].in[0]);
316 adf::connect<adf::stream> (pin[1+i], gemms[i].in[1]);
317 adf::connect<adf::window<M*NCHUNK*4>> (gemms[i].out[0], concat_graph.pin[i]);
319 adf::samples_per_iteration(gemms[i].in[0]) = M*K;
320 adf::samples_per_iteration(gemms[i].in[1]) = K*NCHUNK;
322 adf::location<adf::parameter>(gemms[i].param[0]) = adf::location<adf::kernel>(gemms[i]);
323 adf::location<adf::parameter>(gemms[i].param[0]) = adf::offset(0);
327 adf::connect<adf::stream> (concat_graph.pout[0], pout[0]);
329 for (
int i = 0; i < concat_graph.k1.size(); i++) {
330 adf::location<adf::kernel>(concat_graph.k1[i]) =
331 adf::location<adf::kernel>(gemms[i*2]) + adf::relative_offset({.col_offset=0, .row_offset=1});
333 adf::location_constraint cTilePos = adf::location<adf::kernel>(concat_graph.k1[i]);
334 adf::location<adf::stack>(gemms[i*2]) = cTilePos;
335 adf::location<adf::stack>(concat_graph.k1[i]) = cTilePos;