44 adf::port<input> pin[1];
45 adf::port<output> pout[1];
48 std::vector<TTPARAM> weights,
49 std::vector<int32_t> bias,
57 k[0] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, N>>(
58 weights, bias, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
59 adf::source(k[0]) =
"qgemm.cc";
60 adf::headers(k[0]) = {
"qgemm.h"};
61 adf::runtime<ratio>(k[0]) = 0.6;
62 adf::heap_size(k[0]) = K + 1024;
64 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
65 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
66 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
68 adf::connect<adf::stream> (pin[0], k[0].in[0]);
69 adf::connect<adf::stream> (k[0].out[0], pout[0]);
71 adf::samples_per_iteration(k[0].in[0]) = M*K;
72 adf::samples_per_iteration(k[0].out[0]) = M*N;
96 adf::port<input> pin[2];
97 adf::port<output> pout[1];
100 std::vector<int32_t> bias,
108 k[0] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, N>>(
109 bias, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
110 adf::source(k[0]) =
"qgemm.cc";
111 adf::headers(k[0]) = {
"qgemm.h"};
112 adf::runtime<ratio>(k[0]) = 0.6;
113 adf::heap_size(k[0]) = 24576;
115 adf::connect<adf::stream> (pin[0], k[0].in[0]);
116 adf::connect<adf::stream> (pin[1], k[0].in[1]);
117 adf::connect<adf::stream> (k[0].out[0], pout[0]);
118 adf::samples_per_iteration(k[0].in[0]) = M*K;
119 adf::samples_per_iteration(k[0].out[0]) = M*N;
121 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
122 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
123 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
146 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK;
147 adf::kernel k[CHUNK_COUNT];
148 ConcatStreamGraph<CONCAT, TT, CHUNK_COUNT, M, NCHUNK, N> concat_g;
150 adf::port<input> pin[1];
151 adf::port<output> pout[1];
154 std::vector<TTPARAM> weights,
155 std::vector<int32_t> bias,
163 static_assert(CHUNK_COUNT <= 8);
164 static_assert(M*K <= TILE_BYTES);
165 static_assert(K*NCHUNK <= MAX_PARAM_BYTES);
166 static_assert(M*NCHUNK <= TILE_BYTES);
168 std::vector<int32_t> bChunk;
170 for (
int i = 0; i < CHUNK_COUNT; i++) {
173 std::vector<TTPARAM> wChunk;
174 wChunk.reserve(NCHUNK*K);
175 for (
int j = 0; j < K*N; j+=N) {
176 wChunk.insert(wChunk.end(), weights.begin()+j+i*NCHUNK, weights.begin()+j+i*NCHUNK+NCHUNK);
180 bChunk = std::vector<int32_t>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
181 bChunk.resize(NCHUNK, 0);
183 k[i] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, NCHUNK>>(
184 wChunk, bChunk, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
185 adf::source(k[i]) =
"qgemm.cc";
186 adf::headers(k[i]) = {
"qgemm.h"};
187 adf::runtime<ratio>(k[i]) = 0.6;
188 adf::heap_size(k[i]) = K + 1024;
191 adf::location<adf::kernel>(k[i]) = adf::location<adf::kernel>(k[i-1]) + adf::relative_offset({.col_offset=0, .row_offset=1});
193 if (i == 2 || i == 6) {
194 adf::location<adf::kernel>(k[i]) = adf::location<adf::kernel>(k[i-1]) + adf::relative_offset({.col_offset=0, .row_offset=2});
197 adf::location_constraint tilePos = adf::location<adf::kernel>(k[i]);
198 adf::location<adf::parameter>(k[i].param[0]) = tilePos;
199 adf::location<adf::parameter>(k[i].param[0]) = adf::offset(0);
200 adf::location<adf::parameter>(k[i].param[1]) = tilePos;
201 adf::location<adf::parameter>(k[i].param[1]) = adf::offset((K*NCHUNK+31)/32*32);
205 for (
int i = 0; i < CHUNK_COUNT; i++) {
206 adf::connect<adf::stream> (pin[0], k[i].in[0]);
207 adf::connect<adf::stream> (k[i].out[0], concat_g.pin[i]);
208 adf::samples_per_iteration(k[i].out[0]) = M*NCHUNK;
210 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
212 for (
int i = 0; i < concat_g.k1.size(); i++) {
213 adf::location<adf::kernel>(concat_g.k1[i]) =
214 adf::location<adf::kernel>(k[i*2+1]) + adf::relative_offset({.col_offset=0, .row_offset=1});