onnx2versal
Loading...
Searching...
No Matches
graph_concat.h
1#ifndef __CONCAT_GRAPH_H__
2#define __CONCAT_GRAPH_H__
3
4#include <assert.h>
5#include <adf.h>
6#include "concat.h"
7
35template <template<typename, int, int, int, int> class CONCAT,
36 typename TT, int LCNT, int H, int INP_W, int OUT_W>
37class ConcatGraph : public adf::graph {
38
39 public:
40 adf::kernel k[1];
41 adf::port<input> pin[LCNT];
42 adf::port<output> pout[1];
43
44 ConcatGraph() {
45 static_assert(LCNT <= 8);
46 k[0] = adf::kernel::create_object<CONCAT<TT, LCNT, H, INP_W, OUT_W>>();
47 adf::source(k[0]) = "concat.cc";
48 adf::headers(k[0]) = {"concat.h"};
49 adf::runtime<ratio>(k[0]) = 0.6;
50
51 for (int i = 0; i < LCNT; i++) {
52 adf::connect<adf::window<H*INP_W*sizeof(TT)>> (pin[i], k[0].in[i]);
53 adf::single_buffer(k[0].in[i]);
54 }
55
56 // OUT_W <= H*INP_W
57 adf::connect<adf::stream> (k[0].out[0], pout[0]);
58 adf::samples_per_iteration(k[0].out[0]) = H*OUT_W;
59 }
60
61};
62
63
64template <template<typename, int, int, int, int> class CONCAT_STREAM,
65 typename TT, int LCNT, int H, int INP_W, int OUT_W>
66class ConcatStreamGraph : public adf::graph {
67
68 private:
69 static constexpr int L1_LCNT = LCNT / 2;
70
71 public:
72 adf::vector<adf::kernel> k1;
73 adf::vector<adf::kernel> k2;
74 adf::vector<adf::kernel> k;
75
76 adf::port<input> pin[LCNT];
77 adf::port<output> pout[1];
78
79 template<int mINP_W1, int mINP_W2, int mOUT_W>
80 adf::kernel create_concat_kernel() {
81 adf::kernel new_k = adf::kernel::create_object<CONCAT_STREAM<TT, H, mINP_W1, mINP_W2, mOUT_W>>();
82 adf::source(new_k) = "concat.cc";
83 adf::headers(new_k) = {"concat.h"};
84 adf::runtime<ratio>(new_k) = 0.6;
85 return new_k;
86 }
87
88 // separate tiles since each tile only 2 input streams
89 ConcatStreamGraph() {
90 static_assert(LCNT <= 8 && LCNT > 1);
91
92 adf::kernel _k;
93 for (int i = 0; i < LCNT-1; i+=2) {
94 _k = create_concat_kernel<INP_W, INP_W, 2*INP_W>();
95 k1.push_back(_k);
96 adf::connect<adf::stream> (pin[i], _k.in[0]);
97 adf::connect<adf::stream> (pin[i+1], _k.in[1]);
98 adf::samples_per_iteration(_k.in[0]) = H*INP_W;
99 adf::samples_per_iteration(_k.in[1]) = H*INP_W;
100 adf::location<adf::stack> (_k) = adf::location<adf::kernel>(_k);
101 }
102 if (LCNT > 3) { // kernel in loop will be included during compilation otherwise
103 for (int i = 0; i < LCNT-3; i+=4) {
104 _k = create_concat_kernel<2*INP_W, 2*INP_W, 4*INP_W>();
105 k2.push_back(_k);
106 adf::connect<adf::stream> (k1[i/2].out[0], _k.in[0]);
107 adf::connect<adf::stream> (k1[i/2+1].out[0], _k.in[1]);
108 adf::samples_per_iteration(_k.in[0]) = H*2*INP_W;
109 adf::samples_per_iteration(_k.in[1]) = H*2*INP_W;
110 adf::location<adf::stack> (_k) = adf::location<adf::kernel>(_k);
111 }
112 }
113
114
115 if (LCNT == 2) {
116 adf::connect<adf::stream> (k1[0].out[0], pout[0]);
117 adf::samples_per_iteration(k1[0].out[0]) = H*OUT_W;
118 }
119 else if (LCNT == 3) {
120 _k = create_concat_kernel<2*INP_W, INP_W, OUT_W>();
121 k.push_back(_k);
122 adf::connect<adf::stream> (k1[0].out[0], _k.in[0]);
123 adf::connect<adf::stream> (pin[LCNT-1], _k.in[1]);
124
125 adf::connect<adf::stream> (_k.out[0], pout[0]);
126 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
127 }
128 else if (LCNT == 4) {
129 adf::connect<adf::stream> (k2[0].out[0], pout[0]);
130 adf::samples_per_iteration(k2[0].out[0]) = H*OUT_W;
131 }
132 else if (LCNT == 5) {
133 _k = create_concat_kernel<4*INP_W, INP_W, OUT_W>();
134 k.push_back(_k);
135 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
136 adf::connect<adf::stream> (pin[LCNT-1], _k.in[1]);
137
138 adf::connect<adf::stream> (_k.out[0], pout[0]);
139 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
140 }
141 else if (LCNT == 6) {
142 _k = create_concat_kernel<4*INP_W, 2*INP_W, OUT_W>();
143 k.push_back(_k);
144 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
145 adf::connect<adf::stream> (k1[2].out[0], _k.in[1]);
146
147 adf::connect<adf::stream> (_k.out[0], pout[0]);
148 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
149 }
150 else if (LCNT == 7) {
151 adf::kernel _k1 = create_concat_kernel<2*INP_W, INP_W, 4*INP_W>();
152 k.push_back(_k1);
153 adf::connect<adf::stream> (k1[2].out[0], _k1.in[0]);
154 adf::connect<adf::stream> (pin[LCNT-1], _k1.in[1]);
155
156 adf::kernel _k2 = create_concat_kernel<4*INP_W, 4*INP_W, OUT_W>();
157 k.push_back(_k2);
158 adf::connect<adf::stream> (k2[0].out[0], _k2.in[0]);
159 adf::connect<adf::stream> (_k1.out[0], _k2.in[1]);
160
161 adf::connect<adf::stream> (_k2.out[0], pout[0]);
162 adf::samples_per_iteration(_k2.out[0]) = H*OUT_W;
163 }
164 else if (LCNT == 8) {
165 _k = create_concat_kernel<4*INP_W, 4*INP_W, OUT_W>();
166 k.push_back(_k);
167 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
168 adf::connect<adf::stream> (k2[1].out[0], _k.in[1]);
169
170 adf::connect<adf::stream> (_k.out[0], pout[0]);
171 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
172 }
173
174 }
175
176};
177
178
179template <template<typename> class CONCAT_STREAM,
180 typename TT, int LCNT, int H, int INP_W, int OUT_W>
181class ConcatStreamSequentiallyGraph : public adf::graph {
182
183 public:
184 adf::kernel k[LCNT-1];
185 adf::port<input> pin[LCNT];
186 adf::port<output> pout[1];
187
188 adf::kernel create_concat_kernel(
189 int mINP_W1,
190 int mINP_W2,
191 int mOUT_W
192 ) {
193 adf::kernel new_k = adf::kernel::create_object<CONCAT_STREAM<TT>>(H, mINP_W1, mINP_W2, mOUT_W);
194 adf::source(new_k) = "concat.cc";
195 adf::headers(new_k) = {"concat.h"};
196 adf::runtime<ratio>(new_k) = 0.6;
197 return new_k;
198 }
199
200 // separate tiles since each tile only 2 input streams
201 ConcatStreamSequentiallyGraph() {
202 static_assert(LCNT >= 2);
203
204 adf::kernel _k;
205
206 k[0] = create_concat_kernel(INP_W, INP_W, 2*INP_W);
207 adf::connect<adf::stream> (pin[0], k[0].in[0]);
208 adf::connect<adf::stream> (pin[1], k[0].in[1]);
209
210 for (int i = 1; i < LCNT-1; i++) {
211 k[i] = create_concat_kernel((i+1)*INP_W, INP_W, (i+2)*INP_W);
212 adf::connect<adf::stream> (k[i-1].out[0], k[i].in[0]);
213 adf::connect<adf::stream> (pin[i+1], k[i].in[1]);
214 }
215
216 adf::connect<adf::stream> (k[LCNT-2].out[0], pout[0]);
217
218 }
219
220};
221
222
231template <template<typename, int, int, int, int> class CONCAT,
232 typename TT, int LCNT, int H, int INP_W, int OUT_W>
233class ConcatTwoStreamGraph : public adf::graph {
234
235 public:
236 adf::kernel k[1];
237 adf::port<input> pin[2];
238 adf::port<output> pout[1];
239
241 k[0] = adf::kernel::create_object<CONCAT<TT, LCNT, H, INP_W, OUT_W>>();
242 adf::source(k[0]) = "concat.cc";
243 adf::headers(k[0]) = {"concat.h"};
244 adf::runtime<ratio>(k[0]) = 0.6;
245
246 adf::connect<adf::stream> (pin[0], k[0].in[0]);
247 adf::connect<adf::stream> (pin[1], k[0].in[1]);
248 adf::connect<adf::stream> (k[0].out[0], pout[0]);
249
250 adf::samples_per_iteration(k[0].in[0]) = H*INP_W* ((LCNT+1)/2);
251 adf::samples_per_iteration(k[0].in[1]) = H*INP_W * (LCNT/2);
252 adf::samples_per_iteration(k[0].out[0]) = H*OUT_W;
253 }
254
255};
256
257
258template <template<typename, int, int, int, int> class CONCAT,
259 typename TT, int LCNT, int H, int INP_W, int OUT_W>
260class ConcatTwiceGraph : public adf::graph {
261
262 public:
263 static constexpr int CONCAT_CNT = (LCNT+7)/8;
264 static constexpr int LCNT_REM = (LCNT % 8 == 0) ? 8 : LCNT % 8;
265 static constexpr int INNER_OUT_W = INP_W * 8;
266 adf::kernel k[CONCAT_CNT];
267 adf::kernel klast;
268 adf::port<input> pin[LCNT];
269 adf::port<output> pout[1];
270
271 ConcatTwiceGraph() {
272 static_assert(LCNT <= 64);
273 klast = adf::kernel::create_object<CONCAT<TT, CONCAT_CNT, H, INNER_OUT_W, OUT_W>>();
274 adf::source(klast) = "concat.cc";
275 adf::headers(klast) = {"concat.h"};
276 adf::runtime<ratio>(klast) = 0.6;
277
278 // intermediate concats
279 if (CONCAT_CNT > 1) { // register kernel error if CONCAT_CNT=1
280 for (int i = 0; i < CONCAT_CNT - 1; i++) {
281 k[i] = adf::kernel::create_object<CONCAT<TT, 8, H, INP_W, INNER_OUT_W>>();
282 adf::source(k[i]) = "concat.cc";
283 adf::headers(k[i]) = {"concat.h"};
284 adf::runtime<ratio>(k[i]) = 0.6;
285
286 for (int j = 0; j < 8; j++)
287 adf::connect<adf::window<H*INP_W*sizeof(TT)>> (pin[i*8+j], k[i].in[j]);
288 adf::connect<adf::stream> (k[i].out[0], klast.in[i]);
289 }
290 }
291
292 int i = CONCAT_CNT - 1;
293 // remainder intermediate concat
294 k[i] = adf::kernel::create_object<CONCAT<TT, LCNT_REM, H, INP_W, INNER_OUT_W>>();
295 adf::source(k[i]) = "concat.cc";
296 adf::headers(k[i]) = {"concat.h"};
297 adf::runtime<ratio>(k[i]) = 0.6;
298
299 for (int j = 0; j < LCNT_REM; j++)
300 adf::connect<adf::window<H*INP_W*sizeof(TT)>> (pin[i*8+j], k[i].in[j]);
301 adf::connect<adf::stream> (k[i].out[0], klast.in[i]);
302
303 adf::connect<adf::stream> (klast.out[0], pout[0]);
304 }
305
306};
310#endif // __CONCAT_GRAPH_H__
Graph wrapper for arbitrary concat kernel implementation and lanes.
Definition graph_concat.h:37
Graph wrapper for concatenating two chunked streams, inverse of SplitTwoStreamGraph.
Definition graph_concat.h:233