38template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
39 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
43 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
44 static constexpr int C_PER_M = C / GROUP;
45 alignas(32)
float (&weights)[M*KH*KW*C];
46 alignas(32)
float (&bias)[M];
50 float (&w)[M*KH*KW*C],
52 ): weights(w), bias(b) {};
55 input_window<float>* in,
56 output_window<float>* out
59 static void registerKernelClass() {
60 static_assert(C % GROUP == 0);
61 REGISTER_FUNCTION(ConvReluScalar::filter);
62 REGISTER_PARAMETER(weights);
63 REGISTER_PARAMETER(bias);
75template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
76 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
80 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
81 alignas(32)
float (&weights)[M*C*KH*8];
82 alignas(32)
float (&bias)[M];
88 ): weights(w), bias(b) {};
91 input_window<float>* in,
92 output_window<float>* out
95 static void registerKernelClass() {
96 static_assert(GROUP == 1);
99 static_assert(INP_W%4==0);
100 static_assert(OUT_W_PAD%8==0);
101 static_assert(STEP_H == 1 && STEP_W == 1);
102 REGISTER_FUNCTION(Conv5x5on8Relu::filter);
103 REGISTER_PARAMETER(weights);
104 REGISTER_PARAMETER(bias);
116template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
117 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
121 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
122 static constexpr int CKK_ROW_SIZE = C*((KH*KW+3)/4*4);
124 alignas(32)
float (&weights)[M*CKK_ROW_SIZE];
125 alignas(32)
float (&bias)[M];
129 float (&w)[M*CKK_ROW_SIZE],
131 ): weights(w), bias(b) {};
134 input_window<float>* in,
135 output_window<float>* out
138 static void registerKernelClass() {
139 static_assert(GROUP == 1);
140 static_assert(KW<=4);
141 static_assert(INP_W%4==0);
142 static_assert(OUT_W_PAD%8==0);
143 static_assert(STEP_H == 1 && STEP_W == 1);
144 REGISTER_FUNCTION(ConvHx4Relu::filter);
145 REGISTER_PARAMETER(weights);
146 REGISTER_PARAMETER(bias);
157template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
158 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
162 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
164 alignas(32)
float (&weights)[M*C];
165 alignas(32)
float (&bias)[M];
171 ): weights(w), bias(b) {};
174 input_window<float>* in,
175 output_window<float>* out
178 static void registerKernelClass() {
179 static_assert(GROUP == 1);
180 static_assert(KH==1);
181 static_assert(KW==1);
182 static_assert(INP_W%4==0);
183 static_assert(OUT_W_PAD%8==0 && STEP_W==1 || OUT_W_PAD%4==0 && STEP_W==2);
184 static_assert(STEP_H == 1 || STEP_H == 2);
185 static_assert(STEP_W == 1 || STEP_W == 2);
186 REGISTER_FUNCTION(Conv1x1Relu::filter);
187 REGISTER_PARAMETER(weights);
188 REGISTER_PARAMETER(bias);
200template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
201 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
205 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
206 static constexpr int C_PER_M = C / GROUP;
207 static constexpr int CKK_ROW_SIZE = C_PER_M*KH*KW;
208 alignas(32)
float (&bias)[M];
209 alignas(32)
float ckk_row[CKK_ROW_SIZE];
217 input_window<float>* in,
218 input_stream<float>* weights,
219 output_stream<float>* out
222 static void registerKernelClass() {
223 static_assert(C % GROUP == 0);
224 REGISTER_FUNCTION(ConvReluScalarStream::filter);
225 REGISTER_PARAMETER(bias);
236template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
237 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
241 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
242 static constexpr int C_PER_M = C / GROUP;
243 static constexpr int CKK_ROW_SIZE = C_PER_M*KH*8;
244 alignas(32)
float (&bias)[M];
245 alignas(32)
float ckk_row[CKK_ROW_SIZE];
246 alignas(32)
float width_row[OUT_W_PAD];
254 input_window<float>* in,
255 input_stream<float>* weights,
256 output_stream<float>* out
259 static void registerKernelClass() {
260 static_assert(C % GROUP == 0);
261 static_assert(KW<=8);
262 static_assert(INP_W%4==0);
263 static_assert(OUT_W_PAD%8==0);
264 static_assert(STEP_H == 1 && STEP_W == 1);
265 REGISTER_FUNCTION(ConvHx8ReluStream::filter);
266 REGISTER_PARAMETER(bias);
279template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
280 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
284 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
285 static constexpr int C_PER_M = C / GROUP;
286 static constexpr int CKK_ROW_SIZE = C_PER_M*((KH*KW+3)/4*4);
287 static constexpr unsigned int X_OFFSET = (STEP_W == 1) ? 0x76543210 : ((STEP_W == 2) ? 0x00006420 : 0x0000c840);
288 static constexpr int W_LOOP_STEP = (STEP_W == 1) ? 8 : 4;
289 static constexpr int W_LOOP_IN_STEP = (STEP_W != 4) ? 8 : 16;
291 alignas(32)
float (&bias)[M];
292 alignas(32)
float ckk_row[CKK_ROW_SIZE];
300 input_window<float>* in,
301 input_stream<float>* weights,
302 output_stream<float>* out
305 static void registerKernelClass() {
306 static_assert(KW<=4);
307 static_assert(INP_W%4==0);
308 static_assert(OUT_W_PAD%8==0 && STEP_W==1 || OUT_W_PAD%4==0 && STEP_W==2 || OUT_W_PAD%4==0 && STEP_W == 4);
309 static_assert(STEP_H == 1 || STEP_H == 2 || STEP_H == 4);
310 static_assert(STEP_W == 1 || STEP_W == 2 || STEP_W == 4);
311 REGISTER_FUNCTION(ConvHx4ReluStream::filter);
312 REGISTER_PARAMETER(bias);
323template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
324 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
328 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
329 static constexpr int CKK_ROW_SIZE = C*((KH*KW+3)/4*4);
330 static constexpr unsigned int X_OFFSET = 0x76543210;
332 alignas(32)
float (&bias)[M];
333 alignas(32)
float ckk_row[CKK_ROW_SIZE];
334 alignas(32)
float out_row[OUT_W_PAD];
342 input_window<float>* in,
343 input_stream<float>* weights,
344 output_stream<float>* out
347 static void registerKernelClass() {
348 static_assert(GROUP == 1);
349 static_assert(KH==3);
350 static_assert(KW==3);
351 static_assert(INP_W%4==0);
352 static_assert(OUT_W_PAD%8==0);
353 static_assert(STEP_H == 1);
354 static_assert(STEP_W == 1);
355 REGISTER_FUNCTION(ConvHx4ReluStreamMultiRow::filter);
356 REGISTER_PARAMETER(bias);
366template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
367 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
371 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
372 static constexpr int C_PER_M = C / GROUP;
373 static constexpr int CKK_ROW_SIZE = C_PER_M*((KH*KW+3)/4*4);
375 alignas(32)
float (&bias)[M];
376 alignas(32)
float ckk_row[CKK_ROW_SIZE];
384 input_window<float>* in,
385 input_stream<float>* weights,
386 output_stream<float>* out
389 static void registerKernelClass() {
390 static_assert(KW<=4);
391 static_assert(INP_W%4==0);
392 static_assert(OUT_W_PAD == 4);
393 static_assert(STEP_H == 1);
394 static_assert(STEP_W == 1);
395 REGISTER_FUNCTION(ConvHx4Out4ReluStream::filter);
396 REGISTER_PARAMETER(bias);
407template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
408 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
412 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
413 static constexpr int CKK_ROW_SIZE = (C+3)/4*4;
415 alignas(32)
float (&bias)[M];
416 alignas(32)
float ckk_row[CKK_ROW_SIZE];
424 input_window<float>* in,
425 input_stream<float>* weights,
426 output_stream<float>* out
429 static void registerKernelClass() {
430 static_assert(GROUP == 1);
431 static_assert(KH==1);
432 static_assert(KW==1);
433 static_assert(INP_W%4==0);
434 static_assert(OUT_W_PAD%8==0 && STEP_W==1 || OUT_W_PAD%4==0 && STEP_W==2);
435 static_assert(STEP_H == 1 || STEP_H == 2);
436 static_assert(STEP_W == 1 || STEP_W == 2);
437 REGISTER_FUNCTION(Conv1x1ReluStream::filter);
438 REGISTER_PARAMETER(bias);
448template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
449 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
453 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
454 static constexpr int C_PER_M = C / GROUP;
455 static constexpr int CKK_ROW_SIZE = (C_PER_M+3)/4*4;
457 alignas(32)
float (&bias)[M];
458 alignas(32)
float ckk_row[CKK_ROW_SIZE];
466 input_window<float>* in,
467 input_stream<float>* weights,
468 output_stream<float>* out
471 static void registerKernelClass() {
472 static_assert(KW<=4);
473 static_assert(INP_W%4==0);
474 static_assert(OUT_W_PAD == 4);
475 static_assert(STEP_H == 1);
476 static_assert(STEP_W == 1);
477 REGISTER_FUNCTION(Conv1x1Out4ReluStream::filter);
478 REGISTER_PARAMETER(bias);
488template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
489 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
493 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
494 static constexpr int C_PER_M = C / GROUP;
495 static constexpr int CKK_ROW_SIZE = C_PER_M*((KH*KW+3)/4*4);
496 static constexpr int INP_SIZE = B*C*INP_H*INP_W;
498 static constexpr unsigned int X_OFFSET = (STEP_W == 1) ? 0x76543210 : ((STEP_W == 2) ? 0x00006420 : 0x0000c840);
499 static constexpr int W_LOOP_STEP = (STEP_W == 1) ? 8 : 4;
500 static constexpr int W_LOOP_IN_STEP = (STEP_W != 4) ? 8 : 16;
502 alignas(32)
float (&bias)[M];
503 alignas(32)
float ckk_row[CKK_ROW_SIZE];
504 alignas(32)
float in[INP_SIZE];
512 input_pktstream* in_s,
513 input_stream<float>* weights,
514 output_stream<float>* out
517 static void registerKernelClass() {
518 static_assert(KW<=4);
519 static_assert(INP_W%4==0);
520 static_assert(OUT_W_PAD%8==0 && STEP_W==1 || OUT_W_PAD%4==0 && STEP_W==2 || OUT_W_PAD%4==0 && STEP_W == 4);
521 static_assert(STEP_H == 1 || STEP_H == 2 || STEP_H == 4);
522 static_assert(STEP_W == 1 || STEP_W == 2 || STEP_W == 4);
523 REGISTER_FUNCTION(ConvHx4ReluPktStream::filter);
524 REGISTER_PARAMETER(bias);
534template <
int INP_H,
int INP_W,
int OUT_W,
int OUT_W_PAD,
int STEP_H,
int STEP_W,
535 int B,
int C,
int M,
int KH,
int KW,
int GROUP,
int IS_RELU>
539 static constexpr int OUT_H = (INP_H - KH) / STEP_H + 1;
540 static constexpr int CKK_ROW_SIZE = (C+3)/4*4;
541 static constexpr int INP_SIZE = B*C*INP_H*INP_W;
543 alignas(32)
float (&bias)[M];
544 alignas(32)
float ckk_row[CKK_ROW_SIZE];
545 alignas(32)
float in[INP_SIZE];
553 input_pktstream* in_s,
554 input_stream<float>* weights,
555 output_stream<float>* out
558 static void registerKernelClass() {
559 static_assert(GROUP == 1);
560 static_assert(KH==1);
561 static_assert(KW==1);
562 static_assert(INP_W%4==0);
563 static_assert(OUT_W_PAD%8==0 && STEP_W==1 || OUT_W_PAD%4==0 && STEP_W==2);
564 static_assert(STEP_H == 1 || STEP_H == 2);
565 static_assert(STEP_W == 1 || STEP_W == 2);
566 REGISTER_FUNCTION(Conv1x1ReluPktStream::filter);
567 REGISTER_PARAMETER(bias);
Vector stream implementation for OUT_W == 4 < 8, stores biases, requires KH==KW==1,...
Definition conv.h:450
Vector stream implementation for BCHW, stores biases, requires KH==KW==1, INP_W%4==0,...
Definition conv.h:536
Vector stream implementation for BCHW, stores biases, requires KH==KW==1, INP_W%4==0,...
Definition conv.h:409
Vector stream implementation for BCHW, stores weights and biases, requires KH==KW==1,...
Definition conv.h:159
Vector implementation for 5x5 BCHW, stores weights and biases, requires KH==KW==5,...
Definition conv.h:77
Vector stream implementation for OUT_W == 4 < 8, stores biases, requires KW<=3, INP_W%4==0,...
Definition conv.h:368
Vector stream implementation for BCHW, stores biases, requires KW<=3, INP_W%4==0, OUT_W_PAD%(8|4)==0,...
Definition conv.h:490
Vector stream implementation for BCHW, stores biases, requires KH==KW==3, INP_W%4==0,...
Definition conv.h:325
Vector stream implementation for BCHW, stores biases, requires KW<=3, INP_W%4==0, OUT_W_PAD%(8|4)==0,...
Definition conv.h:281
Vector implementation for 3x3 BCHW, stores weights and biases, requires KW<=4, INP_W%4==0,...
Definition conv.h:118
Scalar stream implementation for BCHW, stores biases, requires GROUP==1, ConvHx8ReluStream<28,...
Definition conv.h:238
Scalar stream implementation for BCHW, stores biases, requires GROUP==1, ConvReluScalarStream<26,...
Definition conv.h:202
Scalar implementation for BCHW, stores weights and biases, requires GROUP==1, ConvReluScalar<28,...
Definition conv.h:40