TensorContraction.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 
24 enum {
25  Rhs = 0,
26  Lhs = 1,
27 };
28 
29 /*
30  * Implementation of the Eigen blas_data_mapper class for tensors.
31  */
32 template<typename Scalar, typename Index, int side,
33  typename Tensor,
34  typename nocontract_t, typename contract_t,
35  int packet_size, bool inner_dim_contiguous>
36 class SimpleTensorContractionMapper {
37  public:
38  EIGEN_DEVICE_FUNC
39  SimpleTensorContractionMapper(const Tensor& tensor,
40  const nocontract_t& nocontract_strides,
41  const nocontract_t& ij_strides,
42  const contract_t& contract_strides,
43  const contract_t& k_strides) :
44  m_tensor(tensor),
45  m_nocontract_strides(nocontract_strides),
46  m_ij_strides(ij_strides),
47  m_contract_strides(contract_strides),
48  m_k_strides(k_strides) { }
49 
50  EIGEN_DEVICE_FUNC
51  EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
52 
53  EIGEN_DEVICE_FUNC
54  EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
55  // column major assumption
56  return operator()(row, 0);
57  }
58 
59  EIGEN_DEVICE_FUNC
60  EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
61  return m_tensor.coeff(computeIndex(row, col));
62  }
63 
64  EIGEN_DEVICE_FUNC
65  EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
66  const bool left = (side == Lhs);
67  Index nocontract_val = left ? row : col;
68  Index linidx = 0;
69  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
70  const Index idx = nocontract_val / m_ij_strides[i];
71  linidx += idx * m_nocontract_strides[i];
72  nocontract_val -= idx * m_ij_strides[i];
73  }
74  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
75  if (side == Lhs && inner_dim_contiguous) {
76  eigen_assert(m_nocontract_strides[0] == 1);
77  linidx += nocontract_val;
78  } else {
79  linidx += nocontract_val * m_nocontract_strides[0];
80  }
81  }
82 
83  Index contract_val = left ? col : row;
84  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
85  const Index idx = contract_val / m_k_strides[i];
86  linidx += idx * m_contract_strides[i];
87  contract_val -= idx * m_k_strides[i];
88  }
89 
90  if(array_size<contract_t>::value > 0) {
91  if (side == Rhs && inner_dim_contiguous) {
92  eigen_assert(m_contract_strides[0] == 1);
93  linidx += contract_val;
94  } else {
95  linidx += contract_val * m_contract_strides[0];
96  }
97  }
98 
99  return linidx;
100  }
101 
102  EIGEN_DEVICE_FUNC
103  EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
104  const bool left = (side == Lhs);
105  Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
106  Index linidx[2] = {0, 0};
107  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
108  const Index idx0 = nocontract_val[0] / m_ij_strides[i];
109  const Index idx1 = nocontract_val[1] / m_ij_strides[i];
110  linidx[0] += idx0 * m_nocontract_strides[i];
111  linidx[1] += idx1 * m_nocontract_strides[i];
112  nocontract_val[0] -= idx0 * m_ij_strides[i];
113  nocontract_val[1] -= idx1 * m_ij_strides[i];
114  }
115  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
116  if (side == Lhs && inner_dim_contiguous) {
117  eigen_assert(m_nocontract_strides[0] == 1);
118  linidx[0] += nocontract_val[0];
119  linidx[1] += nocontract_val[1];
120  } else {
121  linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
122  linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
123  }
124  }
125 
126  Index contract_val[2] = {left ? col : row, left ? col : row + distance};
127  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
128  const Index idx0 = contract_val[0] / m_k_strides[i];
129  const Index idx1 = contract_val[1] / m_k_strides[i];
130  linidx[0] += idx0 * m_contract_strides[i];
131  linidx[1] += idx1 * m_contract_strides[i];
132  contract_val[0] -= idx0 * m_k_strides[i];
133  contract_val[1] -= idx1 * m_k_strides[i];
134  }
135 
136  if (side == Rhs && inner_dim_contiguous) {
137  eigen_assert(m_contract_strides[0] == 1);
138  linidx[0] += contract_val[0];
139  linidx[1] += contract_val[1];
140  } else {
141  linidx[0] += contract_val[0] * m_contract_strides[0];
142  linidx[1] += contract_val[1] * m_contract_strides[0];
143  }
144  return IndexPair<Index>(linidx[0], linidx[1]);
145  }
146 
147  Index firstAligned(Index size) const {
148  return size;
149  }
150  Index stride() const {
151  return 1;
152  }
153 
154  protected:
155  const Tensor m_tensor;
156  const nocontract_t m_nocontract_strides;
157  const nocontract_t m_ij_strides;
158  const contract_t m_contract_strides;
159  const contract_t m_k_strides;
160 };
161 
162 
163 template<typename Scalar, typename Index, int side,
164  typename Tensor,
165  typename nocontract_t, typename contract_t,
166  int packet_size, bool inner_dim_contiguous,
167  bool inner_dim_reordered, int Alignment>
168  class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
169 {
170  public:
171  typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
172 
173  EIGEN_DEVICE_FUNC
174  BaseTensorContractionMapper(const Tensor& tensor,
175  const nocontract_t& nocontract_strides,
176  const nocontract_t& ij_strides,
177  const contract_t& contract_strides,
178  const contract_t& k_strides) :
179  ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
180 
181  typedef typename packet_traits<Scalar>::type Packet;
182  typedef typename packet_traits<Scalar>::half HalfPacket;
183 
184  EIGEN_DEVICE_FUNC
185  EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
186  // whole method makes column major assumption
187 
188  // don't need to add offsets for now (because operator handles that)
189  // current code assumes packet size must be a multiple of 2
190  EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
191 
192  if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
193  const Index index = this->computeIndex(i, j);
194  eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
195  return this->m_tensor.template packet<Alignment>(index);
196  }
197 
198  const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
199  const Index first = indexPair.first;
200  const Index last = indexPair.second;
201 
202  // We can always do optimized packet reads from left hand side right now, because
203  // the vertical matrix dimension on the left hand side is never contracting.
204  // On the right hand side we need to check if the contracting dimensions may have
205  // been shuffled first.
206  if (Tensor::PacketAccess &&
207  (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
208  (last - first) == (packet_size - 1)) {
209 
210  return this->m_tensor.template packet<Alignment>(first);
211  }
212 
213  EIGEN_ALIGN_MAX Scalar data[packet_size];
214 
215  data[0] = this->m_tensor.coeff(first);
216  for (Index k = 1; k < packet_size - 1; k += 2) {
217  const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
218  data[k] = this->m_tensor.coeff(internal_pair.first);
219  data[k + 1] = this->m_tensor.coeff(internal_pair.second);
220  }
221  data[packet_size - 1] = this->m_tensor.coeff(last);
222 
223  return pload<Packet>(data);
224  }
225 
226  EIGEN_DEVICE_FUNC
227  EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
228  // whole method makes column major assumption
229 
230  // don't need to add offsets for now (because operator handles that)
231  const Index half_packet_size = unpacket_traits<HalfPacket>::size;
232  if (half_packet_size == packet_size) {
233  return loadPacket(i, j);
234  }
235  EIGEN_ALIGN_MAX Scalar data[half_packet_size];
236  for (Index k = 0; k < half_packet_size; k++) {
237  data[k] = operator()(i + k, j);
238  }
239  return pload<HalfPacket>(data);
240  }
241 };
242 
243 
244 template<typename Scalar, typename Index, int side,
245  typename Tensor,
246  typename nocontract_t, typename contract_t,
247  bool inner_dim_contiguous,
248  bool inner_dim_reordered, int Alignment>
249 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
250 {
251  public:
252  typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
253 
254  EIGEN_DEVICE_FUNC
255  BaseTensorContractionMapper(const Tensor& tensor,
256  const nocontract_t& nocontract_strides,
257  const nocontract_t& ij_strides,
258  const contract_t& contract_strides,
259  const contract_t& k_strides) :
260  ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
261 
262  typedef typename packet_traits<Scalar>::type Packet;
263  EIGEN_DEVICE_FUNC
264  EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
265  EIGEN_ALIGN_MAX Scalar data[1];
266  data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
267  return pload<typename packet_traits<Scalar>::type>(data);
268  }
269  EIGEN_DEVICE_FUNC
270  EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
271  return loadPacket(i, j);
272  }
273 };
274 
275 template<typename Scalar, typename Index, int side,
276  typename Tensor,
277  typename nocontract_t, typename contract_t,
278  int packet_size,
279  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
280 class TensorContractionInputMapper;
281 
282 template<typename Scalar, typename Index, int side,
283  typename Tensor,
284  typename nocontract_t, typename contract_t,
285  int packet_size,
286  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
287 class TensorContractionSubMapper {
288  public:
289  typedef typename packet_traits<Scalar>::type Packet;
290  typedef typename packet_traits<Scalar>::half HalfPacket;
291 
292  typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
293  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
294  typedef Self LinearMapper;
295 
296  EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
297  : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
298 
299  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
300  return m_base_mapper(i + m_vert_offset, m_horiz_offset);
301  }
302  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
303  return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
304  }
305 
306  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
307  return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
308  }
309  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
310  return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
311  }
312 
313  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
314  return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
315  }
316 
317  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
318  m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
319  }
320 
321  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
322  return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
323  }
324 
325  template <typename PacketT, int AlignmentType>
326  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
327  EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
328  EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
329  return loadPacket(i);
330  }
331 
332  template <typename Packet>
333  EIGEN_DEVICE_FUNC bool aligned(Index) const {
334  return false;
335  }
336 
337  private:
338  const ParentMapper& m_base_mapper;
339  const Index m_vert_offset;
340  const Index m_horiz_offset;
341 };
342 
343 
344 template<typename Scalar, typename Index, int side,
345  typename Tensor,
346  typename nocontract_t, typename contract_t,
347  int packet_size,
348  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
349 class TensorContractionInputMapper
350  : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
351 
352  public:
353  typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
354  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
355  typedef SubMapper VectorMapper;
356 
357  EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
358  const nocontract_t& nocontract_strides,
359  const nocontract_t& ij_strides,
360  const contract_t& contract_strides,
361  const contract_t& k_strides)
362  : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
363 
364  EIGEN_DEVICE_FUNC
365  EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
366  return SubMapper(*this, i, j);
367  }
368 
369  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
370  return VectorMapper(*this, i, j);
371  }
372 };
373 
374 
375 
376 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
377 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
378 {
379  // Type promotion to handle the case where the types of the lhs and the rhs are different.
380  typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
381  typename RhsXprType::Scalar>::ret Scalar;
382  typedef typename internal::packet_traits<Scalar>::type Packet;
383  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
384  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
385  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
386  typename traits<RhsXprType>::Index>::type Index;
387  typedef typename LhsXprType::Nested LhsNested;
388  typedef typename RhsXprType::Nested RhsNested;
389  typedef typename remove_reference<LhsNested>::type _LhsNested;
390  typedef typename remove_reference<RhsNested>::type _RhsNested;
391 
392  // From NumDims below.
393  static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
394  static const int Layout = traits<LhsXprType>::Layout;
395 
396  enum {
397  Flags = 0,
398  };
399 };
400 
401 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
402 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
403 {
404  typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
405 };
406 
407 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
408 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
409 {
410  typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
411 };
412 
413 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
414 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
415  typedef Indices_ Indices;
416  typedef LeftArgType_ LeftArgType;
417  typedef RightArgType_ RightArgType;
418  typedef Device_ Device;
419 
420  // From NumDims below.
421  static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
422 };
423 
424 } // end namespace internal
425 
426 template<typename Indices, typename LhsXprType, typename RhsXprType>
427 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
428 {
429  public:
430  typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
431  typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet;
432  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
433  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
434  typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
435  typename RhsXprType::PacketReturnType>::ret PacketReturnType;
436  typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
437  typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
438  typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
439 
440  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
441  const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
442  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
443 
444  EIGEN_DEVICE_FUNC
445  const Indices& indices() const { return m_indices; }
446 
448  EIGEN_DEVICE_FUNC
449  const typename internal::remove_all<typename LhsXprType::Nested>::type&
450  lhsExpression() const { return m_lhs_xpr; }
451 
452  EIGEN_DEVICE_FUNC
453  const typename internal::remove_all<typename RhsXprType::Nested>::type&
454  rhsExpression() const { return m_rhs_xpr; }
455 
456  protected:
457  typename LhsXprType::Nested m_lhs_xpr;
458  typename RhsXprType::Nested m_rhs_xpr;
459  const Indices m_indices;
460 };
461 
462 
463 template<typename Derived>
464 struct TensorContractionEvaluatorBase
465 {
466  typedef typename internal::traits<Derived>::Indices Indices;
467  typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
468  typedef typename internal::traits<Derived>::RightArgType RightArgType;
469  typedef typename internal::traits<Derived>::Device Device;
470 
471  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
472  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
473  typedef typename XprType::Packet Packet;
474  typedef typename XprType::Index Index;
475  typedef typename XprType::CoeffReturnType CoeffReturnType;
476  typedef typename XprType::PacketReturnType PacketReturnType;
477 
478  enum {
479  IsAligned = true,
480  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
481  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
482  CoordAccess = false, // to be implemented
483  };
484 
485  // Most of the code is assuming that both input tensors are ColMajor. If the
486  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
487  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
488  // will pretend B is LHS and A is RHS.
489  typedef typename internal::conditional<
490  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
491  typedef typename internal::conditional<
492  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
493 
494  static const int LDims =
495  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
496  static const int RDims =
497  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
498  static const int ContractDims = internal::array_size<Indices>::value;
499  static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
500 
501  typedef array<Index, LDims> left_dim_mapper_t;
502  typedef array<Index, RDims> right_dim_mapper_t;
503  typedef array<Index, ContractDims> contract_t;
504  typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
505  typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
506 
507  typedef DSizes<Index, NumDims> Dimensions;
508 
509  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
510  TensorContractionEvaluatorBase(const XprType& op, const Device& device)
511  : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
512  op.lhsExpression(), op.rhsExpression()), device),
513  m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
514  op.rhsExpression(), op.lhsExpression()), device),
515  m_device(device),
516  m_result(NULL) {
517  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
518  static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
519  YOU_MADE_A_PROGRAMMING_MISTAKE);
520 
521 
522  DSizes<Index, LDims> eval_left_dims;
523  DSizes<Index, RDims> eval_right_dims;
524  array<IndexPair<Index>, ContractDims> eval_op_indices;
525  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
526  // For ColMajor, we keep using the existing dimensions
527  for (int i = 0; i < LDims; i++) {
528  eval_left_dims[i] = m_leftImpl.dimensions()[i];
529  }
530  for (int i = 0; i < RDims; i++) {
531  eval_right_dims[i] = m_rightImpl.dimensions()[i];
532  }
533  // We keep the pairs of contracting indices.
534  for (int i = 0; i < ContractDims; i++) {
535  eval_op_indices[i].first = op.indices()[i].first;
536  eval_op_indices[i].second = op.indices()[i].second;
537  }
538  } else {
539  // For RowMajor, we need to reverse the existing dimensions
540  for (int i = 0; i < LDims; i++) {
541  eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
542  }
543  for (int i = 0; i < RDims; i++) {
544  eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
545  }
546  // We need to flip all the pairs of contracting indices as well as
547  // reversing the dimensions.
548  for (int i = 0; i < ContractDims; i++) {
549  eval_op_indices[i].first = LDims - 1 - op.indices()[i].second;
550  eval_op_indices[i].second = RDims - 1 - op.indices()[i].first;
551  }
552  }
553 
554  array<Index, LDims> lhs_strides;
555  lhs_strides[0] = 1;
556  for (int i = 0; i < LDims-1; ++i) {
557  lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
558  }
559 
560  array<Index, RDims> rhs_strides;
561  rhs_strides[0] = 1;
562  for (int i = 0; i < RDims-1; ++i) {
563  rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
564  }
565 
566  m_i_strides[0] = 1;
567  m_j_strides[0] = 1;
568  if(ContractDims) {
569  m_k_strides[0] = 1;
570  }
571 
572  m_i_size = 1;
573  m_j_size = 1;
574  m_k_size = 1;
575 
576  // To compute the dimension, we simply concatenate the non-contracting
577  // dimensions of the left and then the right tensor. Additionally, we also
578  // compute the strides corresponding to the left non-contracting
579  // dimensions and right non-contracting dimensions.
580  m_lhs_inner_dim_contiguous = true;
581  int dim_idx = 0;
582  unsigned int nocontract_idx = 0;
583 
584  for (int i = 0; i < LDims; i++) {
585  // find if we are contracting on index i of left tensor
586  bool contracting = false;
587  for (int j = 0; j < ContractDims; j++) {
588  if (eval_op_indices[j].first == i) {
589  contracting = true;
590  break;
591  }
592  }
593  if (!contracting) {
594  // add dimension size to output dimensions
595  m_dimensions[dim_idx] = eval_left_dims[i];
596  m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
597  if (dim_idx != i) {
598  m_lhs_inner_dim_contiguous = false;
599  }
600  if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
601  m_i_strides[nocontract_idx+1] =
602  m_i_strides[nocontract_idx] * eval_left_dims[i];
603  } else {
604  m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
605  }
606  dim_idx++;
607  nocontract_idx++;
608  }
609  }
610 
611  nocontract_idx = 0;
612  for (int i = 0; i < RDims; i++) {
613  bool contracting = false;
614  // find if we are contracting on index i of right tensor
615  for (int j = 0; j < ContractDims; j++) {
616  if (eval_op_indices[j].second == i) {
617  contracting = true;
618  break;
619  }
620  }
621  if (!contracting) {
622  m_dimensions[dim_idx] = eval_right_dims[i];
623  if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
624  m_j_strides[nocontract_idx+1] =
625  m_j_strides[nocontract_idx] * eval_right_dims[i];
626  } else {
627  m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
628  }
629  m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
630  dim_idx++;
631  nocontract_idx++;
632  }
633  }
634 
635  // Now compute the strides corresponding to the contracting dimensions. We
636  // assumed above that non-contracting axes are represented in the same order
637  // in the matrix as they are in the tensor. This is not the case for
638  // contracting axes. As the contracting axes must be of the same size in
639  // each tensor, we'll only look at the first tensor here.
640  m_rhs_inner_dim_contiguous = true;
641  m_rhs_inner_dim_reordered = false;
642  for (int i = 0; i < ContractDims; i++) {
643  Index left = eval_op_indices[i].first;
644  Index right = eval_op_indices[i].second;
645 
646  Index size = eval_left_dims[left];
647  eigen_assert(size == eval_right_dims[right] &&
648  "Contraction axes must be same size");
649 
650  if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) {
651  m_k_strides[i+1] = m_k_strides[i] * size;
652  } else {
653  m_k_size = m_k_strides[i] * size;
654  }
655  m_left_contracting_strides[i] = lhs_strides[left];
656  m_right_contracting_strides[i] = rhs_strides[right];
657 
658  if (i > 0 && right < eval_op_indices[i-1].second) {
659  m_rhs_inner_dim_reordered = true;
660  }
661  if (right != i) {
662  m_rhs_inner_dim_contiguous = false;
663  }
664  }
665 
666  // Scalar case. We represent the result as a 1d tensor of size 1.
667  if (LDims + RDims == 2 * ContractDims) {
668  m_dimensions[0] = 1;
669  }
670 
671  // If the layout is RowMajor, we need to reverse the m_dimensions
672  if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
673  for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
674  numext::swap(m_dimensions[i], m_dimensions[j]);
675  }
676  }
677  }
678 
679  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
680 
681  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
682  m_leftImpl.evalSubExprsIfNeeded(NULL);
683  m_rightImpl.evalSubExprsIfNeeded(NULL);
684  if (data) {
685  evalTo(data);
686  return false;
687  } else {
688  m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
689  evalTo(m_result);
690  return true;
691  }
692  }
693 
694  EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
695  if (this->m_lhs_inner_dim_contiguous) {
696  if (this->m_rhs_inner_dim_contiguous) {
697  if (this->m_rhs_inner_dim_reordered) {
698  static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
699  }
700  else {
701  static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
702  }
703  }
704  else {
705  if (this->m_rhs_inner_dim_reordered) {
706  static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
707  }
708  else {
709  static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
710  }
711  }
712  }
713  else {
714  if (this->m_rhs_inner_dim_contiguous) {
715  if (this->m_rhs_inner_dim_reordered) {
716  static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
717  }
718  else {
719  static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
720  }
721  }
722  else {
723  if (this->m_rhs_inner_dim_reordered) {
724  static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
725  }
726  else {
727  static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
728  }
729  }
730  }
731  }
732 
733  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
734  void evalGemv(Scalar* buffer) const {
735  const Index rows = m_i_size;
736  const Index cols = m_k_size;
737 
738  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
739  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
740  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
741  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
742  const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
743  const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
744  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
745  LeftEvaluator, left_nocontract_t,
746  contract_t, lhs_packet_size,
747  lhs_inner_dim_contiguous,
748  false, Unaligned> LhsMapper;
749 
750  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
751  RightEvaluator, right_nocontract_t,
752  contract_t, rhs_packet_size,
753  rhs_inner_dim_contiguous,
754  rhs_inner_dim_reordered, Unaligned> RhsMapper;
755 
756  LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
757  m_left_contracting_strides, m_k_strides);
758  RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
759  m_right_contracting_strides, m_k_strides);
760 
761  const Scalar alpha(1);
762  const Index resIncr(1);
763 
764  // zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
765  m_device.memset(buffer, 0, rows * sizeof(Scalar));
766 
767  internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
768  rows, cols, lhs, rhs,
769  buffer, resIncr, alpha);
770  }
771 
772  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
773  m_leftImpl.cleanup();
774  m_rightImpl.cleanup();
775 
776  if (m_result != NULL) {
777  m_device.deallocate(m_result);
778  m_result = NULL;
779  }
780  }
781 
782  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
783  return m_result[index];
784  }
785 
786  template<int LoadMode>
787  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
788  return internal::ploadt<Packet, LoadMode>(m_result + index);
789  }
790 
791  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
792 
793  protected:
794  // Prevent assignment
795  TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&);
796  Dimensions m_dimensions;
797 
798  contract_t m_k_strides;
799  contract_t m_left_contracting_strides;
800  contract_t m_right_contracting_strides;
801 
802  bool m_lhs_inner_dim_contiguous;
803  bool m_rhs_inner_dim_contiguous;
804  bool m_rhs_inner_dim_reordered;
805 
806  left_nocontract_t m_i_strides;
807  right_nocontract_t m_j_strides;
808  left_nocontract_t m_left_nocontract_strides;
809  right_nocontract_t m_right_nocontract_strides;
810 
811  Index m_i_size;
812  Index m_j_size;
813  Index m_k_size;
814 
815  TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
816  TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
817  const Device& m_device;
818  Scalar* m_result;
819 };
820 
821 
822 // evaluator for default device
823 template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
824 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
825  public TensorContractionEvaluatorBase<
826  TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
827  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
828  typedef TensorContractionEvaluatorBase<Self> Base;
829 
830  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
831  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
832  typedef typename XprType::Packet Packet;
833  typedef typename XprType::Index Index;
834  typedef typename XprType::CoeffReturnType CoeffReturnType;
835  typedef typename XprType::PacketReturnType PacketReturnType;
836 
837  enum {
838  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
839  };
840 
841  // Most of the code is assuming that both input tensors are ColMajor. If the
842  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
843  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
844  // will pretend B is LHS and A is RHS.
845  typedef typename internal::conditional<
846  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
847  typedef typename internal::conditional<
848  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
849 
850  static const int LDims =
851  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
852  static const int RDims =
853  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
854  static const int ContractDims = internal::array_size<Indices>::value;
855 
856  typedef array<Index, LDims> left_dim_mapper_t;
857  typedef array<Index, RDims> right_dim_mapper_t;
858 
859  typedef array<Index, ContractDims> contract_t;
860  typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
861  typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
862 
863  static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
864 
865  // Could we use NumDimensions here?
866  typedef DSizes<Index, NumDims> Dimensions;
867 
868 
869  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
870  Base(op, device) { }
871 
872  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
873  void evalProduct(Scalar* buffer) const {
874  if (this->m_j_size == 1) {
875  this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
876  return;
877  }
878 
879  evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
880  }
881 
882  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
883  EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
884  // columns in left side, rows in right side
885  const Index k = this->m_k_size;
886 
887  // rows in left side
888  const Index m = this->m_i_size;
889 
890  // columns in right side
891  const Index n = this->m_j_size;
892 
893  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
894  this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
895 
896  // define mr, nr, and all of my data mapper types
897  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
898  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
899  typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
900 
901  const Index nr = Traits::nr;
902  const Index mr = Traits::mr;
903 
904  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
905  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
906 
907  const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
908  const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
909 
910  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
911  LeftEvaluator, left_nocontract_t,
912  contract_t, lhs_packet_size,
913  lhs_inner_dim_contiguous,
914  false, Unaligned> LhsMapper;
915 
916  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
917  RightEvaluator, right_nocontract_t,
918  contract_t, rhs_packet_size,
919  rhs_inner_dim_contiguous,
920  rhs_inner_dim_reordered, Unaligned> RhsMapper;
921 
922  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
923 
924  // Declare GEBP packing and kernel structs
925  internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
926  internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
927 
928  internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
929 
930  // initialize data mappers
931  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
932  this->m_left_contracting_strides, this->m_k_strides);
933 
934  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
935  this->m_right_contracting_strides, this->m_k_strides);
936 
937  OutputMapper output(buffer, m);
938 
939  typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
940 
941  // Sizes of the blocks to load in cache. See the Goto paper for details.
942  BlockingType blocking(m, n, k, 1, true);
943  const Index kc = blocking.kc();
944  const Index mc = numext::mini(m, blocking.mc());
945  const Index nc = numext::mini(n, blocking.nc());
946  const Index sizeA = mc * kc;
947  const Index sizeB = kc * nc;
948 
949  LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
950  RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
951 
952  for(Index i2=0; i2<m; i2+=mc)
953  {
954  const Index actual_mc = numext::mini(i2+mc,m)-i2;
955  for (Index k2 = 0; k2 < k; k2 += kc) {
956  // make sure we don't overshoot right edge of left matrix, then pack vertical panel
957  const Index actual_kc = numext::mini(k2 + kc, k) - k2;
958  pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
959 
960  // series of horizontal blocks
961  for (Index j2 = 0; j2 < n; j2 += nc) {
962  // make sure we don't overshoot right edge of right matrix, then pack block
963  const Index actual_nc = numext::mini(j2 + nc, n) - j2;
964  pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
965 
966  // call gebp (matrix kernel)
967  // The parameters here are copied from Eigen's GEMM implementation
968  gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
969  }
970  }
971  }
972 
973  this->m_device.deallocate(blockA);
974  this->m_device.deallocate(blockB);
975  }
976 };
977 
978 } // end namespace Eigen
979 
980 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13