onnx2versal
Loading...
Searching...
No Matches
graph_qgemm.h
1#ifndef __QGEMM_GRAPH_H_
2#define __QGEMM_GRAPH_H_
3
4#include <adf.h>
5#include "qgemm.h"
6#include "graph_concat.h"
7#include "graph_utils.h"
8
9
36template <template<typename, typename, int, int, int> class QGEMM,
37 typename TT, typename TTPARAM, int M, int K, int N>
38class QgemmGraph : public adf::graph {
39
40 private:
41 adf::kernel k[1];
42
43 public:
44 adf::port<input> pin[1];
45 adf::port<output> pout[1];
46
48 std::vector<TTPARAM> weights,
49 std::vector<int32_t> bias,
50 float x_scale,
51 float w_scale,
52 float y_scale,
53 TT x_zero,
54 TTPARAM w_zero,
55 TT y_zero
56 ) {
57 k[0] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, N>>(
58 weights, bias, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
59 adf::source(k[0]) = "qgemm.cc";
60 adf::headers(k[0]) = {"qgemm.h"};
61 adf::runtime<ratio>(k[0]) = 0.6;
62 adf::heap_size(k[0]) = K + 1024;
63
64 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
65 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
66 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
67
68 adf::connect<adf::stream> (pin[0], k[0].in[0]);
69 adf::connect<adf::stream> (k[0].out[0], pout[0]);
70
71 adf::samples_per_iteration(k[0].in[0]) = M*K;
72 adf::samples_per_iteration(k[0].out[0]) = M*N;
73 }
74
75};
76
77
78
88template <template<typename, typename, int, int, int> class QGEMM,
89 typename TT, typename TTPARAM, int M, int K, int N>
90class QgemmStreamGraph : public adf::graph {
91
92 private:
93 adf::kernel k[1];
94
95 public:
96 adf::port<input> pin[2];
97 adf::port<output> pout[1];
98
100 std::vector<int32_t> bias,
101 float x_scale,
102 float w_scale,
103 float y_scale,
104 TT x_zero,
105 TTPARAM w_zero,
106 TT y_zero
107 ) {
108 k[0] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, N>>(
109 bias, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
110 adf::source(k[0]) = "qgemm.cc";
111 adf::headers(k[0]) = {"qgemm.h"};
112 adf::runtime<ratio>(k[0]) = 0.6;
113 adf::heap_size(k[0]) = 24576; // assume KxN > MxN
114
115 adf::connect<adf::stream> (pin[0], k[0].in[0]);
116 adf::connect<adf::stream> (pin[1], k[0].in[1]);
117 adf::connect<adf::stream> (k[0].out[0], pout[0]);
118 adf::samples_per_iteration(k[0].in[0]) = M*K;
119 adf::samples_per_iteration(k[0].out[0]) = M*N;
120
121 adf::location_constraint tilePos = adf::location<adf::kernel>(k[0]);
122 adf::location<adf::parameter>(k[0].param[0]) = tilePos;
123 adf::location<adf::parameter>(k[0].param[0]) = adf::offset(0);
124 }
125
126};
127
128
136template <
137 template<typename, typename, int, int, int> class QGEMM,
138 template<typename, int, int, int, int> class CONCAT,
139 int NCHUNK,
140 typename TT, typename TTPARAM, int M, int K, int N>
141class QgemmChunkNGraph : public adf::graph {
142
143 private:
144
145 public:
146 static const int CHUNK_COUNT = (N + NCHUNK - 1) / NCHUNK; // ceiling
147 adf::kernel k[CHUNK_COUNT];
148 ConcatStreamGraph<CONCAT, TT, CHUNK_COUNT, M, NCHUNK, N> concat_g;
149
150 adf::port<input> pin[1];
151 adf::port<output> pout[1];
152
154 std::vector<TTPARAM> weights, // KxN
155 std::vector<int32_t> bias, // N
156 float x_scale,
157 float w_scale,
158 float y_scale,
159 TT x_zero,
160 TTPARAM w_zero,
161 TT y_zero
162 ) {
163 static_assert(CHUNK_COUNT <= 8);
164 static_assert(M*K <= TILE_BYTES);
165 static_assert(K*NCHUNK <= MAX_PARAM_BYTES);
166 static_assert(M*NCHUNK <= TILE_BYTES);
167
168 std::vector<int32_t> bChunk;
169
170 for (int i = 0; i < CHUNK_COUNT; i++) {
171
172 // build wchunk
173 std::vector<TTPARAM> wChunk;
174 wChunk.reserve(NCHUNK*K);
175 for (int j = 0; j < K*N; j+=N) {
176 wChunk.insert(wChunk.end(), weights.begin()+j+i*NCHUNK, weights.begin()+j+i*NCHUNK+NCHUNK);
177 }
178
179 // build bChunk
180 bChunk = std::vector<int32_t>(bias.begin()+i*NCHUNK, bias.begin()+i*NCHUNK+NCHUNK);
181 bChunk.resize(NCHUNK, 0);
182
183 k[i] = adf::kernel::create_object<QGEMM<TT, TTPARAM, M, K, NCHUNK>>(
184 wChunk, bChunk, x_scale, w_scale, y_scale, x_zero, w_zero, y_zero);
185 adf::source(k[i]) = "qgemm.cc";
186 adf::headers(k[i]) = {"qgemm.h"};
187 adf::runtime<ratio>(k[i]) = 0.6;
188 adf::heap_size(k[i]) = K + 1024;
189
190 if ((i&0x1) == 1) {
191 adf::location<adf::kernel>(k[i]) = adf::location<adf::kernel>(k[i-1]) + adf::relative_offset({.col_offset=0, .row_offset=1});
192 }
193 if (i == 2 || i == 6) {
194 adf::location<adf::kernel>(k[i]) = adf::location<adf::kernel>(k[i-1]) + adf::relative_offset({.col_offset=0, .row_offset=2});
195 }
196
197 adf::location_constraint tilePos = adf::location<adf::kernel>(k[i]);
198 adf::location<adf::parameter>(k[i].param[0]) = tilePos;
199 adf::location<adf::parameter>(k[i].param[0]) = adf::offset(0);
200 adf::location<adf::parameter>(k[i].param[1]) = tilePos;
201 adf::location<adf::parameter>(k[i].param[1]) = adf::offset((K*NCHUNK+31)/32*32);
202 // arbitrary input/output buffer location due to interconnect design
203 }
204
205 for (int i = 0; i < CHUNK_COUNT; i++) {
206 adf::connect<adf::stream> (pin[0], k[i].in[0]);
207 adf::connect<adf::stream> (k[i].out[0], concat_g.pin[i]);
208 adf::samples_per_iteration(k[i].out[0]) = M*NCHUNK;
209 }
210 adf::connect<adf::stream> (concat_g.pout[0], pout[0]);
211
212 for (int i = 0; i < concat_g.k1.size(); i++) {
213 adf::location<adf::kernel>(concat_g.k1[i]) =
214 adf::location<adf::kernel>(k[i*2+1]) + adf::relative_offset({.col_offset=0, .row_offset=1});
215 }
216 }
217
218};
222#endif // __QGEMM_GRAPH_H_
Multiinstance graph for MxK times KxN that stores weights and biases Requires KxN_RND weight,...
Definition graph_qgemm.h:141
Single instance graph that stores weights and biases.
Definition graph_qgemm.h:38
Single instance graph that stores weights and biases.
Definition graph_qgemm.h:90