tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
conv2d_op_avx.h
1/*
2 Copyright (c) 2016, Taiga Nomi, Edgar Riba
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28
29#include <vector>
30#include "tiny_dnn/core/params/conv_params.h"
31#include "tiny_dnn/core/kernels/conv2d_op_internal.h"
32
33#ifdef CNN_USE_AVX
34#include "tiny_dnn/core/kernels/avx_kernel_common.h"
35#endif
36
37namespace tiny_dnn {
38namespace kernels {
39
40#ifdef CNN_USE_AVX
41
42// float ver
43template <typename Allocator>
44void avx_conv2d_5x5_kernel(const core::conv_params& params,
45 const std::vector<float, Allocator>& in,
46 const std::vector<float, Allocator>& W,
47 const std::vector<float, Allocator>& bias,
48 std::vector<float, Allocator>& a,
49 const bool layer_parallelize) {
50 assert(params.weight.height_ == 5 && params.weight.width_ == 5);
51
52 auto& out = params.out;
53 auto& in_padded = params.in_padded;
54 auto& tbl = params.tbl;
55 auto w_stride = params.w_stride;
56
57 const serial_size_t out_area = out.area();
58 serial_size_t oidx = 0;
59 float bias_scale = params.has_bias ? 1.0f : 0.0f;
60 const serial_size_t stride = params.h_stride * in_padded.width_;
61 const serial_size_t inarea = in_padded.area();
62
63 static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
64 // static const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0));
65
66 const __m128 y_bias_scale = _mm_set_ss(bias_scale);
67 if (out.height_ == 1 && out.width_ == 1) {
68 const float* pw = (const float*)&W[0];
69 for (serial_size_t o = 0; o < out.depth_; ++o) {
70 __m256 sum0 = _mm256_setzero_ps();
71 __m256 sum1 = _mm256_setzero_ps();
72 __m256 sum2 = _mm256_setzero_ps();
73 __m128 sum3 = _mm_setzero_ps();
74 const float* pi = (const float*)&in[0];
75 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, pi += inarea) {
76 if (!tbl.is_connected(o, inc)) {
77 continue;
78 }
79 __m256 w0 = _mm256_loadu_ps(pw + 0);
80 __m256 w1 = _mm256_loadu_ps(pw + 8);
81 __m256 w2 = _mm256_loadu_ps(pw + 16);
82 __m256 i0 = _mm256_loadu_ps(pi + 0);
83 __m256 i1 = _mm256_loadu_ps(pi + 8);
84 __m256 i2 = _mm256_loadu_ps(pi + 16);
85 __m128 w3 = _mm_load_ss(pw + 24);
86 __m128 i3 = _mm_load_ss(pi + 24);
87 __m256 tmp0 = _mm256_mul_ps(w0, i0);
88 __m256 tmp1 = _mm256_mul_ps(w1, i1);
89 __m256 tmp2 = _mm256_mul_ps(w2, i2);
90 __m128 tmp3 = _mm_mul_ps(w3, i3);
91 sum0 = _mm256_add_ps(tmp0, sum0);
92 sum1 = _mm256_add_ps(tmp1, sum1);
93 sum2 = _mm256_add_ps(tmp2, sum2);
94 sum3 = _mm_add_ps(tmp3, sum3);
95 }
96 __m256 sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), sum2);
97 __m128 b = _mm_load_ss(&bias[o]);
98 __m128 hsum = hsum256_ps(sum);
99 b = madd128_ss(b, y_bias_scale, sum3);
100 _mm_store_ss(&a[o], _mm_add_ss(hsum, b));
101 }
102 } else {
103 const serial_size_t nblocks = out.width_ / 4;
104 for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
105 float* pa = &a[oidx];
106 // init to bias value
107 float b = bias[o] * bias_scale;
108 {
109 size_t headSize = 0;
110 __m256 b2 = _mm256_set1_ps(b);
111 if (oidx & 7) {
112 headSize = 8 - (oidx & 7);
113 assert(headSize < out_area);
114 for (size_t i=0; i<headSize; ++i) {
115 _mm_store_ss(&pa[i], _mm256_castps256_ps128(b2));
116 }
117 }
118 size_t cnt = (out_area - headSize) / 16;
119 float* pa2 = pa + headSize;
120 for (size_t i=0; i<cnt; ++i) {
121 _mm256_store_ps(&pa2[i*16+0], b2);
122 _mm256_store_ps(&pa2[i*16+8], b2);
123 }
124 for (size_t i=headSize+cnt*16; i<out_area; ++i) {
125 pa[i] = b;
126 }
127 }
128 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
129 if (!tbl.is_connected(o, inc)) continue;
130
131 const float* pw = (const float*) &W[25 * (params.in.depth_ * o + inc)];
132 const float* pi = (const float*) &in[in_padded.get_index(0, 0, inc)];
133
134 __m256 w0a = _mm256_maskload_ps(pw+0, imask);
135 __m256 w1a = _mm256_maskload_ps(pw+5, imask);
136 __m256 w2a = _mm256_maskload_ps(pw+10, imask);
137 __m256 w3a = _mm256_maskload_ps(pw+15, imask);
138 __m256 w4a = _mm256_maskload_ps(pw+20, imask);
139 __m256 w0b = leftShift<4>(w0a);
140 __m256 w1b = leftShift<4>(w1a);
141 __m256 w2b = leftShift<4>(w2a);
142 __m256 w3b = leftShift<4>(w3a);
143 __m256 w4b = leftShift<4>(w4a);
144 __m256 w0c = leftShift<8>(w0a);
145 __m256 w1c = leftShift<8>(w1a);
146 __m256 w2c = leftShift<8>(w2a);
147 __m256 w3c = leftShift<8>(w3a);
148 __m256 w4c = leftShift<8>(w4a);
149 __m256 w0d = leftShift<12>(w0a);
150 __m256 w1d = leftShift<12>(w1a);
151 __m256 w2d = leftShift<12>(w2a);
152 __m256 w3d = leftShift<12>(w3a);
153 __m256 w4d = leftShift<12>(w4a);
154 float* ppa = pa;
155 for (serial_size_t y = 0; y < out.height_; y++) {
156 const float* pi0 = (pi + y * stride);
157 const float* pi1 = pi0 + 1 * in_padded.width_;
158 const float* pi2 = pi0 + 2 * in_padded.width_;
159 const float* pi3 = pi0 + 3 * in_padded.width_;
160 const float* pi4 = pi0 + 4 * in_padded.width_;
161 serial_size_t x = 0;
162 if (w_stride == 1) {
163 __m256 dst0, dst1, dst2, dst3;
164 float* ppa2 = ppa;
165 for (size_t i = 0; i < nblocks; ++i) {
166 __m256 i0 = _mm256_loadu_ps(pi0);
167 __m256 i1 = _mm256_loadu_ps(pi1);
168 __m256 i2 = _mm256_loadu_ps(pi2);
169 __m256 i3 = _mm256_loadu_ps(pi3);
170 __m256 i4 = _mm256_loadu_ps(pi4);
171 __m128 sum = _mm_loadu_ps(ppa2);
172 dst0 = _mm256_mul_ps(w0a, i0);
173 dst1 = _mm256_mul_ps(w0b, i0);
174 dst2 = _mm256_mul_ps(w0c, i0);
175 dst3 = _mm256_mul_ps(w0d, i0);
176 dst0 = madd256_ps(w1a, i1, dst0);
177 dst1 = madd256_ps(w1b, i1, dst1);
178 dst2 = madd256_ps(w1c, i1, dst2);
179 dst3 = madd256_ps(w1d, i1, dst3);
180 dst0 = madd256_ps(w2a, i2, dst0);
181 dst1 = madd256_ps(w2b, i2, dst1);
182 dst2 = madd256_ps(w2c, i2, dst2);
183 dst3 = madd256_ps(w2d, i2, dst3);
184 dst0 = madd256_ps(w3a, i3, dst0);
185 dst1 = madd256_ps(w3b, i3, dst1);
186 dst2 = madd256_ps(w3c, i3, dst2);
187 dst3 = madd256_ps(w3d, i3, dst3);
188 dst0 = madd256_ps(w4a, i4, dst0);
189 dst1 = madd256_ps(w4b, i4, dst1);
190 __m128 hsum01 = hsum2x256_ps(dst0, dst1);
191 dst2 = madd256_ps(w4c, i4, dst2);
192 dst3 = madd256_ps(w4d, i4, dst3);
193 __m128 hsum23 = hsum2x256_ps(dst2, dst3);
194 __m128 sum2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(hsum01), _mm_castps_pd(hsum23)));
195 sum = _mm_add_ps(sum, sum2);
196 _mm_storeu_ps(ppa2, sum);
197 pi0 += 4;
198 pi1 += 4;
199 pi2 += 4;
200 pi3 += 4;
201 pi4 += 4;
202 ppa2 += 4;
203 }
204 x = nblocks * 4;
205 }
206 for (; x < out.width_; ++x) {
207 __m128 sum = _mm_load_ss(&ppa[x]);
208 __m256 i0 = _mm256_loadu_ps(pi0);
209 __m256 i1 = _mm256_loadu_ps(pi1);
210 __m256 i2 = _mm256_loadu_ps(pi2);
211 __m256 i3 = _mm256_loadu_ps(pi3);
212 __m256 i4 = _mm256_maskload_ps(pi4, imask);
213 __m256 sum0 = _mm256_mul_ps(w0a, i0);
214 __m256 sum1 = _mm256_mul_ps(w1a, i1);
215 sum0 = madd256_ps(w2a, i2, sum0);
216 sum1 = madd256_ps(w3a, i3, sum1);
217 sum0 = madd256_ps(w4a, i4, sum0);
218 sum0 = _mm256_add_ps(sum0, sum1);
219 _mm_store_ss(&ppa[x], _mm_add_ss(sum, hsum256_ps(sum0)));
220// printf("%d %d %d %f\n", inc, y, x, ppa[x]);
221 pi0 += w_stride;
222 pi1 += w_stride;
223 pi2 += w_stride;
224 pi3 += w_stride;
225 pi4 += w_stride;
226 } // x loop
227 ppa += out.width_;
228 } // y loop
229 } // in depth loop
230 } // out depth loop
231 } // else
232} // avx_conv2d_5x5_kernel float ver
233
234// double ver
235template <typename Allocator>
236void avx_conv2d_5x5_kernel(const core::conv_params& params,
237 const std::vector<double, Allocator>& in,
238 const std::vector<double, Allocator>& W,
239 const std::vector<double, Allocator>& bias,
240 std::vector<double, Allocator>& a,
241 const bool layer_parallelize) {
242 assert(params.weight.height_ == 5 && params.weight.width_ == 5);
243
244 auto& out = params.out;
245 auto& in_padded = params.in_padded;
246 auto& tbl = params.tbl;
247 auto w_stride = params.w_stride;
248
249 const size_t out_area = out.area();
250 double bias_scale = params.has_bias ? 1.0 : 0.0;
251 const __m128d y_bias_scale = _mm_set_sd(bias_scale);
252 serial_size_t oidx = 0;
253
254 const size_t in_stride = params.h_stride * in_padded.width_;
255 const size_t in_padded_area = in_padded.area();
256
257 if (out.height_ == 1 && out.width_ == 1) {
258 const double* pw = &W[0];
259 for (size_t o = 0; o < out.depth_; ++o) {
260 __m256d sum0 = _mm256_setzero_pd();
261 __m256d sum1 = _mm256_setzero_pd();
262 __m256d sum2 = _mm256_setzero_pd();
263 __m256d sum3 = _mm256_setzero_pd();
264 __m256d sum4 = _mm256_setzero_pd();
265 __m256d sum5 = _mm256_setzero_pd();
266 __m128d sum6 = _mm_setzero_pd();
267 size_t inidx = 0;
268 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, inidx += in_padded_area) {
269 if (!tbl.is_connected(o, inc)) {
270 continue;
271 }
272 __m256d w0 = _mm256_loadu_pd(pw + 0);
273 __m256d w1 = _mm256_loadu_pd(pw + 4);
274 __m256d w2 = _mm256_loadu_pd(pw + 8);
275 __m256d w3 = _mm256_loadu_pd(pw + 12);
276 __m256d w4 = _mm256_loadu_pd(pw + 16);
277 __m256d w5 = _mm256_loadu_pd(pw + 20);
278 __m128d w6 = _mm_load_sd(pw + 24);
279 const double* pi = (const double*)&in[inidx];
280 __m256d i0 = _mm256_loadu_pd(pi + 0);
281 __m256d i1 = _mm256_loadu_pd(pi + 4);
282 __m256d i2 = _mm256_loadu_pd(pi + 8);
283 __m256d i3 = _mm256_loadu_pd(pi + 12);
284 __m256d i4 = _mm256_loadu_pd(pi + 16);
285 __m256d i5 = _mm256_loadu_pd(pi + 20);
286 __m128d i6 = _mm_load_sd(pi + 24);
287 __m256d tmp0 = _mm256_mul_pd(w0, i0);
288 __m256d tmp1 = _mm256_mul_pd(w1, i1);
289 __m256d tmp2 = _mm256_mul_pd(w2, i2);
290 __m256d tmp3 = _mm256_mul_pd(w3, i3);
291 __m256d tmp4 = _mm256_mul_pd(w4, i4);
292 __m256d tmp5 = _mm256_mul_pd(w5, i5);
293 __m128d tmp6 = _mm_mul_pd(w6, i6);
294 sum0 = _mm256_add_pd(tmp0, sum0);
295 sum1 = _mm256_add_pd(tmp1, sum1);
296 sum2 = _mm256_add_pd(tmp2, sum2);
297 sum3 = _mm256_add_pd(tmp3, sum3);
298 sum4 = _mm256_add_pd(tmp4, sum4);
299 sum5 = _mm256_add_pd(tmp5, sum5);
300 sum6 = _mm_add_pd(tmp6, sum6);
301 }
302 sum0 = _mm256_add_pd(sum0, sum1);
303 sum2 = _mm256_add_pd(sum2, sum3);
304 sum4 = _mm256_add_pd(sum4, sum5);
305 sum0 = _mm256_add_pd(sum0, sum2);
306 __m256d sum = _mm256_add_pd(sum0, sum4);
307 __m128d b = _mm_load_sd(&bias[o]);
308 __m128d hsum = hsum256_pd(sum);
309 b = madd128_sd(b, y_bias_scale, sum6);
310 _mm_store_sd(&a[o], _mm_add_sd(hsum, b));
311 }
312 } else {
313 for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
314 double* pa = &a[oidx];
315 double b = bias[o] * bias_scale;
316 {
317 size_t headSize = 0;
318 __m256d b2 = _mm256_set1_pd(b);
319 if (oidx & 3) {
320 headSize = 4 - (oidx & 3);
321 assert(headSize < out_area);
322 for (size_t i = 0; i < headSize; ++i) {
323 _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
324 }
325 }
326 size_t cnt = (out_area - headSize) / 8;
327 double* pa2 = pa + headSize;
328 for (size_t i = 0; i < cnt; ++i) {
329 _mm256_store_pd(&pa2[i*8+0], b2);
330 _mm256_store_pd(&pa2[i*8+4], b2);
331 }
332 for (size_t i = headSize + cnt*8; i < out_area; ++i) {
333 _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
334 }
335 }
336
337 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
338 if (!tbl.is_connected(o, inc)) continue;
339
340 const double* pw = (const double*)&W[25 * (params.in.depth_ * o + inc)];
341 const double* pi = &in[in_padded.get_index(0, 0, inc)];
342
343 __m256d w0a = _mm256_loadu_pd(pw+0);
344 __m128d w0b = _mm_load_sd(pw+4);
345 __m256d w1a = _mm256_loadu_pd(pw+5);
346 __m128d w1b = _mm_load_sd(pw+9);
347 __m256d w2a = _mm256_loadu_pd(pw+10);
348 __m128d w2b = _mm_load_sd(pw+14);
349 __m256d w3a = _mm256_loadu_pd(pw+15);
350 __m128d w3b = _mm_load_sd(pw+19);
351 __m256d w4a = _mm256_loadu_pd(pw+20);
352 __m128d w4b = _mm_load_sd(pw+24);
353
354 double* ppa = pa;
355 for (serial_size_t y = 0; y < out.height_; ++y, pi += in_stride, ppa += out.width_) {
356 const double* pi0 = pi + 0 * in_padded.width_;
357 const double* pi1 = pi + 1 * in_padded.width_;
358 const double* pi2 = pi + 2 * in_padded.width_;
359 const double* pi3 = pi + 3 * in_padded.width_;
360 const double* pi4 = pi + 4 * in_padded.width_;
361 for (serial_size_t x = 0; x < out.width_; ++x) {
362 __m128d sum = _mm_load_sd(&ppa[x]);
363 __m256d i0a = _mm256_loadu_pd(pi0);
364 __m128d i0b = _mm_load_sd(pi0 + 4);
365 __m256d i1a = _mm256_loadu_pd(pi1);
366 __m128d i1b = _mm_load_sd(pi1 + 4);
367 __m256d i2a = _mm256_loadu_pd(pi2);
368 __m128d i2b = _mm_load_sd(pi2 + 4);
369 __m256d i3a = _mm256_loadu_pd(pi3);
370 __m128d i3b = _mm_load_sd(pi3 + 4);
371 __m256d i4a = _mm256_loadu_pd(pi4);
372 __m128d i4b = _mm_load_sd(pi4 + 4);
373 __m256d sum_a = _mm256_mul_pd(w0a, i0a);
374 __m128d sum_b = _mm_mul_sd(w0b, i0b);
375 sum_a = madd256_pd(w1a, i1a, sum_a);
376 sum_b = madd128_pd(w1b, i1b, sum_b);
377 sum_a = madd256_pd(w2a, i2a, sum_a);
378 sum_b = madd128_pd(w2b, i2b, sum_b);
379 sum_a = madd256_pd(w3a, i3a, sum_a);
380 sum_b = madd128_pd(w3b, i3b, sum_b);
381 sum_a = madd256_pd(w4a, i4a, sum_a);
382 sum_b = madd128_pd(w4b, i4b, sum_b);
383 __m128d sum_c = hsum256_pd(sum_a);
384 sum = _mm_add_sd(sum, sum_b);
385 _mm_store_sd(&ppa[x], _mm_add_sd(sum, sum_c));
386 pi0 += w_stride;
387 pi1 += w_stride;
388 pi2 += w_stride;
389 pi3 += w_stride;
390 pi4 += w_stride;
391 } // x loop
392 } // y loop
393 } // in depth loop
394 } // out depth loop
395 } // else
396} // avx_conv2d_5x5_kernel double ver
397
398#endif // CNN_USE_AVX
399
400inline void conv2d_op_avx(const tensor_t& in_data,
401 const vec_t& W,
402 const vec_t& bias,
403 tensor_t& out_data,
404 const core::conv_params& params,
405 const bool layer_parallelize) {
406#ifdef CNN_USE_AVX
407 if (params.weight.height_ == 5 && params.weight.width_ == 5) {
408 // @todo consider better parallelization
409 for_i(layer_parallelize, in_data.size(), [&](int i) {
410 avx_conv2d_5x5_kernel(params, in_data[i], W, bias, out_data[i], layer_parallelize);
411 });
412 return;
413 }
414#endif
415 conv2d_op_internal(in_data, W, bias, out_data, params, layer_parallelize);
416}
417
418} // namespace kernels
419} // namespace tiny_dnn