TensorAssign.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
12 
13 namespace Eigen {
14 
23 namespace internal {
24 template<typename LhsXprType, typename RhsXprType>
25 struct traits<TensorAssignOp<LhsXprType, RhsXprType> >
26 {
27  typedef typename LhsXprType::Scalar Scalar;
28  typedef typename internal::packet_traits<Scalar>::type Packet;
29  typedef typename traits<LhsXprType>::StorageKind StorageKind;
30  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
31  typename traits<RhsXprType>::Index>::type Index;
32  typedef typename LhsXprType::Nested LhsNested;
33  typedef typename RhsXprType::Nested RhsNested;
34  typedef typename remove_reference<LhsNested>::type _LhsNested;
35  typedef typename remove_reference<RhsNested>::type _RhsNested;
36  static const std::size_t NumDimensions = internal::traits<LhsXprType>::NumDimensions;
37  static const int Layout = internal::traits<LhsXprType>::Layout;
38 
39  enum {
40  Flags = 0,
41  };
42 };
43 
44 template<typename LhsXprType, typename RhsXprType>
45 struct eval<TensorAssignOp<LhsXprType, RhsXprType>, Eigen::Dense>
46 {
47  typedef const TensorAssignOp<LhsXprType, RhsXprType>& type;
48 };
49 
50 template<typename LhsXprType, typename RhsXprType>
51 struct nested<TensorAssignOp<LhsXprType, RhsXprType>, 1, typename eval<TensorAssignOp<LhsXprType, RhsXprType> >::type>
52 {
53  typedef TensorAssignOp<LhsXprType, RhsXprType> type;
54 };
55 
56 } // end namespace internal
57 
58 
59 
60 template<typename LhsXprType, typename RhsXprType>
61 class TensorAssignOp : public TensorBase<TensorAssignOp<LhsXprType, RhsXprType> >
62 {
63  public:
64  typedef typename Eigen::internal::traits<TensorAssignOp>::Scalar Scalar;
65  typedef typename Eigen::internal::traits<TensorAssignOp>::Packet Packet;
66  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
67  typedef typename LhsXprType::CoeffReturnType CoeffReturnType;
68  typedef typename LhsXprType::PacketReturnType PacketReturnType;
69  typedef typename Eigen::internal::nested<TensorAssignOp>::type Nested;
70  typedef typename Eigen::internal::traits<TensorAssignOp>::StorageKind StorageKind;
71  typedef typename Eigen::internal::traits<TensorAssignOp>::Index Index;
72 
73  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorAssignOp(LhsXprType& lhs, const RhsXprType& rhs)
74  : m_lhs_xpr(lhs), m_rhs_xpr(rhs) {}
75 
77  EIGEN_DEVICE_FUNC
78  typename internal::remove_all<typename LhsXprType::Nested>::type&
79  lhsExpression() const { return *((typename internal::remove_all<typename LhsXprType::Nested>::type*)&m_lhs_xpr); }
80 
81  EIGEN_DEVICE_FUNC
82  const typename internal::remove_all<typename RhsXprType::Nested>::type&
83  rhsExpression() const { return m_rhs_xpr; }
84 
85  protected:
86  typename internal::remove_all<typename LhsXprType::Nested>::type& m_lhs_xpr;
87  const typename internal::remove_all<typename RhsXprType::Nested>::type& m_rhs_xpr;
88 };
89 
90 
91 template<typename LeftArgType, typename RightArgType, typename Device>
92 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
93 {
94  typedef TensorAssignOp<LeftArgType, RightArgType> XprType;
95 
96  enum {
97  IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
98  PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
99  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
100  };
101 
102  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
103  m_leftImpl(op.lhsExpression(), device),
104  m_rightImpl(op.rhsExpression(), device)
105  {
106  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
107  }
108 
109  typedef typename XprType::Index Index;
110  typedef typename XprType::Scalar Scalar;
111  typedef typename XprType::CoeffReturnType CoeffReturnType;
112  typedef typename XprType::PacketReturnType PacketReturnType;
113  typedef typename TensorEvaluator<RightArgType, Device>::Dimensions Dimensions;
114 
115  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
116  {
117  // The dimensions of the lhs and the rhs tensors should be equal to prevent
118  // overflows and ensure the result is fully initialized.
119  // TODO: use left impl instead if right impl dimensions are known at compile time.
120  return m_rightImpl.dimensions();
121  }
122 
123  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
124  eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
125  m_leftImpl.evalSubExprsIfNeeded(NULL);
126  // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non
127  // null value), attempt to evaluate the rhs expression in place. Returns true iff in place
128  // evaluation isn't supported and the caller still needs to manually assign the values generated
129  // by the rhs to the lhs.
130  return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data());
131  }
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
133  m_leftImpl.cleanup();
134  m_rightImpl.cleanup();
135  }
136 
137  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalScalar(Index i) {
138  m_leftImpl.coeffRef(i) = m_rightImpl.coeff(i);
139  }
140  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalPacket(Index i) {
141  const int LhsStoreMode = TensorEvaluator<LeftArgType, Device>::IsAligned ? Aligned : Unaligned;
142  const int RhsLoadMode = TensorEvaluator<RightArgType, Device>::IsAligned ? Aligned : Unaligned;
143  m_leftImpl.template writePacket<LhsStoreMode>(i, m_rightImpl.template packet<RhsLoadMode>(i));
144  }
145  EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
146  {
147  return m_leftImpl.coeff(index);
148  }
149  template<int LoadMode>
150  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
151  {
152  return m_leftImpl.template packet<LoadMode>(index);
153  }
154 
155  private:
156  TensorEvaluator<LeftArgType, Device> m_leftImpl;
157  TensorEvaluator<RightArgType, Device> m_rightImpl;
158 };
159 
160 }
161 
162 
163 #endif // EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13