29template <
int M,
int K,
int N,
int IS_RELU>
32 alignas(32)
float (&bias)[N];
33 alignas(32)
float in_row[K];
41 input_stream<float>* in,
42 input_stream<float>* weight,
43 output_window<float>* out
45 static void registerKernelClass() {
46 REGISTER_FUNCTION(GemmReluScalarMKNKStream::filter);
47 REGISTER_PARAMETER(bias);
57template <
int M,
int K,
int N,
int IS_RELU>
61 alignas(32)
float (&weights)[N*K];
62 alignas(32)
float (&bias)[N];
68 ): weights(w), bias(b) {};
71 input_window<float>* in,
72 output_window<float>* out
75 static void registerKernelClass() {
76 REGISTER_FUNCTION(GemmReluScalarMKNK::filter);
77 REGISTER_PARAMETER(weights);
78 REGISTER_PARAMETER(bias);
87template <
int M,
int K,
int N,
int IS_RELU>
91 alignas(32)
float (&weights)[K*N];
92 alignas(32)
float (&bias)[N];
98 ): weights(w), bias(b) {};
101 input_window<float>* in,
102 output_window<float>* out
105 static void registerKernelClass() {
106 REGISTER_FUNCTION(GemmReluScalarMKKN::filter);
107 REGISTER_PARAMETER(weights);
108 REGISTER_PARAMETER(bias);
118template <
int M,
int K,
int N,
int IS_RELU>
121 alignas(32)
float (&bias)[N];
122 alignas(32)
float in_row[2*K];
123 alignas(32)
float out_row[N];
131 input_stream<float>* in,
132 input_stream<float>* weight,
133 output_window<float>* out
135 static void registerKernelClass() {
136 static_assert((2*K + N)*4 <= 24576);
137 static_assert(N%8 == 0);
138 REGISTER_FUNCTION(GemmReluMKKNTwoAccsStream::filter);
139 REGISTER_PARAMETER(bias);
149template <
int M,
int K,
int N,
int IS_RELU>
152 alignas(32)
float (&bias)[N];
153 alignas(32)
float in_row[4*K];
154 alignas(32)
float out_row[3*N];
162 input_stream<float>* in,
163 input_stream<float>* weight,
164 output_window<float>* out
166 static void registerKernelClass() {
167 static_assert((4*K + 3*N)*4 <= 24576);
168 static_assert(N%8 == 0);
169 REGISTER_FUNCTION(GemmReluMKKNStream::filter);
170 REGISTER_PARAMETER(bias);
180template <
int M,
int K,
int N,
int IS_RELU>
184 alignas(32)
float (&weights)[K*N];
185 alignas(32)
float (&bias)[N];
187 static constexpr int K_REM8 = K%8;
188 static constexpr int RUN_LASTCHUNK = K_REM8 > 0;
194 ): weights(w), bias(b) {};
197 input_window<float>* in,
198 output_window<float>* out
201 static void registerKernelClass() {
202 static_assert(K%4==0 && N%8==0);
203 REGISTER_FUNCTION(GemmReluMKKN::filter);
204 REGISTER_PARAMETER(weights);
205 REGISTER_PARAMETER(bias);
Vector implementation for MK*KN, streams input, outputs, weights, stores bias, requires (4*K + 3*N)*4...
Definition gemm.h:150
Vector implementation for MK*KN, streams input, outputs, weights, stores bias, requires (2*K + N)*4 <...
Definition gemm.h:119
Vector implementation for MK*KN, stores weights and biases, requires K%4=0, N%8=0 GemmReluMKKN<7,...
Definition gemm.h:181
Scalar implementation for MK*KN, stores weights and biases, GemmReluScalarMKKN<7,36,...
Definition gemm.h:88
Scalar implementation for MK*NK, streams input, outputs, weights, stores bias GemmReluScalarMKNKStrea...
Definition gemm.h:30
Scalar implementation for MK*NK, stores weights and biases, Running GemmReluScalarMKNK<7,...
Definition gemm.h:58