onnx2versal
Loading...
Searching...
No Matches
gemm.h
1#ifndef GEMM_H_
2#define GEMM_H_
3
4#include <adf.h>
5#include <assert.h>
6
7
29template <int M, int K, int N, int IS_RELU>
31 private:
32 alignas(32) float (&bias)[N];
33 alignas(32) float in_row[K];
34
35 public:
37 float (&b)[N]
38 ): bias(b) {};
39
40 void filter(
41 input_stream<float>* in, // MxK
42 input_stream<float>* weight, // NxK
43 output_window<float>* out // MxN
44 );
45 static void registerKernelClass() {
46 REGISTER_FUNCTION(GemmReluScalarMKNKStream::filter);
47 REGISTER_PARAMETER(bias);
48 };
49};
50
51
56// xA^T + b as per torch,nn.Linear
57template <int M, int K, int N, int IS_RELU>
59
60 private:
61 alignas(32) float (&weights)[N*K]; // NxK
62 alignas(32) float (&bias)[N]; // N
63
64 public:
66 float (&w)[N*K],
67 float (&b)[N]
68 ): weights(w), bias(b) {};
69
70 void filter(
71 input_window<float>* in, // MxK
72 output_window<float>* out // MxN
73 );
74
75 static void registerKernelClass() {
76 REGISTER_FUNCTION(GemmReluScalarMKNK::filter);
77 REGISTER_PARAMETER(weights);
78 REGISTER_PARAMETER(bias);
79 };
80};
81
82
87template <int M, int K, int N, int IS_RELU>
89
90 private:
91 alignas(32) float (&weights)[K*N]; // KxN
92 alignas(32) float (&bias)[N]; // N
93
94 public:
96 float (&w)[N*K],
97 float (&b)[N]
98 ): weights(w), bias(b) {};
99
100 void filter(
101 input_window<float>* in, // MxK
102 output_window<float>* out // MxN
103 );
104
105 static void registerKernelClass() {
106 REGISTER_FUNCTION(GemmReluScalarMKKN::filter);
107 REGISTER_PARAMETER(weights);
108 REGISTER_PARAMETER(bias);
109 };
110};
111
112
118template <int M, int K, int N, int IS_RELU>
120 private:
121 alignas(32) float (&bias)[N];
122 alignas(32) float in_row[2*K];
123 alignas(32) float out_row[N];
124
125 public:
127 float (&b)[N]
128 ): bias(b) {};
129
130 void filter(
131 input_stream<float>* in, // MxK
132 input_stream<float>* weight, // NxK
133 output_window<float>* out // MxN
134 );
135 static void registerKernelClass() {
136 static_assert((2*K + N)*4 <= 24576);
137 static_assert(N%8 == 0);
138 REGISTER_FUNCTION(GemmReluMKKNTwoAccsStream::filter);
139 REGISTER_PARAMETER(bias);
140 };
141};
142
143
149template <int M, int K, int N, int IS_RELU>
151 private:
152 alignas(32) float (&bias)[N];
153 alignas(32) float in_row[4*K];
154 alignas(32) float out_row[3*N];
155
156 public:
158 float (&b)[N]
159 ): bias(b) {};
160
161 void filter(
162 input_stream<float>* in, // MxK
163 input_stream<float>* weight, // NxK
164 output_window<float>* out // MxN
165 );
166 static void registerKernelClass() {
167 static_assert((4*K + 3*N)*4 <= 24576);
168 static_assert(N%8 == 0);
169 REGISTER_FUNCTION(GemmReluMKKNStream::filter);
170 REGISTER_PARAMETER(bias);
171 };
172};
173
174
180template <int M, int K, int N, int IS_RELU>
182
183 private:
184 alignas(32) float (&weights)[K*N]; // KxN
185 alignas(32) float (&bias)[N]; // N
186
187 static constexpr int K_REM8 = K%8;
188 static constexpr int RUN_LASTCHUNK = K_REM8 > 0;
189
190 public:
192 float (&w)[N*K],
193 float (&b)[N]
194 ): weights(w), bias(b) {};
195
196 void filter(
197 input_window<float>* in, // MxK
198 output_window<float>* out // MxN
199 );
200
201 static void registerKernelClass() {
202 static_assert(K%4==0 && N%8==0);
203 REGISTER_FUNCTION(GemmReluMKKN::filter);
204 REGISTER_PARAMETER(weights);
205 REGISTER_PARAMETER(bias);
206 };
207};
211#endif // GEMM_H_
Vector implementation for MK*KN, streams input, outputs, weights, stores bias, requires (4*K + 3*N)*4...
Definition gemm.h:150
Vector implementation for MK*KN, streams input, outputs, weights, stores bias, requires (2*K + N)*4 <...
Definition gemm.h:119
Vector implementation for MK*KN, stores weights and biases, requires K%4=0, N%8=0 GemmReluMKKN<7,...
Definition gemm.h:181
Scalar implementation for MK*KN, stores weights and biases, GemmReluScalarMKKN<7,36,...
Definition gemm.h:88
Scalar implementation for MK*NK, streams input, outputs, weights, stores bias GemmReluScalarMKNKStrea...
Definition gemm.h:30
Scalar implementation for MK*NK, stores weights and biases, Running GemmReluScalarMKNK<7,...
Definition gemm.h:58