151 serial_size_t in_data_size()
const {
152 return sumif(
in_shape(), [&](serial_size_t i) {
153 return in_type_[i] == vector_type::data; }, [](
const shape3d& s) {
157 serial_size_t out_data_size()
const {
158 return sumif(
out_shape(), [&](serial_size_t i) {
159 return out_type_[i] == vector_type::data; }, [](
const shape3d& s) {
163 std::vector<shape3d> in_data_shape() {
164 return filter(
in_shape(), [&](
size_t i) {
165 return in_type_[i] == vector_type::data;
169 std::vector<shape3d> out_data_shape() {
170 return filter(
out_shape(), [&](
size_t i) {
171 return out_type_[i] == vector_type::data;
177 return in_data_size();
182 return out_data_size();
185 std::vector<const vec_t*> weights()
const {
186 std::vector<const vec_t*>
v;
189 v.push_back(get_weight_data(
i));
195 std::vector<vec_t*> weights() {
196 std::vector<vec_t*> v;
198 if (is_trainable_weight(
in_type_[i])) {
199 v.push_back(get_weight_data(i));
205 std::vector<tensor_t*> weights_grads() {
206 std::vector<tensor_t*> v;
208 if (is_trainable_weight(
in_type_[i])) {
209 v.push_back(ith_in_node(i)->get_gradient());
215 std::vector<edgeptr_t> inputs() {
216 std::vector<edgeptr_t> nodes;
218 nodes.push_back(ith_in_node(i));
223 std::vector<edgeptr_t> outputs() {
224 std::vector<edgeptr_t> nodes;
226 nodes.push_back(ith_out_node(i));
231 std::vector<edgeptr_t> outputs()
const {
232 std::vector<edgeptr_t> nodes;
234 nodes.push_back(
const_cast<layerptr_t
>(
this)
240 void set_out_grads(
const std::vector<tensor_t>& grad) {
243 if (
out_type_[i] != vector_type::data)
continue;
244 assert(j < grad.size());
245 *ith_out_node(i)->get_gradient() = grad[j++];
249 void set_in_data(
const std::vector<tensor_t>& data) {
252 if (
in_type_[i] != vector_type::data)
continue;
253 assert(j < data.size());
254 *ith_in_node(i)->get_data() = data[j++];
258 std::vector<tensor_t> output()
const {
259 std::vector<tensor_t> out;
262 out.push_back(*(
const_cast<layerptr_t
>(
this))
263 ->ith_out_node(i)->get_data());
269 std::vector<vector_type> in_types()
const {
return in_type_; }
271 std::vector<vector_type> out_types()
const {
return out_type_; }
273 void set_trainable(
bool trainable) { trainable_ = trainable; }
275 bool trainable()
const {
return trainable_; }
283 return { float_t(0.0), float_t(1.0) };
321 template <
typename WeightInit>
323 weight_init_ = std::make_shared<WeightInit>(f);
327 template <
typename BiasInit>
328 layer& bias_init(
const BiasInit& f) {
329 bias_init_ = std::make_shared<BiasInit>(f);
333 template <
typename WeightInit>
334 layer& weight_init(std::shared_ptr<WeightInit> f) {
339 template <
typename BiasInit>
340 layer& bias_init(std::shared_ptr<BiasInit> f) {
347 template <
typename Archive>
348 void serialize(Archive & ar) {
349 auto all_weights = weights();
350 for (
auto weight : all_weights) {
356 virtual void save(std::ostream& os)
const {
360 auto all_weights = weights();
361 for (
auto& weight : all_weights) {
362 for (
auto w : *weight) os << w <<
" ";
366 virtual void load(std::istream& is) {
367 auto all_weights = weights();
368 for (
auto& weight : all_weights) {
369 for (
auto& w : *weight) is >> w;
374 virtual void load(
const std::vector<float_t>& src,
int& idx) {
375 auto all_weights = weights();
376 for (
auto& weight : all_weights) {
377 for (
auto& w : *weight) w = src[idx++];
389 const vec_t* output = &(*(outputs()[
channel]->get_data()))[0];
401 std::vector<tensor_t*>&
out_data) = 0;
411 const std::vector<tensor_t*>&
out_data,
413 std::vector<tensor_t*>&
in_grad) = 0;
428 CNN_UNREFERENCED_PARAMETER(
ctx);
445 std::vector<tensor_t> forward(
const std::vector<tensor_t>& input) {
458 std::vector<tensor_t> backward(
const std::vector<tensor_t>& out_grads) {
460 set_out_grads(out_grads);
462 return map_<tensor_t>(inputs(), [](edgeptr_t e) {
463 return *e->get_gradient();
488 std::vector<tensor_t*> in_data, out_data;
495 in_data.push_back(ith_in_node(i)->get_data());
500 set_sample_count(
static_cast<serial_size_t
>(in_data[0]->size()));
507 out_data.push_back(ith_out_node(i)->get_data());
508 ith_out_node(i)->clear_grads();
516 std::vector<tensor_t*> in_data, out_data, in_grad, out_grad;
520 in_data.push_back(ith_in_node(i)->get_data());
523 out_data.push_back(ith_out_node(i)->get_data());
526 in_grad.push_back(ith_in_node(i)->get_gradient());
529 out_grad.push_back(ith_out_node(i)->get_gradient());
541 void setup(
bool reset_weight) {
548 throw nn_error(
"Connection mismatch at setup layer");
562 next_[i] = std::make_shared<edge>(
599 case vector_type::weight:
600 weight_init_->fill(get_weight_data(i),
604 case vector_type::bias:
605 bias_init_->fill(get_weight_data(i),
618 for (serial_size_t i = 0; i < static_cast<serial_size_t>(
in_type_.size()); i++) {
619 ith_in_node(i)->clear_grads();
623 void update_weight(optimizer *o, serial_size_t batch_size) {
624 float_t rcp_batch_size = float_t(1) / float_t(batch_size);
626 for (serial_size_t i = 0; i < static_cast<serial_size_t>(
in_type_.size()); i++) {
627 if (trainable() && is_trainable_weight(
in_type_[i])) {
628 vec_t& target = *get_weight_data(i);
629 ith_in_node(i)->merge_grads(&diff);
630 std::transform(diff.begin(), diff.end(),
631 diff.begin(), [&](float_t x) {
632 return x * rcp_batch_size; });
635 bool parallelize = (target.size() >= 512);
636 o->update(diff, target, parallelize);
643 bool has_same_weights(
const layer& rhs, float_t eps)
const {
645 auto w2 = rhs.weights();
646 if (w1.size() != w2.size())
return false;
648 for (
size_t i = 0; i < w1.size(); i++) {
649 if (w1[i]->size() != w2[i]->size())
return false;
651 for (
size_t j = 0; j < w1[i]->size(); j++) {
652 if (std::abs(w1[i]->at(j) - w2[i]->at(j)) > eps)
return false;
658 virtual void set_sample_count(serial_size_t sample_count) {
661 auto resize = [sample_count](tensor_t* tensor) {
662 tensor->resize(sample_count, (*tensor)[0]);
666 if (!is_trainable_weight(
in_type_[i])) {
667 resize(ith_in_node(i)->get_data());
669 resize(ith_in_node(i)->get_gradient());
673 if (!is_trainable_weight(
out_type_[i])) {
674 resize(ith_out_node(i)->get_data());
676 resize(ith_out_node(i)->get_gradient());
683 template <
typename InputArchive>
684 static std::shared_ptr<layer>
load_layer(InputArchive & ia);
686 template <
typename OutputArchive>
687 static void save_layer(OutputArchive & oa,
const layer& l);
689 template <
class Archive>
690 void serialize_prolog(Archive & ar);
716 std::shared_ptr<weight_init::function> weight_init_;
718 std::shared_ptr<weight_init::function> bias_init_;
730 void alloc_input(serial_size_t
i)
const {
746 void alloc_output(serial_size_t
i)
const {
749 next_[
i] = std::make_shared<edge>((
layer*)
this,
762 edgeptr_t ith_in_node(serial_size_t i) {
764 if (!prev_[i]) alloc_input(i);
777 edgeptr_t ith_out_node(serial_size_t i) {
779 if (!next_[i]) alloc_output(i);
788 vec_t* get_weight_data(serial_size_t i) {
789 assert(is_trainable_weight(
in_type_[i]));
790 return &(*(ith_in_node(i)->get_data()))[0];
798 const vec_t* get_weight_data(serial_size_t i)
const {
799 assert(is_trainable_weight(
in_type_[i]));
800 return &(*(
const_cast<layerptr_t
>(
this)->ith_in_node(i)->get_data()))[0];
804inline void connect(layerptr_t head,
806 serial_size_t head_index = 0,
807 serial_size_t tail_index = 0) {
808 auto out_shape = head->out_shape()[head_index];
809 auto in_shape = tail->in_shape()[tail_index];
813 if (out_shape.size() != in_shape.size()) {
814 connection_mismatch(*head, *tail);
817 if (!head->next_[head_index]) {
818 throw nn_error(
"output edge must not be null");
821 tail->prev_[tail_index] = head->next_[head_index];
822 tail->prev_[tail_index]->add_next_node(tail);
825inline layer& operator << (layer& lhs, layer& rhs) {
830template <
typename Char,
typename CharTraits>
831std::basic_ostream<Char, CharTraits>& operator << (
832 std::basic_ostream<Char, CharTraits>& os,
const layer& v) {
837template <
typename Char,
typename CharTraits>
838std::basic_istream<Char, CharTraits>& operator >> (
839 std::basic_istream<Char, CharTraits>& os, layer& v) {
846inline void connection_mismatch(
const layer& from,
const layer& to) {
847 std::ostringstream os;
850 os <<
"output size of Nth layer must be equal to input of (N+1)th layer\n";
852 os <<
"layerN: " << std::setw(12) << from.layer_type() <<
" in:"
853 << from.in_data_size() <<
"("
854 << from.in_shape() <<
"), " <<
"out:"
855 << from.out_data_size() <<
"("
856 << from.out_shape() <<
")\n";
858 os <<
"layerN+1: " << std::setw(12) << to.layer_type() <<
" in:"
859 << to.in_data_size() <<
"("
860 << to.in_shape() <<
"), " <<
"out:"
861 << to.out_data_size() <<
"("
862 << to.out_shape() <<
")\n";
864 os << from.out_data_size() <<
" != " << to.in_data_size() << std::endl;
865 std::string detail_info = os.str();
867 throw nn_error(
"layer dimension mismatch!" + detail_info);
870inline void data_mismatch(
const layer& layer,
const vec_t& data) {
871 std::ostringstream os;
874 os <<
"data dimension: " << data.size() <<
"\n";
875 os <<
"network dimension: " << layer.in_data_size() <<
"("
876 << layer.layer_type() <<
":"
877 << layer.in_shape() <<
")\n";
879 std::string detail_info = os.str();
881 throw nn_error(
"input dimension mismatch!" + detail_info);
884inline void pooling_size_mismatch(serial_size_t in_width,
885 serial_size_t in_height,
886 serial_size_t pooling_size_x,
887 serial_size_t pooling_size_y) {
888 std::ostringstream os;
891 os <<
"WxH:" << in_width <<
"x" << in_height << std::endl;
892 os <<
"pooling-size:" << pooling_size_x <<
"x" << pooling_size_y << std::endl;
894 std::string detail_info = os.str();
896 throw nn_error(
"width/height not multiple of pooling size" + detail_info);
900template <
typename T,
typename U>
901void graph_traverse(layer *root_node, T&& node_callback, U&& edge_callback) {
902 std::unordered_set<layer*> visited;
903 std::queue<layer*> S;
908 layer *curr = S.front();
910 visited.insert(curr);
912 node_callback(*curr);
914 auto edges = curr->next();
915 for (
auto e : edges) {
920 auto prev = curr->prev_nodes();
921 for (
auto p : prev) {
924 layer* l =
dynamic_cast<layer*
>(p);
925 if (visited.find(l) == visited.end()) {
930 auto next = curr->next_nodes();
931 for (
auto n : next) {
934 layer* l =
dynamic_cast<layer*
>(n);
935 if (visited.find(l) == visited.end()) {