10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
32 template<
typename Scalar,
typename Index,
int side,
34 typename nocontract_t,
typename contract_t,
35 int packet_size,
bool inner_dim_contiguous>
36 class SimpleTensorContractionMapper {
39 SimpleTensorContractionMapper(
const Tensor& tensor,
40 const nocontract_t& nocontract_strides,
41 const nocontract_t& ij_strides,
42 const contract_t& contract_strides,
43 const contract_t& k_strides) :
45 m_nocontract_strides(nocontract_strides),
46 m_ij_strides(ij_strides),
47 m_contract_strides(contract_strides),
48 m_k_strides(k_strides) { }
51 EIGEN_STRONG_INLINE
void prefetch(Index ) { }
54 EIGEN_STRONG_INLINE Scalar operator()(Index row)
const {
56 return operator()(row, 0);
60 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col)
const {
61 return m_tensor.coeff(computeIndex(row, col));
65 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col)
const {
66 const bool left = (side == Lhs);
67 Index nocontract_val = left ? row : col;
69 for (
int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
70 const Index idx = nocontract_val / m_ij_strides[i];
71 linidx += idx * m_nocontract_strides[i];
72 nocontract_val -= idx * m_ij_strides[i];
74 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
75 if (side == Lhs && inner_dim_contiguous) {
76 eigen_assert(m_nocontract_strides[0] == 1);
77 linidx += nocontract_val;
79 linidx += nocontract_val * m_nocontract_strides[0];
83 Index contract_val = left ? col : row;
84 for (
int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
85 const Index idx = contract_val / m_k_strides[i];
86 linidx += idx * m_contract_strides[i];
87 contract_val -= idx * m_k_strides[i];
90 if(array_size<contract_t>::value > 0) {
91 if (side == Rhs && inner_dim_contiguous) {
92 eigen_assert(m_contract_strides[0] == 1);
93 linidx += contract_val;
95 linidx += contract_val * m_contract_strides[0];
103 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
const Index distance)
const {
104 const bool left = (side == Lhs);
105 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
106 Index linidx[2] = {0, 0};
107 for (
int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
108 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
109 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
110 linidx[0] += idx0 * m_nocontract_strides[i];
111 linidx[1] += idx1 * m_nocontract_strides[i];
112 nocontract_val[0] -= idx0 * m_ij_strides[i];
113 nocontract_val[1] -= idx1 * m_ij_strides[i];
115 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
116 if (side == Lhs && inner_dim_contiguous) {
117 eigen_assert(m_nocontract_strides[0] == 1);
118 linidx[0] += nocontract_val[0];
119 linidx[1] += nocontract_val[1];
121 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
122 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
126 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
127 for (
int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
128 const Index idx0 = contract_val[0] / m_k_strides[i];
129 const Index idx1 = contract_val[1] / m_k_strides[i];
130 linidx[0] += idx0 * m_contract_strides[i];
131 linidx[1] += idx1 * m_contract_strides[i];
132 contract_val[0] -= idx0 * m_k_strides[i];
133 contract_val[1] -= idx1 * m_k_strides[i];
136 if (side == Rhs && inner_dim_contiguous) {
137 eigen_assert(m_contract_strides[0] == 1);
138 linidx[0] += contract_val[0];
139 linidx[1] += contract_val[1];
141 linidx[0] += contract_val[0] * m_contract_strides[0];
142 linidx[1] += contract_val[1] * m_contract_strides[0];
144 return IndexPair<Index>(linidx[0], linidx[1]);
147 Index firstAligned(Index size)
const {
150 Index stride()
const {
155 const Tensor m_tensor;
156 const nocontract_t m_nocontract_strides;
157 const nocontract_t m_ij_strides;
158 const contract_t m_contract_strides;
159 const contract_t m_k_strides;
163 template<
typename Scalar,
typename Index,
int side,
165 typename nocontract_t,
typename contract_t,
166 int packet_size,
bool inner_dim_contiguous,
167 bool inner_dim_reordered,
int Alignment>
168 class BaseTensorContractionMapper :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
171 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
174 BaseTensorContractionMapper(
const Tensor& tensor,
175 const nocontract_t& nocontract_strides,
176 const nocontract_t& ij_strides,
177 const contract_t& contract_strides,
178 const contract_t& k_strides) :
179 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
181 typedef typename packet_traits<Scalar>::type Packet;
182 typedef typename packet_traits<Scalar>::half HalfPacket;
185 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j)
const {
190 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
192 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
193 const Index index = this->computeIndex(i, j);
194 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
195 return this->m_tensor.template packet<Alignment>(index);
198 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
199 const Index first = indexPair.first;
200 const Index last = indexPair.second;
206 if (Tensor::PacketAccess &&
207 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
208 (last - first) == (packet_size - 1)) {
210 return this->m_tensor.template packet<Alignment>(first);
213 EIGEN_ALIGN_MAX Scalar data[packet_size];
215 data[0] = this->m_tensor.coeff(first);
216 for (Index k = 1; k < packet_size - 1; k += 2) {
217 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
218 data[k] = this->m_tensor.coeff(internal_pair.first);
219 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
221 data[packet_size - 1] = this->m_tensor.coeff(last);
223 return pload<Packet>(data);
227 EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j)
const {
231 const Index half_packet_size = unpacket_traits<HalfPacket>::size;
232 if (half_packet_size == packet_size) {
233 return loadPacket(i, j);
235 EIGEN_ALIGN_MAX Scalar data[half_packet_size];
236 for (Index k = 0; k < half_packet_size; k++) {
237 data[k] = operator()(i + k, j);
239 return pload<HalfPacket>(data);
244 template<
typename Scalar,
typename Index,
int side,
246 typename nocontract_t,
typename contract_t,
247 bool inner_dim_contiguous,
248 bool inner_dim_reordered,
int Alignment>
249 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
252 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
255 BaseTensorContractionMapper(
const Tensor& tensor,
256 const nocontract_t& nocontract_strides,
257 const nocontract_t& ij_strides,
258 const contract_t& contract_strides,
259 const contract_t& k_strides) :
260 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
262 typedef typename packet_traits<Scalar>::type Packet;
264 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j)
const {
265 EIGEN_ALIGN_MAX Scalar data[1];
266 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
267 return pload<typename packet_traits<Scalar>::type>(data);
270 EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j)
const {
271 return loadPacket(i, j);
275 template<
typename Scalar,
typename Index,
int side,
277 typename nocontract_t,
typename contract_t,
279 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
280 class TensorContractionInputMapper;
282 template<
typename Scalar,
typename Index,
int side,
284 typename nocontract_t,
typename contract_t,
286 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
287 class TensorContractionSubMapper {
289 typedef typename packet_traits<Scalar>::type Packet;
290 typedef typename packet_traits<Scalar>::half HalfPacket;
292 typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
293 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
294 typedef Self LinearMapper;
296 EIGEN_DEVICE_FUNC TensorContractionSubMapper(
const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
297 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
299 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i)
const {
300 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
302 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j)
const {
303 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
306 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i)
const {
307 return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
309 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j)
const {
310 return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
313 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i)
const {
314 return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
317 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(Index i, Packet p)
const {
318 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
321 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j)
const {
322 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
325 template <
typename PacketT,
int AlignmentType>
326 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i)
const {
327 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
328 EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
329 return loadPacket(i);
332 template <
typename Packet>
333 EIGEN_DEVICE_FUNC
bool aligned(Index)
const {
338 const ParentMapper& m_base_mapper;
339 const Index m_vert_offset;
340 const Index m_horiz_offset;
344 template<
typename Scalar,
typename Index,
int side,
346 typename nocontract_t,
typename contract_t,
348 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
349 class TensorContractionInputMapper
350 :
public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
353 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
354 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
355 typedef SubMapper VectorMapper;
357 EIGEN_DEVICE_FUNC TensorContractionInputMapper(
const Tensor& tensor,
358 const nocontract_t& nocontract_strides,
359 const nocontract_t& ij_strides,
360 const contract_t& contract_strides,
361 const contract_t& k_strides)
362 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
365 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j)
const {
366 return SubMapper(*
this, i, j);
369 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j)
const {
370 return VectorMapper(*
this, i, j);
376 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
377 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
380 typedef typename internal::promote_storage_type<
typename LhsXprType::Scalar,
381 typename RhsXprType::Scalar>::ret Scalar;
382 typedef typename internal::packet_traits<Scalar>::type Packet;
383 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
384 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
385 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
386 typename traits<RhsXprType>::Index>::type Index;
387 typedef typename LhsXprType::Nested LhsNested;
388 typedef typename RhsXprType::Nested RhsNested;
389 typedef typename remove_reference<LhsNested>::type _LhsNested;
390 typedef typename remove_reference<RhsNested>::type _RhsNested;
393 static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
394 static const int Layout = traits<LhsXprType>::Layout;
401 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
402 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>,
Eigen::Dense>
404 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
407 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
408 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
410 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
413 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename Device_>
414 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
415 typedef Indices_ Indices;
416 typedef LeftArgType_ LeftArgType;
417 typedef RightArgType_ RightArgType;
418 typedef Device_ Device;
421 static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
426 template<
typename Indices,
typename LhsXprType,
typename RhsXprType>
427 class TensorContractionOp :
public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
430 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
431 typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet;
432 typedef typename internal::promote_storage_type<
typename LhsXprType::CoeffReturnType,
433 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
434 typedef typename internal::promote_storage_type<
typename LhsXprType::PacketReturnType,
435 typename RhsXprType::PacketReturnType>::ret PacketReturnType;
436 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
437 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
438 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
440 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
441 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims)
442 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
445 const Indices& indices()
const {
return m_indices; }
449 const typename internal::remove_all<typename LhsXprType::Nested>::type&
450 lhsExpression()
const {
return m_lhs_xpr; }
453 const typename internal::remove_all<typename RhsXprType::Nested>::type&
454 rhsExpression()
const {
return m_rhs_xpr; }
457 typename LhsXprType::Nested m_lhs_xpr;
458 typename RhsXprType::Nested m_rhs_xpr;
459 const Indices m_indices;
463 template<
typename Derived>
464 struct TensorContractionEvaluatorBase
466 typedef typename internal::traits<Derived>::Indices Indices;
467 typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
468 typedef typename internal::traits<Derived>::RightArgType RightArgType;
469 typedef typename internal::traits<Derived>::Device Device;
471 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
472 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
473 typedef typename XprType::Packet Packet;
474 typedef typename XprType::Index Index;
475 typedef typename XprType::CoeffReturnType CoeffReturnType;
476 typedef typename XprType::PacketReturnType PacketReturnType;
480 PacketAccess = (internal::packet_traits<Scalar>::size > 1),
481 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
489 typedef typename internal::conditional<
490 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
491 typedef typename internal::conditional<
492 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
494 static const int LDims =
495 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
496 static const int RDims =
497 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
498 static const int ContractDims = internal::array_size<Indices>::value;
499 static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
501 typedef array<Index, LDims> left_dim_mapper_t;
502 typedef array<Index, RDims> right_dim_mapper_t;
503 typedef array<Index, ContractDims> contract_t;
504 typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
505 typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
507 typedef DSizes<Index, NumDims> Dimensions;
509 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
510 TensorContractionEvaluatorBase(
const XprType& op,
const Device& device)
511 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
512 op.lhsExpression(), op.rhsExpression()), device),
513 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
514 op.rhsExpression(), op.lhsExpression()), device),
517 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
518 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
519 YOU_MADE_A_PROGRAMMING_MISTAKE);
522 DSizes<Index, LDims> eval_left_dims;
523 DSizes<Index, RDims> eval_right_dims;
524 array<IndexPair<Index>, ContractDims> eval_op_indices;
525 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
527 for (
int i = 0; i < LDims; i++) {
528 eval_left_dims[i] = m_leftImpl.dimensions()[i];
530 for (
int i = 0; i < RDims; i++) {
531 eval_right_dims[i] = m_rightImpl.dimensions()[i];
534 for (
int i = 0; i < ContractDims; i++) {
535 eval_op_indices[i].first = op.indices()[i].first;
536 eval_op_indices[i].second = op.indices()[i].second;
540 for (
int i = 0; i < LDims; i++) {
541 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
543 for (
int i = 0; i < RDims; i++) {
544 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
548 for (
int i = 0; i < ContractDims; i++) {
549 eval_op_indices[i].first = LDims - 1 - op.indices()[i].second;
550 eval_op_indices[i].second = RDims - 1 - op.indices()[i].first;
554 array<Index, LDims> lhs_strides;
556 for (
int i = 0; i < LDims-1; ++i) {
557 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
560 array<Index, RDims> rhs_strides;
562 for (
int i = 0; i < RDims-1; ++i) {
563 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
580 m_lhs_inner_dim_contiguous =
true;
582 unsigned int nocontract_idx = 0;
584 for (
int i = 0; i < LDims; i++) {
586 bool contracting =
false;
587 for (
int j = 0; j < ContractDims; j++) {
588 if (eval_op_indices[j].first == i) {
595 m_dimensions[dim_idx] = eval_left_dims[i];
596 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
598 m_lhs_inner_dim_contiguous =
false;
600 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
601 m_i_strides[nocontract_idx+1] =
602 m_i_strides[nocontract_idx] * eval_left_dims[i];
604 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
612 for (
int i = 0; i < RDims; i++) {
613 bool contracting =
false;
615 for (
int j = 0; j < ContractDims; j++) {
616 if (eval_op_indices[j].second == i) {
622 m_dimensions[dim_idx] = eval_right_dims[i];
623 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
624 m_j_strides[nocontract_idx+1] =
625 m_j_strides[nocontract_idx] * eval_right_dims[i];
627 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
629 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
640 m_rhs_inner_dim_contiguous =
true;
641 m_rhs_inner_dim_reordered =
false;
642 for (
int i = 0; i < ContractDims; i++) {
643 Index left = eval_op_indices[i].first;
644 Index right = eval_op_indices[i].second;
646 Index size = eval_left_dims[left];
647 eigen_assert(size == eval_right_dims[right] &&
648 "Contraction axes must be same size");
650 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) {
651 m_k_strides[i+1] = m_k_strides[i] * size;
653 m_k_size = m_k_strides[i] * size;
655 m_left_contracting_strides[i] = lhs_strides[left];
656 m_right_contracting_strides[i] = rhs_strides[right];
658 if (i > 0 && right < eval_op_indices[i-1].second) {
659 m_rhs_inner_dim_reordered =
true;
662 m_rhs_inner_dim_contiguous =
false;
667 if (LDims + RDims == 2 * ContractDims) {
672 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
673 for (
int i = 0, j = NumDims - 1; i < j; i++, j--) {
674 numext::swap(m_dimensions[i], m_dimensions[j]);
679 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
681 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* data) {
682 m_leftImpl.evalSubExprsIfNeeded(NULL);
683 m_rightImpl.evalSubExprsIfNeeded(NULL);
688 m_result =
static_cast<Scalar *
>(m_device.allocate(dimensions().TotalSize() *
sizeof(Scalar)));
694 EIGEN_DEVICE_FUNC
void evalTo(Scalar* buffer)
const {
695 if (this->m_lhs_inner_dim_contiguous) {
696 if (this->m_rhs_inner_dim_contiguous) {
697 if (this->m_rhs_inner_dim_reordered) {
698 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, true, Unaligned>(buffer);
701 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, false, Unaligned>(buffer);
705 if (this->m_rhs_inner_dim_reordered) {
706 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, true, Unaligned>(buffer);
709 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, false, Unaligned>(buffer);
714 if (this->m_rhs_inner_dim_contiguous) {
715 if (this->m_rhs_inner_dim_reordered) {
716 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, true, Unaligned>(buffer);
719 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, false, Unaligned>(buffer);
723 if (this->m_rhs_inner_dim_reordered) {
724 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, true, Unaligned>(buffer);
727 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, false, Unaligned>(buffer);
733 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
734 void evalGemv(Scalar* buffer)
const {
735 const Index rows = m_i_size;
736 const Index cols = m_k_size;
738 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
739 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
740 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
741 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
742 const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
743 const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
744 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
745 LeftEvaluator, left_nocontract_t,
746 contract_t, lhs_packet_size,
747 lhs_inner_dim_contiguous,
748 false, Unaligned> LhsMapper;
750 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
751 RightEvaluator, right_nocontract_t,
752 contract_t, rhs_packet_size,
753 rhs_inner_dim_contiguous,
754 rhs_inner_dim_reordered, Unaligned> RhsMapper;
756 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
757 m_left_contracting_strides, m_k_strides);
758 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
759 m_right_contracting_strides, m_k_strides);
761 const Scalar alpha(1);
762 const Index resIncr(1);
765 m_device.memset(buffer, 0, rows *
sizeof(Scalar));
767 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
768 rows, cols, lhs, rhs,
769 buffer, resIncr, alpha);
772 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
773 m_leftImpl.cleanup();
774 m_rightImpl.cleanup();
776 if (m_result != NULL) {
777 m_device.deallocate(m_result);
782 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
783 return m_result[index];
786 template<
int LoadMode>
787 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index)
const {
788 return internal::ploadt<Packet, LoadMode>(m_result + index);
791 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
795 TensorContractionEvaluatorBase& operator = (
const TensorContractionEvaluatorBase&);
796 Dimensions m_dimensions;
798 contract_t m_k_strides;
799 contract_t m_left_contracting_strides;
800 contract_t m_right_contracting_strides;
802 bool m_lhs_inner_dim_contiguous;
803 bool m_rhs_inner_dim_contiguous;
804 bool m_rhs_inner_dim_reordered;
806 left_nocontract_t m_i_strides;
807 right_nocontract_t m_j_strides;
808 left_nocontract_t m_left_nocontract_strides;
809 right_nocontract_t m_right_nocontract_strides;
815 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
816 TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
817 const Device& m_device;
823 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename Device>
824 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
825 public TensorContractionEvaluatorBase<
826 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
827 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
828 typedef TensorContractionEvaluatorBase<Self> Base;
830 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
831 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
832 typedef typename XprType::Packet Packet;
833 typedef typename XprType::Index Index;
834 typedef typename XprType::CoeffReturnType CoeffReturnType;
835 typedef typename XprType::PacketReturnType PacketReturnType;
838 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
845 typedef typename internal::conditional<
846 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
847 typedef typename internal::conditional<
848 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
850 static const int LDims =
851 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
852 static const int RDims =
853 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
854 static const int ContractDims = internal::array_size<Indices>::value;
856 typedef array<Index, LDims> left_dim_mapper_t;
857 typedef array<Index, RDims> right_dim_mapper_t;
859 typedef array<Index, ContractDims> contract_t;
860 typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
861 typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
863 static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
866 typedef DSizes<Index, NumDims> Dimensions;
869 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device) :
872 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
873 void evalProduct(Scalar* buffer)
const {
874 if (this->m_j_size == 1) {
875 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
879 evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
882 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
883 EIGEN_DEVICE_FUNC
void evalGemm(Scalar* buffer)
const {
885 const Index k = this->m_k_size;
888 const Index m = this->m_i_size;
891 const Index n = this->m_j_size;
894 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
897 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
898 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
899 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
901 const Index nr = Traits::nr;
902 const Index mr = Traits::mr;
904 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
905 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
907 const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
908 const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
910 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
911 LeftEvaluator, left_nocontract_t,
912 contract_t, lhs_packet_size,
913 lhs_inner_dim_contiguous,
914 false, Unaligned> LhsMapper;
916 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
917 RightEvaluator, right_nocontract_t,
918 contract_t, rhs_packet_size,
919 rhs_inner_dim_contiguous,
920 rhs_inner_dim_reordered, Unaligned> RhsMapper;
922 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
925 internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
926 internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
928 internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
931 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
932 this->m_left_contracting_strides, this->m_k_strides);
934 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
935 this->m_right_contracting_strides, this->m_k_strides);
937 OutputMapper output(buffer, m);
939 typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
942 BlockingType blocking(m, n, k, 1,
true);
943 const Index kc = blocking.kc();
944 const Index mc = numext::mini(m, blocking.mc());
945 const Index nc = numext::mini(n, blocking.nc());
946 const Index sizeA = mc * kc;
947 const Index sizeB = kc * nc;
949 LhsScalar* blockA =
static_cast<LhsScalar *
>(this->m_device.allocate(sizeA *
sizeof(LhsScalar)));
950 RhsScalar* blockB =
static_cast<RhsScalar *
>(this->m_device.allocate(sizeB *
sizeof(RhsScalar)));
952 for(Index i2=0; i2<m; i2+=mc)
954 const Index actual_mc = numext::mini(i2+mc,m)-i2;
955 for (Index k2 = 0; k2 < k; k2 += kc) {
957 const Index actual_kc = numext::mini(k2 + kc, k) - k2;
958 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
961 for (Index j2 = 0; j2 < n; j2 += nc) {
963 const Index actual_nc = numext::mini(j2 + nc, n) - j2;
964 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
968 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
973 this->m_device.deallocate(blockA);
974 this->m_device.deallocate(blockB);
980 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13