35#include <unordered_set>
37#include "tiny_dnn/util/util.h"
38#include "tiny_dnn/util/product.h"
39#include "tiny_dnn/util/image.h"
40#include "tiny_dnn/util/weight_init.h"
41#include "tiny_dnn/optimizers/optimizer.h"
43#include "tiny_dnn/activations/activation_function.h"
51typedef node* nodeptr_t;
52typedef std::shared_ptr<edge> edgeptr_t;
54typedef layer* layerptr_t;
59class node :
public std::enable_shared_from_this<node> {
61 node(serial_size_t in_size, serial_size_t out_size)
62 : prev_(in_size), next_(out_size) {}
65 const std::vector<edgeptr_t>& prev()
const {
return prev_; }
66 const std::vector<edgeptr_t>& next()
const {
return next_; }
68 serial_size_t prev_port(
const edge&
e)
const {
69 auto it = std::find_if(prev_.begin(), prev_.end(),
70 [&](edgeptr_t
ep) { return ep.get() == &e; });
71 return (serial_size_t)std::distance(prev_.begin(),
it);
74 serial_size_t next_port(
const edge&
e)
const {
75 auto it = std::find_if(next_.begin(), next_.end(),
76 [&](edgeptr_t
ep) { return ep.get() == &e; });
77 return (serial_size_t)std::distance(next_.begin(),
it);
80 std::vector<node*> prev_nodes()
const;
81 std::vector<node*> next_nodes()
const;
88 mutable std::vector<edgeptr_t> prev_;
89 mutable std::vector<edgeptr_t> next_;
100 data_({vec_t(shape.size())}),
101 grad_({vec_t(shape.size())}),
104 void merge_grads(vec_t *
dst) {
105 dst->resize(grad_[0].size());
106 std::fill(
dst->begin(),
dst->end(),
static_cast<float_t>(0));
110 vectorize::reduce<float_t>(&grad_[
sample][0],
dst->size(), &(*
dst)[0]);
120 tensor_t* get_data() {
124 const tensor_t* get_data()
const {
128 tensor_t* get_gradient() {
132 const tensor_t* get_gradient()
const {
136 const std::vector<node*>& next()
const {
return next_; }
137 node* prev() {
return prev_; }
138 const node* prev()
const {
return prev_; }
140 const shape3d& shape()
const {
return shape_; }
141 vector_type vtype()
const {
return vtype_; }
142 void add_next_node(
node* next) { next_.push_back(next); }
150 std::vector<node*> next_;
153inline std::vector<node*> node::prev_nodes()
const {
154 std::set<node*>
sets;
155 for (
auto&
e : prev_) {
156 if (
e &&
e->prev())
sets.insert(
e->prev());
158 return std::vector<node*>(sets.begin(), sets.end());
161inline std::vector<node*> node::next_nodes()
const {
162 std::set<node*> sets;
163 for (
auto& e : next_) {
166 sets.insert(n.begin(), n.end());
169 return std::vector<node*>(sets.begin(), sets.end());
175 nodes_.push_back(
l1); nodes_.push_back(
l2);
177 std::vector<T> nodes_;
186node_tuple<std::shared_ptr<T>> operator , (std::shared_ptr<T> l1, std::shared_ptr<T> l2) {
187 return node_tuple<std::shared_ptr<T>>(l1, l2);
191node_tuple<std::shared_ptr<T>> operator , (node_tuple<std::shared_ptr<T>> lhs, std::shared_ptr<T>& rhs) {
192 lhs.nodes_.push_back(rhs);
197node_tuple<T*> operator , (node_tuple<T*> lhs, T& rhs) {
198 lhs.nodes_.push_back(&rhs);
202template <
typename T,
typename U>
203inline std::shared_ptr<U>& operator << (std::shared_ptr<T>& lhs,
204 std::shared_ptr<U>& rhs) {
205 connect(lhs.get(), rhs.get());
209template <
typename T,
typename U>
210inline U& operator << (
const node_tuple<T>& lhs, U& rhs) {
211 for (serial_size_t i = 0; i < static_cast<serial_size_t>(lhs.nodes_.size()); i++) {
212 connect(&*lhs.nodes_[i], &*rhs, 0, i);
217template <
typename T,
typename U>
218inline node_tuple<T>& operator << (U& lhs,
const node_tuple<T>& rhs) {
219 for (serial_size_t i = 0; i < static_cast<serial_size_t>(rhs.nodes_.size()); i++) {
220 connect(&*lhs, &*rhs.nodes_[i], i, 0);
225template <
typename T,
typename U>
226inline U& operator << (
const node_tuple<T*>& lhs, U& rhs) {
227 for (serial_size_t i = 0; i < static_cast<serial_size_t>(lhs.nodes_.size()); i++) {
228 connect(lhs.nodes_[i], &rhs, 0, i);
233template <
typename T,
typename U>
234inline node_tuple<T*>& operator << (U& lhs,
const node_tuple<T*>& rhs) {
235 for (serial_size_t i = 0; i < static_cast<serial_size_t>(rhs.nodes_.size()); i++) {
236 connect(&lhs, rhs.nodes_[i], i, 0);
class containing input/output data
Definition node.h:95
Simple image utility class.
Definition image.h:94
base class of all kind of NN layers
Definition layer.h:62
base class of all kind of tinny-cnn data
Definition node.h:59