TensorArgMax.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5 // Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
13 
14 namespace Eigen {
15 namespace internal {
16 
24 template<typename XprType>
25 struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
26 {
27  typedef traits<XprType> XprTraits;
28  typedef typename XprTraits::StorageKind StorageKind;
29  typedef typename XprTraits::Index Index;
30  typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
31  typedef typename XprType::Nested Nested;
32  typedef typename remove_reference<Nested>::type _Nested;
33  static const int NumDimensions = XprTraits::NumDimensions;
34  static const int Layout = XprTraits::Layout;
35 };
36 
37 template<typename XprType>
38 struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
39 {
40  typedef const TensorIndexTupleOp<XprType>& type;
41 };
42 
43 template<typename XprType>
44 struct nested<TensorIndexTupleOp<XprType>, 1,
45  typename eval<TensorIndexTupleOp<XprType> >::type>
46 {
47  typedef TensorIndexTupleOp<XprType> type;
48 };
49 
50 } // end namespace internal
51 
52 template<typename XprType>
53 class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
54 {
55  public:
56  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
57  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58  typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
59  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
60  typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
61  typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
62 
63  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
64  : m_xpr(expr) {}
65 
66  EIGEN_DEVICE_FUNC
67  const typename internal::remove_all<typename XprType::Nested>::type&
68  expression() const { return m_xpr; }
69 
70  protected:
71  typename XprType::Nested m_xpr;
72 };
73 
74 // Eval as rvalue
75 template<typename ArgType, typename Device>
76 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
77 {
78  typedef TensorIndexTupleOp<ArgType> XprType;
79  typedef typename XprType::Index Index;
80  typedef typename XprType::Scalar Scalar;
81  typedef typename XprType::CoeffReturnType CoeffReturnType;
82 
83  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
84  static const int NumDims = internal::array_size<Dimensions>::value;
85 
86  enum {
87  IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
88  PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
89  BlockAccess = false,
90  Layout = TensorEvaluator<ArgType, Device>::Layout,
91  CoordAccess = false, // to be implemented
92  };
93 
94  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
95  : m_impl(op.expression(), device) { }
96 
97  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
98  return m_impl.dimensions();
99  }
100 
101  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
102  m_impl.evalSubExprsIfNeeded(NULL);
103  return true;
104  }
105  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
106  m_impl.cleanup();
107  }
108 
109  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
110  {
111  return CoeffReturnType(index, m_impl.coeff(index));
112  }
113 
114  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
115 
116  protected:
117  TensorEvaluator<ArgType, Device> m_impl;
118 };
119 
120 namespace internal {
121 
128 template<typename ReduceOp, typename Dims, typename XprType>
129 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
130 {
131  typedef traits<XprType> XprTraits;
132  typedef typename XprTraits::StorageKind StorageKind;
133  typedef typename XprTraits::Index Index;
134  typedef Index Scalar;
135  typedef typename XprType::Nested Nested;
136  typedef typename remove_reference<Nested>::type _Nested;
137  static const int NumDimensions = XprTraits::NumDimensions;
138  static const int Layout = XprTraits::Layout;
139 };
140 
141 template<typename ReduceOp, typename Dims, typename XprType>
142 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
143 {
144  typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
145 };
146 
147 template<typename ReduceOp, typename Dims, typename XprType>
148 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
149  typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
150 {
151  typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
152 };
153 
154 } // end namespace internal
155 
156 template<typename ReduceOp, typename Dims, typename XprType>
157 class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
158 {
159  public:
160  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
161  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
162  typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
163  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
164  typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
165  typedef Index CoeffReturnType;
166 
167  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
168  const ReduceOp& reduce_op,
169  const int return_dim,
170  const Dims& reduce_dims)
171  : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
172 
173  EIGEN_DEVICE_FUNC
174  const typename internal::remove_all<typename XprType::Nested>::type&
175  expression() const { return m_xpr; }
176 
177  EIGEN_DEVICE_FUNC
178  const ReduceOp& reduce_op() const { return m_reduce_op; }
179 
180  EIGEN_DEVICE_FUNC
181  const Dims& reduce_dims() const { return m_reduce_dims; }
182 
183  EIGEN_DEVICE_FUNC
184  int return_dim() const { return m_return_dim; }
185 
186  protected:
187  typename XprType::Nested m_xpr;
188  const ReduceOp m_reduce_op;
189  const int m_return_dim;
190  const Dims m_reduce_dims;
191 };
192 
193 // Eval as rvalue
194 template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
195 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
196 {
197  typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
198  typedef typename XprType::Index Index;
199  typedef typename XprType::Scalar Scalar;
200  typedef typename XprType::CoeffReturnType CoeffReturnType;
201  typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
202  typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
203  typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
204  static const int NumDims = internal::array_size<InputDimensions>::value;
205  typedef array<Index, NumDims> StrideDims;
206 
207  enum {
208  IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
209  PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
210  BlockAccess = false,
211  Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
212  CoordAccess = false, // to be implemented
213  };
214 
215  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
216  : m_orig_impl(op.expression(), device),
217  m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
218  m_return_dim(op.return_dim()) {
219 
220  gen_strides(m_orig_impl.dimensions(), m_strides);
221  if (Layout == static_cast<int>(ColMajor)) {
222  const Index total_size = internal::array_prod(m_orig_impl.dimensions());
223  m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
224  } else {
225  const Index total_size = internal::array_prod(m_orig_impl.dimensions());
226  m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
227  }
228  m_stride_div = m_strides[m_return_dim];
229  }
230 
231  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
232  return m_impl.dimensions();
233  }
234 
235  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
236  m_impl.evalSubExprsIfNeeded(NULL);
237  return true;
238  }
239  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
240  m_impl.cleanup();
241  }
242 
243  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
244  const TupleType v = m_impl.coeff(index);
245  return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
246  }
247 
248  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
249 
250  private:
251  EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
252  if (m_return_dim < 0) {
253  return; // Won't be using the strides.
254  }
255  eigen_assert(m_return_dim < NumDims &&
256  "Asking to convert index to a dimension outside of the rank");
257 
258  // Calculate m_stride_div and m_stride_mod, which are used to
259  // calculate the value of an index w.r.t. the m_return_dim.
260  if (Layout == static_cast<int>(ColMajor)) {
261  strides[0] = 1;
262  for (int i = 1; i < NumDims; ++i) {
263  strides[i] = strides[i-1] * dims[i-1];
264  }
265  } else {
266  strides[NumDims-1] = 1;
267  for (int i = NumDims - 2; i >= 0; --i) {
268  strides[i] = strides[i+1] * dims[i+1];
269  }
270  }
271  }
272 
273  protected:
274  TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
275  TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
276  const int m_return_dim;
277  StrideDims m_strides;
278  Index m_stride_mod;
279  Index m_stride_div;
280 };
281 
282 } // end namespace Eigen
283 
284 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13