1#ifndef __CONCAT_GRAPH_H__
2#define __CONCAT_GRAPH_H__
35template <
template<
typename,
int,
int,
int,
int>
class CONCAT,
36 typename TT,
int LCNT,
int H,
int INP_W,
int OUT_W>
41 adf::port<input> pin[LCNT];
42 adf::port<output> pout[1];
45 static_assert(LCNT <= 8);
46 k[0] = adf::kernel::create_object<CONCAT<TT, LCNT, H, INP_W, OUT_W>>();
47 adf::source(k[0]) =
"concat.cc";
48 adf::headers(k[0]) = {
"concat.h"};
49 adf::runtime<ratio>(k[0]) = 0.6;
51 for (
int i = 0; i < LCNT; i++) {
52 adf::connect<adf::window<H*INP_W*
sizeof(TT)>> (pin[i], k[0].in[i]);
53 adf::single_buffer(k[0].in[i]);
57 adf::connect<adf::stream> (k[0].out[0], pout[0]);
58 adf::samples_per_iteration(k[0].out[0]) = H*OUT_W;
64template <
template<
typename,
int,
int,
int,
int>
class CONCAT_STREAM,
65 typename TT,
int LCNT,
int H,
int INP_W,
int OUT_W>
66class ConcatStreamGraph :
public adf::graph {
69 static constexpr int L1_LCNT = LCNT / 2;
72 adf::vector<adf::kernel> k1;
73 adf::vector<adf::kernel> k2;
74 adf::vector<adf::kernel> k;
76 adf::port<input> pin[LCNT];
77 adf::port<output> pout[1];
79 template<
int mINP_W1,
int mINP_W2,
int mOUT_W>
80 adf::kernel create_concat_kernel() {
81 adf::kernel new_k = adf::kernel::create_object<CONCAT_STREAM<TT, H, mINP_W1, mINP_W2, mOUT_W>>();
82 adf::source(new_k) =
"concat.cc";
83 adf::headers(new_k) = {
"concat.h"};
84 adf::runtime<ratio>(new_k) = 0.6;
90 static_assert(LCNT <= 8 && LCNT > 1);
93 for (
int i = 0; i < LCNT-1; i+=2) {
94 _k = create_concat_kernel<INP_W, INP_W, 2*INP_W>();
96 adf::connect<adf::stream> (pin[i], _k.in[0]);
97 adf::connect<adf::stream> (pin[i+1], _k.in[1]);
98 adf::samples_per_iteration(_k.in[0]) = H*INP_W;
99 adf::samples_per_iteration(_k.in[1]) = H*INP_W;
100 adf::location<adf::stack> (_k) = adf::location<adf::kernel>(_k);
103 for (
int i = 0; i < LCNT-3; i+=4) {
104 _k = create_concat_kernel<2*INP_W, 2*INP_W, 4*INP_W>();
106 adf::connect<adf::stream> (k1[i/2].out[0], _k.in[0]);
107 adf::connect<adf::stream> (k1[i/2+1].out[0], _k.in[1]);
108 adf::samples_per_iteration(_k.in[0]) = H*2*INP_W;
109 adf::samples_per_iteration(_k.in[1]) = H*2*INP_W;
110 adf::location<adf::stack> (_k) = adf::location<adf::kernel>(_k);
116 adf::connect<adf::stream> (k1[0].out[0], pout[0]);
117 adf::samples_per_iteration(k1[0].out[0]) = H*OUT_W;
119 else if (LCNT == 3) {
120 _k = create_concat_kernel<2*INP_W, INP_W, OUT_W>();
122 adf::connect<adf::stream> (k1[0].out[0], _k.in[0]);
123 adf::connect<adf::stream> (pin[LCNT-1], _k.in[1]);
125 adf::connect<adf::stream> (_k.out[0], pout[0]);
126 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
128 else if (LCNT == 4) {
129 adf::connect<adf::stream> (k2[0].out[0], pout[0]);
130 adf::samples_per_iteration(k2[0].out[0]) = H*OUT_W;
132 else if (LCNT == 5) {
133 _k = create_concat_kernel<4*INP_W, INP_W, OUT_W>();
135 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
136 adf::connect<adf::stream> (pin[LCNT-1], _k.in[1]);
138 adf::connect<adf::stream> (_k.out[0], pout[0]);
139 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
141 else if (LCNT == 6) {
142 _k = create_concat_kernel<4*INP_W, 2*INP_W, OUT_W>();
144 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
145 adf::connect<adf::stream> (k1[2].out[0], _k.in[1]);
147 adf::connect<adf::stream> (_k.out[0], pout[0]);
148 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
150 else if (LCNT == 7) {
151 adf::kernel _k1 = create_concat_kernel<2*INP_W, INP_W, 4*INP_W>();
153 adf::connect<adf::stream> (k1[2].out[0], _k1.in[0]);
154 adf::connect<adf::stream> (pin[LCNT-1], _k1.in[1]);
156 adf::kernel _k2 = create_concat_kernel<4*INP_W, 4*INP_W, OUT_W>();
158 adf::connect<adf::stream> (k2[0].out[0], _k2.in[0]);
159 adf::connect<adf::stream> (_k1.out[0], _k2.in[1]);
161 adf::connect<adf::stream> (_k2.out[0], pout[0]);
162 adf::samples_per_iteration(_k2.out[0]) = H*OUT_W;
164 else if (LCNT == 8) {
165 _k = create_concat_kernel<4*INP_W, 4*INP_W, OUT_W>();
167 adf::connect<adf::stream> (k2[0].out[0], _k.in[0]);
168 adf::connect<adf::stream> (k2[1].out[0], _k.in[1]);
170 adf::connect<adf::stream> (_k.out[0], pout[0]);
171 adf::samples_per_iteration(_k.out[0]) = H*OUT_W;
179template <
template<
typename>
class CONCAT_STREAM,
180 typename TT,
int LCNT,
int H,
int INP_W,
int OUT_W>
181class ConcatStreamSequentiallyGraph :
public adf::graph {
184 adf::kernel k[LCNT-1];
185 adf::port<input> pin[LCNT];
186 adf::port<output> pout[1];
188 adf::kernel create_concat_kernel(
193 adf::kernel new_k = adf::kernel::create_object<CONCAT_STREAM<TT>>(H, mINP_W1, mINP_W2, mOUT_W);
194 adf::source(new_k) =
"concat.cc";
195 adf::headers(new_k) = {
"concat.h"};
196 adf::runtime<ratio>(new_k) = 0.6;
201 ConcatStreamSequentiallyGraph() {
202 static_assert(LCNT >= 2);
206 k[0] = create_concat_kernel(INP_W, INP_W, 2*INP_W);
207 adf::connect<adf::stream> (pin[0], k[0].in[0]);
208 adf::connect<adf::stream> (pin[1], k[0].in[1]);
210 for (
int i = 1; i < LCNT-1; i++) {
211 k[i] = create_concat_kernel((i+1)*INP_W, INP_W, (i+2)*INP_W);
212 adf::connect<adf::stream> (k[i-1].out[0], k[i].in[0]);
213 adf::connect<adf::stream> (pin[i+1], k[i].in[1]);
216 adf::connect<adf::stream> (k[LCNT-2].out[0], pout[0]);
231template <
template<
typename,
int,
int,
int,
int>
class CONCAT,
232 typename TT,
int LCNT,
int H,
int INP_W,
int OUT_W>
237 adf::port<input> pin[2];
238 adf::port<output> pout[1];
241 k[0] = adf::kernel::create_object<CONCAT<TT, LCNT, H, INP_W, OUT_W>>();
242 adf::source(k[0]) =
"concat.cc";
243 adf::headers(k[0]) = {
"concat.h"};
244 adf::runtime<ratio>(k[0]) = 0.6;
246 adf::connect<adf::stream> (pin[0], k[0].in[0]);
247 adf::connect<adf::stream> (pin[1], k[0].in[1]);
248 adf::connect<adf::stream> (k[0].out[0], pout[0]);
250 adf::samples_per_iteration(k[0].in[0]) = H*INP_W* ((LCNT+1)/2);
251 adf::samples_per_iteration(k[0].in[1]) = H*INP_W * (LCNT/2);
252 adf::samples_per_iteration(k[0].out[0]) = H*OUT_W;
258template <
template<
typename,
int,
int,
int,
int>
class CONCAT,
259 typename TT,
int LCNT,
int H,
int INP_W,
int OUT_W>
260class ConcatTwiceGraph :
public adf::graph {
263 static constexpr int CONCAT_CNT = (LCNT+7)/8;
264 static constexpr int LCNT_REM = (LCNT % 8 == 0) ? 8 : LCNT % 8;
265 static constexpr int INNER_OUT_W = INP_W * 8;
266 adf::kernel k[CONCAT_CNT];
268 adf::port<input> pin[LCNT];
269 adf::port<output> pout[1];
272 static_assert(LCNT <= 64);
273 klast = adf::kernel::create_object<CONCAT<TT, CONCAT_CNT, H, INNER_OUT_W, OUT_W>>();
274 adf::source(klast) =
"concat.cc";
275 adf::headers(klast) = {
"concat.h"};
276 adf::runtime<ratio>(klast) = 0.6;
279 if (CONCAT_CNT > 1) {
280 for (
int i = 0; i < CONCAT_CNT - 1; i++) {
281 k[i] = adf::kernel::create_object<CONCAT<TT, 8, H, INP_W, INNER_OUT_W>>();
282 adf::source(k[i]) =
"concat.cc";
283 adf::headers(k[i]) = {
"concat.h"};
284 adf::runtime<ratio>(k[i]) = 0.6;
286 for (
int j = 0; j < 8; j++)
287 adf::connect<adf::window<H*INP_W*
sizeof(TT)>> (pin[i*8+j], k[i].in[j]);
288 adf::connect<adf::stream> (k[i].out[0], klast.in[i]);
292 int i = CONCAT_CNT - 1;
294 k[i] = adf::kernel::create_object<CONCAT<TT, LCNT_REM, H, INP_W, INNER_OUT_W>>();
295 adf::source(k[i]) =
"concat.cc";
296 adf::headers(k[i]) = {
"concat.h"};
297 adf::runtime<ratio>(k[i]) = 0.6;
299 for (
int j = 0; j < LCNT_REM; j++)
300 adf::connect<adf::window<H*INP_W*
sizeof(TT)>> (pin[i*8+j], k[i].in[j]);
301 adf::connect<adf::stream> (k[i].out[0], klast.in[i]);
303 adf::connect<adf::stream> (klast.out[0], pout[0]);
Graph wrapper for arbitrary concat kernel implementation and lanes.
Definition graph_concat.h:37
Graph wrapper for concatenating two chunked streams, inverse of SplitTwoStreamGraph.
Definition graph_concat.h:233