87 typedef std::vector<layerptr_t>::iterator iterator;
88 typedef std::vector<layerptr_t>::const_iterator const_iterator;
110 for (
auto l : nodes_) {
119 for (
auto l : nodes_) {
125 for (
auto l : nodes_) {
130 size_t size()
const {
return nodes_.size(); }
131 iterator begin() {
return nodes_.begin(); }
132 iterator end() {
return nodes_.end(); }
133 const_iterator begin()
const {
return nodes_.begin(); }
134 const_iterator end()
const {
return nodes_.end(); }
135 layer* operator[] (
size_t index) {
return nodes_[index]; }
136 const layer* operator[] (
size_t index)
const {
return nodes_[index]; }
137 serial_size_t in_data_size()
const {
return nodes_.front()->in_data_size(); }
138 serial_size_t out_data_size()
const {
return nodes_.back()->out_data_size(); }
140 template <
typename T>
141 const T& at(
size_t index)
const {
142 const T* v =
dynamic_cast<const T*
>(nodes_[index]);
144 throw nn_error(
"failed to cast");
147 template <
typename T>
148 T& at(
size_t index) {
149 T* v =
dynamic_cast<T*
>(nodes_[index]);
151 throw nn_error(
"failed to cast");
155 virtual float_t target_value_min(
int out_channel = 0)
const {
156 CNN_UNREFERENCED_PARAMETER(out_channel);
157 return nodes_.back()->out_value_range().first;
160 virtual float_t target_value_max(
int out_channel = 0)
const {
161 CNN_UNREFERENCED_PARAMETER(out_channel);
162 return nodes_.back()->out_value_range().second;
165 void save(std::ostream& os)
const {
166 for (
auto& l : nodes_) {
171 void load(std::istream& is) {
173 for (
auto& l : nodes_) {
178 virtual void load(
const std::vector<float_t>& vec) {
181 for (
auto& l : nodes_) {
186 void label2vec(
const label_t* t, serial_size_t num, std::vector<vec_t> *vec)
const {
187 serial_size_t outdim = out_data_size();
190 for (serial_size_t i = 0; i < num; i++) {
191 assert(t[i] < outdim);
192 vec->emplace_back(outdim, target_value_min());
193 vec->back()[t[i]] = target_value_max();
197 template <
typename OutputArchive>
198 void save_model(OutputArchive & oa)
const;
200 template <
typename InputArchive>
201 void load_model(InputArchive & ia);
204 template <
typename OutputArchive>
205 void save_weights(OutputArchive & oa)
const {
206 for (
auto n : nodes_) {
211 template <
typename InputArchive>
212 void load_weights(InputArchive & ia) {
213 for (
auto n : nodes_) {
219 template <
typename T>
220 void push_back(T&& node) {
221 push_back_impl(std::forward<T>(node),
222 typename std::is_rvalue_reference<
decltype(node)>::type());
225 template <
typename T>
226 void push_back(std::shared_ptr<T> node) {
227 own_nodes_.push_back(node);
228 nodes_.push_back(own_nodes_.back().get());
234 std::vector<tensor_t> reorder_for_layerwise_processing(
const std::vector<tensor_t>& input) {
235 const serial_size_t sample_count =
static_cast<serial_size_t
>(input.size());
236 const serial_size_t channel_count =
static_cast<serial_size_t
>(input[0].size());
239 std::vector<tensor_t> output(channel_count, tensor_t(sample_count));
241 for (serial_size_t sample = 0; sample < sample_count; ++sample) {
242 assert(input[sample].size() == channel_count);
243 for (serial_size_t channel = 0; channel < channel_count; ++channel) {
244 output[channel][sample] = input[sample][channel];
251 template <
typename T>
252 void push_back_impl(T&& node, std::true_type) {
253 own_nodes_.push_back(std::make_shared<
254 typename std::remove_reference<T>::type>(std::forward<T>(node)));
255 nodes_.push_back(own_nodes_.back().get());
258 template <
typename T>
259 void push_back_impl(T&& node, std::false_type) {
260 nodes_.push_back(&node);
264 std::vector<std::shared_ptr<layer>> own_nodes_;
266 std::vector<layerptr_t> nodes_;
281 for (
auto l = nodes_.rbegin();
l != nodes_.rend();
l++) {
286 std::vector<tensor_t>
forward(
const std::vector<tensor_t>&
first)
override {
293 for (
auto l : nodes_) {
297 const std::vector<tensor_t> out = nodes_.back()->output();
299 return normalize_out(out);
302 template <
typename T>
304 push_back(std::forward<T>(
layer));
306 if (nodes_.size() != 1) {
307 auto head = nodes_[nodes_.size()-2];
308 auto tail = nodes_[nodes_.size()-1];
310 auto out =
head->outputs();
311 auto in =
tail->inputs();
313 check_connectivity();
316 void check_connectivity() {
317 for (serial_size_t i = 0; i < nodes_.size() - 1; i++) {
318 auto out = nodes_[i]->outputs();
319 auto in = nodes_[i+1]->inputs();
321 if (out[0] != in[0]) {
327 template <
typename InputArchive>
328 void load_connections(InputArchive& ia) {
329 for (serial_size_t i = 0; i < nodes_.size() - 1; i++) {
330 auto head = nodes_[i];
331 auto tail = nodes_[i + 1];
332 connect(head, tail, 0, 0);
336 template <
typename OutputArchive>
337 void save_connections(OutputArchive& )
const { }
342 std::vector<tensor_t> normalize_out(
const std::vector<tensor_t>& out)
345 std::vector<tensor_t> normalized_output;
347 const size_t sample_count = out[0].size();
348 normalized_output.resize(sample_count, tensor_t(1));
350 for (
size_t sample = 0; sample < sample_count; ++sample) {
351 normalized_output[sample][0] = out[0][sample];
354 return normalized_output;
369 throw nn_error(
"input size mismatch");
379 for (
auto l = nodes_.rbegin();
l != nodes_.rend();
l++) {
389 throw nn_error(
"input size mismatch");
399 for (
auto l : nodes_) {
405 void construct(
const std::vector<layerptr_t>& input,
406 const std::vector<layerptr_t>& output) {
407 std::vector<layerptr_t>
sorted;
408 std::vector<nodeptr_t>
input_nodes(input.begin(), input.end());
409 std::unordered_map<node*, std::vector<uint8_t>>
removed_edge;
417 std::vector<node*> next =
curr->next_nodes();
419 for (
size_t i = 0;
i < next.size();
i++) {
420 if (!next[
i])
continue;
424 std::vector<uint8_t>(next[
i]->prev_nodes().size(), 0);
427 std::vector<uint8_t>& removed = removed_edge[next[i]];
428 removed[find_index(next[i]->prev_nodes(), curr)] = 1;
430 if (std::all_of(removed.begin(), removed.end(), [](uint8_t x) {
432 input_nodes.push_back(next[i]);
437 for (
auto& n : sorted) {
441 input_layers_ = input;
442 output_layers_ = output;
450 struct _graph_connection {
451 void add_connection(serial_size_t head, serial_size_t tail, serial_size_t head_index, serial_size_t tail_index) {
452 if (!is_connected(head, tail, head_index, tail_index)) {
453 connections.emplace_back(head, tail, head_index, tail_index);
457 bool is_connected(serial_size_t head, serial_size_t tail, serial_size_t head_index, serial_size_t tail_index)
const {
458 return std::find(connections.begin(),
460 std::make_tuple(head, tail, head_index, tail_index)) != connections.end();
463 template <
typename Archive>
464 void serialize(Archive & ar) {
465 ar(CEREAL_NVP(connections), CEREAL_NVP(in_nodes), CEREAL_NVP(out_nodes));
468 std::vector<std::tuple<serial_size_t, serial_size_t, serial_size_t, serial_size_t>> connections;
469 std::vector<serial_size_t> in_nodes, out_nodes;
472 template <
typename OutputArchive>
473 void save_connections(OutputArchive& oa)
const {
474 _graph_connection gc;
475 std::unordered_map<node*, serial_size_t> node2id;
476 serial_size_t idx = 0;
478 for (
auto n : nodes_) {
481 for (
auto l : input_layers_) {
482 gc.in_nodes.push_back(node2id[l]);
484 for (
auto l : output_layers_) {
485 gc.out_nodes.push_back(node2id[l]);
488 for (
auto l : input_layers_) {
489 graph_traverse(l, [=](layer& l) {}, [&](edge& e) {
490 auto next = e.next();
491 serial_size_t head_index = e.prev()->next_port(e);
493 for (
auto n : next) {
494 serial_size_t tail_index = n->prev_port(e);
495 gc.add_connection(node2id[e.prev()], node2id[n], head_index, tail_index);
500 oa(cereal::make_nvp(
"graph", gc));
503 template <
typename InputArchive>
504 void load_connections(InputArchive& ia) {
505 _graph_connection gc;
506 ia(cereal::make_nvp(
"graph", gc));
508 for (
auto c : gc.connections) {
509 serial_size_t head, tail, head_index, tail_index;
510 std::tie(head, tail, head_index, tail_index) = c;
511 connect(nodes_[head], nodes_[tail], head_index, tail_index);
513 for (
auto in : gc.in_nodes) {
514 input_layers_.push_back(nodes_[in]);
516 for (
auto out : gc.out_nodes) {
517 output_layers_.push_back(nodes_[out]);
522 std::vector<tensor_t> merge_outs() {
523 std::vector<tensor_t> merged;
524 serial_size_t output_channel_count =
static_cast<serial_size_t
>(output_layers_.size());
525 for (serial_size_t output_channel = 0; output_channel < output_channel_count; ++output_channel) {
526 std::vector<tensor_t> out = output_layers_[output_channel]->output();
528 serial_size_t sample_count =
static_cast<serial_size_t
>(out[0].size());
529 if (output_channel == 0) {
530 assert(merged.empty());
531 merged.resize(sample_count, tensor_t(output_channel_count));
534 assert(merged.size() == sample_count);
536 for (serial_size_t sample = 0; sample < sample_count; ++sample) {
537 merged[sample][output_channel] = out[0][sample];
543 serial_size_t find_index(
const std::vector<node*>& nodes,
545 for (serial_size_t i = 0; i < nodes.size(); i++) {
546 if (nodes[i] ==
static_cast<node*
>(&*target))
return i;
548 throw nn_error(
"invalid connection");
550 std::vector<layerptr_t> input_layers_;
551 std::vector<layerptr_t> output_layers_;