28#include "tiny_dnn/util/util.h"
29#include "tiny_dnn/layers/layer.h"
66 out_shape_ = in_shapes_.front();
67 for (
size_t i = 1; i < in_shapes_.size(); i++) {
68 if (in_shapes_[i].area() != out_shape_.area())
69 throw nn_error(
"each input shapes to concat must have same WxH size");
70 out_shape_.depth_ += in_shapes_[i].depth_;
78 std::vector<shape3d>
in_shape()
const override {
87 std::vector<tensor_t*>&
out_data)
override {
93 for (serial_size_t
i = 0;
i < in_shapes_.size();
i++) {
95 serial_size_t
dim = in_shapes_[
i].size();
102 const std::vector<tensor_t*>&
out_data,
104 std::vector<tensor_t*>&
in_grad)
override {
105 CNN_UNREFERENCED_PARAMETER(
in_data);
106 CNN_UNREFERENCED_PARAMETER(
out_data);
113 for (serial_size_t
i = 0;
i < in_shapes_.size();
i++) {
114 serial_size_t
dim = in_shapes_[
i].size();
122 template <
class Archive>
123 static void load_and_construct(
Archive &
ar, cereal::construct<concat_layer> & construct) {
130 template <
class Archive>
131 void serialize(Archive & ar) {
132 layer::serialize_prolog(ar);
137 std::vector<shape3d> in_shapes_;
concat N layers along depth
Definition concat_layer.h:44
void back_propagation(const std::vector< tensor_t * > &in_data, const std::vector< tensor_t * > &out_data, std::vector< tensor_t * > &out_grad, std::vector< tensor_t * > &in_grad) override
return delta of previous layer (delta=\frac{dE}{da}, a=wx in fully-connected layer)
Definition concat_layer.h:101
std::vector< shape3d > in_shape() const override
array of input shapes (width x height x depth)
Definition concat_layer.h:78
std::string layer_type() const override
name of layer, should be unique for each concrete class
Definition concat_layer.h:74
std::vector< shape3d > out_shape() const override
array of output shapes (width x height x depth)
Definition concat_layer.h:82
void forward_propagation(const std::vector< tensor_t * > &in_data, std::vector< tensor_t * > &out_data) override
Definition concat_layer.h:86
concat_layer(const std::vector< shape3d > &in_shapes)
Definition concat_layer.h:49
concat_layer(serial_size_t num_args, serial_size_t ndim)
Definition concat_layer.h:59
Simple image utility class.
Definition image.h:94
base class of all kind of NN layers
Definition layer.h:62