12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 15 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__) 19 template<
typename Scalar,
typename Index,
typename LhsMapper,
20 typename RhsMapper,
typename OutputMapper,
bool needs_edge_check>
21 __device__ EIGEN_STRONG_INLINE
void 22 EigenContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
23 const OutputMapper output,
volatile Scalar* lhs_shmem,
volatile Scalar* rhs_shmem,
24 const Index m_size,
const Index n_size,
const Index k_size) {
26 const Index m_block_idx = blockIdx.x;
27 const Index n_block_idx = blockIdx.y;
29 const Index base_m = 64 * m_block_idx;
30 const Index base_n = 64 * n_block_idx;
67 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
97 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98 const Index lhs_vert = base_m + load_idx_vert;
100 #define prefetchIntoRegisters(base_k) \ 102 lhs_pf0 = Scalar(0); \ 103 lhs_pf1 = Scalar(0); \ 104 lhs_pf2 = Scalar(0); \ 105 lhs_pf3 = Scalar(0); \ 106 lhs_pf4 = Scalar(0); \ 107 lhs_pf5 = Scalar(0); \ 108 lhs_pf6 = Scalar(0); \ 109 lhs_pf7 = Scalar(0); \ 111 rhs_pf0 = Scalar(0); \ 112 rhs_pf1 = Scalar(0); \ 113 rhs_pf2 = Scalar(0); \ 114 rhs_pf3 = Scalar(0); \ 115 rhs_pf4 = Scalar(0); \ 116 rhs_pf5 = Scalar(0); \ 117 rhs_pf6 = Scalar(0); \ 118 rhs_pf7 = Scalar(0); \ 120 if (!needs_edge_check || lhs_vert < m_size) { \ 121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ 122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ 123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ 124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ 125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ 126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ 127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ 128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ 130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \ 131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ 139 } else if (lhs_horiz_6 < k_size) { \ 140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 147 } else if (lhs_horiz_5 < k_size) { \ 148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 154 } else if (lhs_horiz_4 < k_size) { \ 155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 160 } else if (lhs_horiz_3 < k_size) { \ 161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 165 } else if (lhs_horiz_2 < k_size) { \ 166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 169 } else if (lhs_horiz_1 < k_size) { \ 170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 172 } else if (lhs_horiz_0 < k_size) { \ 173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 177 const Index rhs_vert = base_k + load_idx_vert; \ 178 if (!needs_edge_check || rhs_vert < k_size) { \ 179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ 180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ 181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ 182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ 183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ 184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ 185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ 186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ 188 if (rhs_horiz_7 < n_size) { \ 189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ 197 } else if (rhs_horiz_6 < n_size) { \ 198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 205 } else if (rhs_horiz_5 < n_size) { \ 206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 212 } else if (rhs_horiz_4 < n_size) { \ 213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 218 } else if (rhs_horiz_3 < n_size) { \ 219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 223 } else if (rhs_horiz_2 < n_size) { \ 224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 227 } else if (rhs_horiz_1 < n_size) { \ 228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 230 } else if (rhs_horiz_0 < n_size) { \ 231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 236 #define writeRegToShmem(_) \ 237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ 238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ 240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ 241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ 243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ 244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ 246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ 247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ 249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ 250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ 252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ 253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ 255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ 256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ 258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ 259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ 262 #define res(i, j) _res_##i##j 263 #define initResultRow(i) \ 264 Scalar res(i, 0) = Scalar(0); \ 265 Scalar res(i, 1) = Scalar(0); \ 266 Scalar res(i, 2) = Scalar(0); \ 267 Scalar res(i, 3) = Scalar(0); \ 268 Scalar res(i, 4) = Scalar(0); \ 269 Scalar res(i, 5) = Scalar(0); \ 270 Scalar res(i, 6) = Scalar(0); \ 271 Scalar res(i, 7) = Scalar(0); \ 283 for (Index base_k = 0; base_k < k_size; base_k += 64) {
288 prefetchIntoRegisters(base_k);
291 #undef prefetchIntoRegisters 292 #undef writeRegToShmem 300 #define lcol(i) _lcol##i 310 #define rrow(j) _rrow##j 321 const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
322 const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
324 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] 325 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] 327 #define loadData(i, j) \ 328 lcol(0) = lhs_element(0, j); \ 329 rrow(0) = rhs_element(i, 0); \ 330 lcol(1) = lhs_element(1, j); \ 331 rrow(1) = rhs_element(i, 1); \ 332 lcol(2) = lhs_element(2, j); \ 333 rrow(2) = rhs_element(i, 2); \ 334 lcol(3) = lhs_element(3, j); \ 335 rrow(3) = rhs_element(i, 3); \ 336 lcol(4) = lhs_element(4, j); \ 337 rrow(4) = rhs_element(i, 4); \ 338 lcol(5) = lhs_element(5, j); \ 339 rrow(5) = rhs_element(i, 5); \ 340 lcol(6) = lhs_element(6, j); \ 341 rrow(6) = rhs_element(i, 6); \ 342 lcol(7) = lhs_element(7, j); \ 343 rrow(7) = rhs_element(i, 7); \ 345 #define computeCol(j) \ 346 res(0, j) += lcol(0) * rrow(j); \ 347 res(1, j) += lcol(1) * rrow(j); \ 348 res(2, j) += lcol(2) * rrow(j); \ 349 res(3, j) += lcol(3) * rrow(j); \ 350 res(4, j) += lcol(4) * rrow(j); \ 351 res(5, j) += lcol(5) * rrow(j); \ 352 res(6, j) += lcol(6) * rrow(j); \ 353 res(7, j) += lcol(7) * rrow(j); \ 355 #define computePass(i) \ 390 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) 392 #define reduceRow(i, mask) \ 393 shuffleInc(i, 0, mask); \ 394 shuffleInc(i, 1, mask); \ 395 shuffleInc(i, 2, mask); \ 396 shuffleInc(i, 3, mask); \ 397 shuffleInc(i, 4, mask); \ 398 shuffleInc(i, 5, mask); \ 399 shuffleInc(i, 6, mask); \ 400 shuffleInc(i, 7, mask); \ 402 #define reduceMatrix(mask) \ 403 reduceRow(0, mask); \ 404 reduceRow(1, mask); \ 405 reduceRow(2, mask); \ 406 reduceRow(3, mask); \ 407 reduceRow(4, mask); \ 408 reduceRow(5, mask); \ 409 reduceRow(6, mask); \ 410 reduceRow(7, mask); \ 437 #define writeResultShmem(i, j) \ 438 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ 440 #define writeRow(i) \ 441 writeResultShmem(i, 0); \ 442 writeResultShmem(i, 1); \ 443 writeResultShmem(i, 2); \ 444 writeResultShmem(i, 3); \ 445 writeResultShmem(i, 4); \ 446 writeResultShmem(i, 5); \ 447 writeResultShmem(i, 6); \ 448 writeResultShmem(i, 7); \ 450 if (threadIdx.x == 0) {
460 #undef writeResultShmem 463 const int max_i_write = (min)((
int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
464 const int max_j_write = (min)((
int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
466 if (threadIdx.x < max_i_write) {
467 if (max_j_write == 8) {
469 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
470 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
471 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
472 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
473 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
474 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
475 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
476 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
478 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
479 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
480 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
481 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
482 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
488 for (
int j = 0; j < max_j_write; j++) {
489 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
498 template<
typename Scalar,
typename Index,
typename LhsMapper,
499 typename RhsMapper,
typename OutputMapper>
501 __launch_bounds__(512)
502 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
503 const OutputMapper output,
504 const Index m_size, const Index n_size, const Index k_size) {
505 __shared__
volatile Scalar lhs_shmem[72 * 64];
506 __shared__
volatile Scalar rhs_shmem[72 * 64];
508 const Index m_block_idx = blockIdx.x;
509 const Index n_block_idx = blockIdx.y;
511 const Index base_m = 64 * m_block_idx;
512 const Index base_n = 64 * n_block_idx;
514 if (base_m + 63 < m_size && base_n + 63 < n_size) {
515 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
517 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
522 template<
typename Index,
typename LhsMapper,
523 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
524 bool CHECK_RHS_BOUNDARY>
525 __device__ EIGEN_STRONG_INLINE
void 526 EigenFloatContractionKernelInternal16x16(
const LhsMapper lhs,
const RhsMapper rhs,
527 const OutputMapper output, float2 lhs_shmem2[][16],
528 float2 rhs_shmem2[][8],
const Index m_size,
529 const Index n_size,
const Index k_size,
530 const Index base_m,
const Index base_n) {
531 typedef float Scalar;
534 float4 lhs_pf0, rhs_pf0;
537 for (
int i=0; i < 4; i++) {
538 results[i].x = results[i].y = results[i].z = results[i].w = 0;
542 #define prefetch_lhs(reg, row, col) \ 543 if (!CHECK_LHS_BOUNDARY) { \ 544 if (col < k_size) { \ 545 reg =lhs.loadPacket(row, col); \ 548 if (col < k_size) { \ 549 if (row + 3 < m_size) { \ 550 reg =lhs.loadPacket(row, col); \ 551 } else if (row + 2 < m_size) { \ 552 reg.x =lhs(row + 0, col); \ 553 reg.y =lhs(row + 1, col); \ 554 reg.z =lhs(row + 2, col); \ 555 } else if (row + 1 < m_size) { \ 556 reg.x =lhs(row + 0, col); \ 557 reg.y =lhs(row + 1, col); \ 558 } else if (row < m_size) { \ 559 reg.x =lhs(row + 0, col); \ 565 Index lhs_vert = base_m+threadIdx.x*4;
567 for (Index k = 0; k < k_size; k += 16) {
568 lhs_pf0 = internal::pset1<float4>(0);
569 rhs_pf0 = internal::pset1<float4>(0);
571 Index lhs_horiz = threadIdx.y+k;
572 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
574 Index rhs_vert = k+(threadIdx.x%4)*4;
575 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
577 if (!CHECK_RHS_BOUNDARY) {
578 if ((rhs_vert + 3) < k_size) {
580 rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0);
581 }
else if (rhs_vert + 2 < k_size) {
583 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
584 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
585 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
586 }
else if (rhs_vert + 1 < k_size) {
587 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
588 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
589 }
else if (rhs_vert < k_size) {
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
593 if (rhs_horiz0 < n_size) {
594 if ((rhs_vert + 3) < k_size) {
595 rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0);
596 }
else if ((rhs_vert + 2) < k_size) {
597 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
598 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
599 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
600 }
else if ((rhs_vert + 1) < k_size) {
601 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
602 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
603 }
else if (rhs_vert < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
610 if((threadIdx.x%8) < 4) {
617 x1 = __shfl_xor(x1, 4);
618 x2 = __shfl_xor(x2, 4);
619 if((threadIdx.x%8) < 4) {
634 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
635 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
644 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
645 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
648 #define add_vals(fl1, fl2, fr1, fr2)\ 649 results[0].x += fl1.x * fr1.x;\ 650 results[0].y += fl1.y * fr1.x;\ 651 results[0].z += fl2.x * fr1.x;\ 652 results[0].w += fl2.y * fr1.x;\ 654 results[1].x += fl1.x * fr1.y;\ 655 results[1].y += fl1.y * fr1.y;\ 656 results[1].z += fl2.x * fr1.y;\ 657 results[1].w += fl2.y * fr1.y;\ 659 results[2].x += fl1.x * fr2.x;\ 660 results[2].y += fl1.y * fr2.x;\ 661 results[2].z += fl2.x * fr2.x;\ 662 results[2].w += fl2.y * fr2.x;\ 664 results[3].x += fl1.x * fr2.y;\ 665 results[3].y += fl1.y * fr2.y;\ 666 results[3].z += fl2.x * fr2.y;\ 667 results[3].w += fl2.y * fr2.y;\ 673 for (
int koff = 0; koff < 16; koff ++) {
675 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
676 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
678 int start_feature = threadIdx.y * 4;
679 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
680 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
682 add_vals(fl1, fl2, fr1, fr2)
690 Index horiz_base = threadIdx.y*4+base_n;
691 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
692 for (
int i = 0; i < 4; i++) {
693 output(lhs_vert, horiz_base + i) = results[i].x;
694 output(lhs_vert + 1, horiz_base + i) = results[i].y;
695 output(lhs_vert + 2, horiz_base + i) = results[i].z;
696 output(lhs_vert + 3, horiz_base + i) = results[i].w;
698 }
else if (!CHECK_RHS_BOUNDARY) {
700 if (lhs_vert + 3 < m_size) {
701 for (
int i = 0; i < 4; i++) {
702 output(lhs_vert, horiz_base + i) = results[i].x;
703 output(lhs_vert + 1, horiz_base + i) = results[i].y;
704 output(lhs_vert + 2, horiz_base + i) = results[i].z;
705 output(lhs_vert + 3, horiz_base + i) = results[i].w;
707 }
else if (lhs_vert + 2 < m_size) {
708 for (
int i = 0; i < 4; i++) {
709 output(lhs_vert, horiz_base + i) = results[i].x;
710 output(lhs_vert + 1, horiz_base + i) = results[i].y;
711 output(lhs_vert + 2, horiz_base + i) = results[i].z;
713 }
else if (lhs_vert + 1 < m_size) {
714 for (
int i = 0; i < 4; i++) {
715 output(lhs_vert, horiz_base + i) = results[i].x;
716 output(lhs_vert + 1, horiz_base + i) = results[i].y;
718 }
else if (lhs_vert < m_size) {
719 for (
int i = 0; i < 4; i++) {
720 output(lhs_vert, horiz_base + i) = results[i].x;
723 }
else if (!CHECK_LHS_BOUNDARY) {
733 for (
int i = 0; i < 4; i++) {
734 if (horiz_base+i < n_size) {
735 output(lhs_vert, horiz_base + i) = results[i].x;
736 output(lhs_vert + 1, horiz_base + i) = results[i].y;
737 output(lhs_vert + 2, horiz_base + i) = results[i].z;
738 output(lhs_vert + 3, horiz_base + i) = results[i].w;
743 for (
int i = 0; i < 4; i++) {
744 if (horiz_base+i < n_size) {
745 if (lhs_vert < m_size)
746 output(lhs_vert, horiz_base + i) = results[i].x;
747 if (lhs_vert + 1 < m_size)
748 output(lhs_vert + 1, horiz_base + i) = results[i].y;
749 if (lhs_vert + 2 < m_size)
750 output(lhs_vert + 2, horiz_base + i) = results[i].z;
751 if (lhs_vert + 3 < m_size)
752 output(lhs_vert + 3, horiz_base + i) = results[i].w;
759 template<
typename Index,
typename LhsMapper,
760 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
761 bool CHECK_RHS_BOUNDARY>
762 __device__ EIGEN_STRONG_INLINE
void 763 EigenFloatContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
764 const OutputMapper output, float2 lhs_shmem2[][32],
765 float2 rhs_shmem2[][8],
const Index m_size,
766 const Index n_size,
const Index k_size,
767 const Index base_m,
const Index base_n) {
768 typedef float Scalar;
771 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
772 float4 rhs_pf0, rhs_pf1;
775 for (
int i=0; i < 8; i++) {
776 results[i].x = results[i].y = results[i].z = results[i].w = 0;
780 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
781 for (Index k = 0; k < k_size; k += 32) {
782 lhs_pf0 = internal::pset1<float4>(0);
783 lhs_pf1 = internal::pset1<float4>(0);
784 lhs_pf2 = internal::pset1<float4>(0);
785 lhs_pf3 = internal::pset1<float4>(0);
787 rhs_pf0 = internal::pset1<float4>(0);
788 rhs_pf1 = internal::pset1<float4>(0);
790 if (!CHECK_LHS_BOUNDARY) {
791 if ((threadIdx.y/4+k+24) < k_size) {
792 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
793 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
794 lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16));
795 lhs_pf3 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+24));
796 }
else if ((threadIdx.y/4+k+16) < k_size) {
797 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
798 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
799 lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16));
800 }
else if ((threadIdx.y/4+k+8) < k_size) {
801 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
802 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
803 }
else if ((threadIdx.y/4+k) < k_size) {
804 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
808 if (lhs_vert + 3 < m_size) {
809 if ((threadIdx.y/4+k+24) < k_size) {
810 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
811 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
812 lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16));
813 lhs_pf3 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+24));
814 }
else if ((threadIdx.y/4+k+16) < k_size) {
815 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
816 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
817 lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16));
818 }
else if ((threadIdx.y/4+k+8) < k_size) {
819 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
820 lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8));
821 }
else if ((threadIdx.y/4+k) < k_size) {
822 lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k));
824 }
else if (lhs_vert + 2 < m_size) {
825 if ((threadIdx.y/4+k+24) < k_size) {
826 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
827 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
828 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
829 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
830 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
831 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
832 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
833 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
834 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
835 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
836 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
837 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
838 }
else if ((threadIdx.y/4+k+16) < k_size) {
839 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
840 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
841 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
842 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
843 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
844 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
845 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
846 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
847 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
848 }
else if ((threadIdx.y/4+k+8) < k_size) {
849 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
850 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
851 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
852 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
853 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
854 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
855 }
else if ((threadIdx.y/4+k) < k_size) {
856 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
857 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
858 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
860 }
else if (lhs_vert + 1 < m_size) {
861 if ((threadIdx.y/4+k+24) < k_size) {
862 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
863 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
864 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
865 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
866 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
867 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
868 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
869 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
870 }
else if ((threadIdx.y/4+k+16) < k_size) {
871 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
872 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
873 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
874 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
875 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
876 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
877 }
else if ((threadIdx.y/4+k+8) < k_size) {
878 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
879 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
880 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
881 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
882 }
else if ((threadIdx.y/4+k) < k_size) {
883 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
884 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
886 }
else if (lhs_vert < m_size) {
887 if ((threadIdx.y/4+k+24) < k_size) {
888 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
889 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
890 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
891 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
892 }
else if ((threadIdx.y/4+k+16) < k_size) {
893 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
894 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
895 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
896 }
else if ((threadIdx.y/4+k+8) < k_size) {
897 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
898 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
899 }
else if ((threadIdx.y/4+k) < k_size) {
900 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
905 Index rhs_vert = k+threadIdx.x*4;
906 Index rhs_horiz0 = threadIdx.y*2+base_n;
907 Index rhs_horiz1 = threadIdx.y*2+1+base_n;
908 if (!CHECK_RHS_BOUNDARY) {
909 if ((rhs_vert + 3) < k_size) {
911 rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0);
912 rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz1);
913 }
else if (rhs_vert + 2 < k_size) {
915 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
916 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
917 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
918 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
919 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
920 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
921 }
else if (rhs_vert + 1 < k_size) {
922 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
923 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
924 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
925 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
926 }
else if (rhs_vert < k_size) {
927 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
928 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
931 if (rhs_horiz1 < n_size) {
932 if ((rhs_vert + 3) < k_size) {
934 rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0);
935 rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz1);
936 }
else if (rhs_vert + 2 < k_size) {
938 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
939 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
940 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
941 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
942 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
943 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
944 }
else if (k+threadIdx.x*4 + 1 < k_size) {
945 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
946 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
947 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
948 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
949 }
else if (k+threadIdx.x*4 < k_size) {
950 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
951 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
953 }
else if (rhs_horiz0 < n_size) {
954 if ((rhs_vert + 3) < k_size) {
956 rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0);
957 }
else if ((rhs_vert + 2) < k_size) {
959 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
960 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
961 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
962 }
else if ((rhs_vert + 1) < k_size) {
963 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
964 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
965 }
else if (rhs_vert < k_size) {
966 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
976 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
980 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
983 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
986 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
996 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ 997 results[0].x += a_feat1.x * f1.x;\ 998 results[1].x += a_feat1.x * f1.y;\ 999 results[2].x += a_feat1.x * f2.x;\ 1000 results[3].x += a_feat1.x * f2.y;\ 1001 results[4].x += a_feat1.x * f3.x;\ 1002 results[5].x += a_feat1.x * f3.y;\ 1003 results[6].x += a_feat1.x * f4.x;\ 1004 results[7].x += a_feat1.x * f4.y;\ 1006 results[0].y += a_feat1.y * f1.x;\ 1007 results[1].y += a_feat1.y * f1.y;\ 1008 results[2].y += a_feat1.y * f2.x;\ 1009 results[3].y += a_feat1.y * f2.y;\ 1010 results[4].y += a_feat1.y * f3.x;\ 1011 results[5].y += a_feat1.y * f3.y;\ 1012 results[6].y += a_feat1.y * f4.x;\ 1013 results[7].y += a_feat1.y * f4.y;\ 1015 results[0].z += a_feat2.x * f1.x;\ 1016 results[1].z += a_feat2.x * f1.y;\ 1017 results[2].z += a_feat2.x * f2.x;\ 1018 results[3].z += a_feat2.x * f2.y;\ 1019 results[4].z += a_feat2.x * f3.x;\ 1020 results[5].z += a_feat2.x * f3.y;\ 1021 results[6].z += a_feat2.x * f4.x;\ 1022 results[7].z += a_feat2.x * f4.y;\ 1024 results[0].w += a_feat2.y * f1.x;\ 1025 results[1].w += a_feat2.y * f1.y;\ 1026 results[2].w += a_feat2.y * f2.x;\ 1027 results[3].w += a_feat2.y * f2.y;\ 1028 results[4].w += a_feat2.y * f3.x;\ 1029 results[5].w += a_feat2.y * f3.y;\ 1030 results[6].w += a_feat2.y * f4.x;\ 1031 results[7].w += a_feat2.y * f4.y;\ 1033 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1034 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1035 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1036 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1038 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1039 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1040 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1041 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1047 for (
int koff = 0; koff < 32; koff ++) {
1048 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1049 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1052 int start_feature = (threadIdx.y / 4) * 8;
1054 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1055 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1056 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1057 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1059 add_vals(a3, a4, br1, br2, br3, br4)
1066 Index horiz_base = (threadIdx.y/4)*8+base_n;
1067 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1068 for (
int i = 0; i < 8; i++) {
1069 output(lhs_vert, horiz_base + i) = results[i].x;
1070 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1071 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1072 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1074 }
else if (!CHECK_RHS_BOUNDARY) {
1075 if (lhs_vert + 3 < m_size) {
1076 for (
int i = 0; i < 8; i++) {
1077 output(lhs_vert, horiz_base + i) = results[i].x;
1078 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1079 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1080 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082 }
else if (lhs_vert + 2 < m_size) {
1083 for (
int i = 0; i < 8; i++) {
1084 output(lhs_vert, horiz_base + i) = results[i].x;
1085 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1086 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1088 }
else if (lhs_vert + 1 < m_size) {
1089 for (
int i = 0; i < 8; i++) {
1090 output(lhs_vert, horiz_base + i) = results[i].x;
1091 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1093 }
else if (lhs_vert < m_size) {
1094 for (
int i = 0; i < 8; i++) {
1095 output(lhs_vert, horiz_base + i) = results[i].x;
1098 }
else if (!CHECK_LHS_BOUNDARY) {
1100 for (
int i = 0; i < 8; i++) {
1101 if (horiz_base + i < n_size) {
1102 output(lhs_vert, horiz_base + i) = results[i].x;
1103 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1104 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1105 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1110 for (
int i = 0; i < 8; i++) {
1111 if (horiz_base + i < n_size) {
1112 if (lhs_vert < m_size)
1113 output(lhs_vert, horiz_base + i) = results[i].x;
1114 if (lhs_vert + 1 < m_size)
1115 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1116 if (lhs_vert + 2 < m_size)
1117 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1118 if (lhs_vert + 3 < m_size)
1119 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1126 template<
typename Index,
typename LhsMapper,
1127 typename RhsMapper,
typename OutputMapper>
1129 __launch_bounds__(256)
1130 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1131 const OutputMapper output,
1132 const Index m_size, const Index n_size, const Index k_size) {
1133 __shared__ float2 lhs_shmem[64*32];
1134 __shared__ float2 rhs_shmem[128*8];
1136 typedef float2 LHS_MEM[64][32];
1137 typedef float2 RHS_MEM[128][8];
1139 typedef float2 LHS_MEM16x16[32][16];
1140 typedef float2 RHS_MEM16x16[64][8];
1142 const Index m_block_idx = blockIdx.x;
1143 const Index n_block_idx = blockIdx.y;
1145 const Index base_m = 128 * m_block_idx;
1146 const Index base_n = 64 * n_block_idx;
1148 bool check_rhs = (base_n + 63) >= n_size;
1149 bool check_lhs128 = (base_m + 127) >= m_size;
1152 if (!check_lhs128) {
1154 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1155 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1157 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1158 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1161 if (!check_lhs128) {
1163 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1164 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1167 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1172 template<
typename Index,
typename LhsMapper,
1173 typename RhsMapper,
typename OutputMapper>
1175 __launch_bounds__(256)
1176 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1177 const OutputMapper output,
1178 const Index m_size, const Index n_size, const Index k_size) {
1179 __shared__ float2 lhs_shmem[32][16];
1180 __shared__ float2 rhs_shmem[64][8];
1182 const Index m_block_idx = blockIdx.x;
1183 const Index n_block_idx = blockIdx.y;
1185 const Index base_m = 64 * m_block_idx;
1186 const Index base_n = 64 * n_block_idx;
1188 if (base_m + 63 < m_size) {
1189 if (base_n + 63 < n_size) {
1190 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1192 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1195 if (base_n + 63 < n_size) {
1196 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1198 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1204 template<
typename Indices,
typename LeftArgType,
typename RightArgType>
1205 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> :
1206 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
1208 typedef GpuDevice Device;
1210 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
1211 typedef TensorContractionEvaluatorBase<Self> Base;
1213 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
1214 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1215 typedef typename XprType::Packet Packet;
1216 typedef typename XprType::Index Index;
1217 typedef typename XprType::CoeffReturnType CoeffReturnType;
1218 typedef typename XprType::PacketReturnType PacketReturnType;
1221 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1228 typedef typename internal::conditional<
1229 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1230 typedef typename internal::conditional<
1231 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1233 static const int LDims =
1234 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1235 static const int RDims =
1236 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1237 static const int ContractDims = internal::array_size<Indices>::value;
1239 typedef array<Index, LDims> left_dim_mapper_t;
1240 typedef array<Index, RDims> right_dim_mapper_t;
1242 typedef array<Index, ContractDims> contract_t;
1243 typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
1244 typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
1246 static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
1248 typedef DSizes<Index, NumDims> Dimensions;
1251 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1252 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1254 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1255 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1257 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1258 typedef typename RightEvaluator::Dimensions RightDimensions;
1260 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device) :
1264 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* data) {
1265 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1266 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1271 this->m_result =
static_cast<Scalar *
>(this->m_device.allocate(this->dimensions().TotalSize() *
sizeof(Scalar)));
1272 evalTo(this->m_result);
1277 void evalTo(Scalar* buffer)
const {
1278 if (this->m_lhs_inner_dim_contiguous) {
1279 if (this->m_rhs_inner_dim_contiguous) {
1280 if (this->m_rhs_inner_dim_reordered) {
1281 evalTyped<true, true, true, Unaligned>(buffer);
1284 evalTyped<true, true, false, Unaligned>(buffer);
1288 if (this->m_rhs_inner_dim_reordered) {
1289 evalTyped<true, false, true, Unaligned>(buffer);
1292 evalTyped<true, false, false, Unaligned>(buffer);
1297 if (this->m_rhs_inner_dim_contiguous) {
1298 if (this->m_rhs_inner_dim_reordered) {
1299 evalTyped<false, true, true, Unaligned>(buffer);
1302 evalTyped<false, true, false, Unaligned>(buffer);
1306 if (this->m_rhs_inner_dim_reordered) {
1307 evalTyped<false, false, true, Unaligned>(buffer);
1310 evalTyped<false, false, false, Unaligned>(buffer);
1316 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1317 void evalTyped(Scalar* buffer)
const {
1319 const Index k = this->m_k_size;
1322 const Index m = this->m_i_size;
1325 const Index n = this->m_j_size;
1328 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
1330 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1331 LeftEvaluator, left_nocontract_t,
1333 lhs_inner_dim_contiguous,
1334 false, Unaligned> LhsMapper;
1336 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1337 RightEvaluator, right_nocontract_t,
1339 rhs_inner_dim_contiguous,
1340 rhs_inner_dim_reordered, Unaligned> RhsMapper;
1342 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1346 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1347 this->m_left_contracting_strides, this->m_k_strides);
1349 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1350 this->m_right_contracting_strides, this->m_k_strides);
1352 OutputMapper output(buffer, m);
1354 setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
1355 if (internal::is_same<LhsScalar, float>::value &&
1356 internal::is_same<RhsScalar, float>::value) {
1357 if (m < 768 || n < 768) {
1358 const Index m_blocks = (m + 63) / 64;
1359 const Index n_blocks = (n + 63) / 64;
1360 const dim3 num_blocks(m_blocks, n_blocks, 1);
1361 const dim3 block_size(16, 16, 1);
1362 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
1364 const Index m_blocks = (m + 127) / 128;
1365 const Index n_blocks = (n + 63) / 64;
1366 const dim3 num_blocks(m_blocks, n_blocks, 1);
1367 const dim3 block_size(8, 32, 1);
1368 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
1371 const Index m_blocks = (m + 63) / 64;
1372 const Index n_blocks = (n + 63) / 64;
1373 const dim3 num_blocks(m_blocks, n_blocks, 1);
1374 const dim3 block_size(8, 8, 8);
1375 LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
1382 #endif // EIGEN_USE_GPU and __CUDACC__ 1383 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13