85 std::vector<shape3d>
in_shape()
const override {
94 std::vector<tensor_t*>&
out_data)
override {
95 switch (slice_type_) {
96 case slice_type::slice_samples:
99 case slice_type::slice_channels:
108 const std::vector<tensor_t*>&
out_data,
110 std::vector<tensor_t*>&
in_grad)
override {
111 CNN_UNREFERENCED_PARAMETER(
in_data);
112 CNN_UNREFERENCED_PARAMETER(
out_data);
114 switch (slice_type_) {
115 case slice_type::slice_samples:
118 case slice_type::slice_channels:
126 template <
class Archive>
127 static void load_and_construct(
Archive &
ar, cereal::construct<slice_layer> & construct) {
129 slice_type slice_type;
132 ar(cereal::make_nvp(
"in_size",
in_shape), cereal::make_nvp(
"slice_type", slice_type), cereal::make_nvp(
"num_outputs",
num_outputs));
136 template <
class Archive>
137 void serialize(Archive & ar) {
138 layer::serialize_prolog(ar);
139 ar(cereal::make_nvp(
"in_size", in_shape_), cereal::make_nvp(
"slice_type", slice_type_), cereal::make_nvp(
"num_outputs", num_outputs_));
142 void slice_data_forward(
const tensor_t& in_data,
143 std::vector<tensor_t*>& out_data) {
144 const vec_t* in = &in_data[0];
146 for (serial_size_t i = 0; i < num_outputs_; i++) {
147 tensor_t& out = *out_data[i];
149 std::copy(in, in + slice_size_[i], &out[0]);
151 in += slice_size_[i];
155 void slice_data_backward(std::vector<tensor_t*>& out_grad,
157 vec_t* in = &in_grad[0];
159 for (serial_size_t i = 0; i < num_outputs_; i++) {
160 tensor_t& out = *out_grad[i];
162 std::copy(&out[0], &out[0] + slice_size_[i], in);
164 in += slice_size_[i];
168 void slice_channels_forward(
const tensor_t& in_data,
169 std::vector<tensor_t*>& out_data) {
170 serial_size_t num_samples =
static_cast<serial_size_t
>(in_data.size());
171 serial_size_t channel_idx = 0;
172 serial_size_t spatial_dim = in_shape_.area();
174 for (serial_size_t i = 0; i < num_outputs_; i++) {
175 for (serial_size_t s = 0; s < num_samples; s++) {
176 float_t *out = &(*out_data[i])[s][0];
177 const float_t *in = &in_data[s][0] + channel_idx*spatial_dim;
179 std::copy(in, in + slice_size_[i] * spatial_dim, out);
181 channel_idx += slice_size_[i];
185 void slice_channels_backward(std::vector<tensor_t*>& out_grad,
187 serial_size_t num_samples =
static_cast<serial_size_t
>(in_grad.size());
188 serial_size_t channel_idx = 0;
189 serial_size_t spatial_dim = in_shape_.area();
191 for (serial_size_t i = 0; i < num_outputs_; i++) {
192 for (serial_size_t s = 0; s < num_samples; s++) {
193 const float_t *out = &(*out_grad[i])[s][0];
194 float_t *in = &in_grad[s][0] + channel_idx*spatial_dim;
196 std::copy(out, out + slice_size_[i] * spatial_dim, in);
198 channel_idx += slice_size_[i];
202 void set_sample_count(serial_size_t sample_count)
override {
203 if (slice_type_ == slice_type::slice_samples) {
204 if (num_outputs_ == 0)
205 throw nn_error(
"num_outputs must be positive integer");
207 serial_size_t sample_per_out = sample_count / num_outputs_;
209 slice_size_.resize(num_outputs_, sample_per_out);
210 slice_size_.back() = sample_count - (sample_per_out*(num_outputs_-1));
212 Base::set_sample_count(sample_count);
216 switch (slice_type_) {
217 case slice_type::slice_samples:
220 case slice_type::slice_channels:
221 set_shape_channels();
224 throw nn_not_implemented_error();
228 void set_shape_data() {
229 out_shapes_.resize(num_outputs_, in_shape_);
232 void set_shape_channels() {
233 serial_size_t channel_per_out = in_shape_.depth_ / num_outputs_;
236 for (serial_size_t i = 0; i < num_outputs_; i++) {
237 serial_size_t ch = channel_per_out;
239 if (i == num_outputs_ - 1) {
240 assert(in_shape_.depth_ >= i * channel_per_out);
241 ch = in_shape_.depth_ - i * channel_per_out;
244 slice_size_.push_back(ch);
245 out_shapes_.push_back(shape3d(in_shape_.width_, in_shape_.height_, ch));
250 slice_type slice_type_;
251 serial_size_t num_outputs_;
252 std::vector<shape3d> out_shapes_;
253 std::vector<serial_size_t> slice_size_;