onnx2versal
Loading...
Searching...
No Matches
qgemm.h
1#ifndef QGEMM_H_
2#define QGEMM_H_
3
4#include <adf.h>
5#include <assert.h>
6
7
30template <typename TT, typename TTPARAM, int M, int K, int N>
32
33 private:
34 alignas(32) TTPARAM (&weights)[N*K]; // KxN (256x120)
35 alignas(32) int32_t (&bias)[N]; // N (120)
36 float x_scale;
37 float w_scale;
38 float y_scale;
39 TT x_zero;
40 TTPARAM w_zero;
41 TT y_zero;
42
43 float scale;
44 alignas(32) TT in_row[K];
45
46 public:
48 TTPARAM (&w)[K*N],
49 int32_t (&b)[N],
50 float x_scale,
51 float w_scale,
52 float y_scale,
53 TT x_zero,
54 TTPARAM w_zero,
55 TT y_zero
56 ): weights(w), bias(b), x_scale(x_scale), w_scale(w_scale), y_scale(y_scale), x_zero(x_zero), w_zero(w_zero), y_zero(y_zero) {
57 scale = x_scale*w_scale/y_scale;
58 };
59
60 void filter(
61 input_stream<TT>* in, // MxK (1x256)
62 output_stream<TT>* out // MxN (1x120)
63 );
64
65 static void registerKernelClass() {
66 static_assert((std::is_same<TT, int8_t>::value) || (std::is_same<TT, uint8_t>::value));
67 REGISTER_FUNCTION(QgemmScalar::filter);
68 REGISTER_PARAMETER(weights);
69 REGISTER_PARAMETER(bias);
70 };
71};
72
73
78template <typename TT, typename TTPARAM, int M, int K, int N>
79class Qgemm {
80
81 private:
82 alignas(32) TTPARAM (&weights)[N*K]; // KxN (256x120)
83 alignas(32) int32_t (&bias)[N]; // N (120)
84 float x_scale;
85 float w_scale;
86 float y_scale;
87 TT x_zero;
88 TTPARAM w_zero;
89 TT y_zero;
90
91 // precomputation
92 int scalebits;
93 int16_t scale;
94 int32_t shift;
95
96 alignas(32) TT in_row[K];
97
98 public:
99 Qgemm (
100 TTPARAM (&w)[K*N],
101 int32_t (&b)[N],
102 float x_scale,
103 float w_scale,
104 float y_scale,
105 TT x_zero,
106 TTPARAM w_zero,
107 TT y_zero
108 );
109
110 void filter(
111 input_stream<TT>* in, // MxK
112 output_stream<TT>* out // MxN
113 );
114
115 static void registerKernelClass() {
116 static_assert((std::is_same<TT, int8_t>::value) || (std::is_same<TT, uint8_t>::value));
117 static_assert(K % 16 == 0);
118 static_assert(N % 16 == 0);
119 REGISTER_FUNCTION(Qgemm::filter);
120 REGISTER_PARAMETER(weights);
121 REGISTER_PARAMETER(bias);
122 };
123};
124
125
130template <typename TT, typename TTPARAM, int M, int K, int N>
132
133 private:
134 alignas(32) int32_t (&bias)[N]; // N (120)
135 float x_scale;
136 float w_scale;
137 float y_scale;
138 TT x_zero;
139 TTPARAM w_zero;
140 TT y_zero;
141
142 // precomputation
143 int scalebits;
144 int16_t scale;
145 int32_t shift;
146
147 alignas(32) TT in_row[K];
148
149 public:
151 int32_t (&b)[N],
152 float x_scale,
153 float w_scale,
154 float y_scale,
155 TT x_zero,
156 TTPARAM w_zero,
157 TT y_zero
158 );
159
160 void filter(
161 input_stream<TT>* in, // MxK
162 input_stream<TTPARAM>* weight, // KxN
163 output_stream<TT>* out // MxN
164 );
165
166 static void registerKernelClass() {
167 static_assert((std::is_same<TT, int8_t>::value) || (std::is_same<TT, uint8_t>::value));
168 static_assert((std::is_same<TTPARAM, int8_t>::value) || (std::is_same<TTPARAM, uint8_t>::value));
169 static_assert(K % 16 == 0);
170 static_assert(N % 16 == 0);
171 REGISTER_FUNCTION(QgemmStream::filter);
172 REGISTER_PARAMETER(bias);
173 };
174};
178#endif // QGEMM_H_
Scalar implementation for MK*KN, stores weights and biases, QgemmScalar<a,a,1,80,32> takes 17340 cycl...
Definition qgemm.h:31
Vector implementation for MK*KN, stores weights and biases, requires N%16=0 Qgemm<a,...
Definition qgemm.h:131
Vector implementation for MK*KN, stores weights and biases, requires N%16=0 Qgemm<a,...
Definition qgemm.h:79
void filter(input_stream< TT > *in, output_stream< TT > *out)
Definition qgemm.cc:117