tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
conv_params.h
1/*
2 Copyright (c) 2016, Taiga Nomi, Edgar Riba
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28
29#include "params.h"
30
31namespace tiny_dnn {
32namespace core {
33
35 std::vector<const vec_t*> prev_out_padded_;
36 std::vector<vec_t> prev_out_buf_;
37 std::vector<vec_t> prev_delta_padded_;
38};
39
41 connection_table() : rows_(0), cols_(0) {}
42 connection_table(const bool *ar, serial_size_t rows, serial_size_t cols)
43 : connected_(rows * cols), rows_(rows), cols_(cols) {
44 std::copy(ar, ar + rows * cols, connected_.begin());
45 }
46 connection_table(serial_size_t ngroups, serial_size_t rows, serial_size_t cols)
47 : connected_(rows * cols, false), rows_(rows), cols_(cols) {
48 if (rows % ngroups || cols % ngroups) {
49 throw nn_error("invalid group size");
50 }
51
52 serial_size_t row_group = rows / ngroups;
53 serial_size_t col_group = cols / ngroups;
54
55 serial_size_t idx = 0;
56
57 for (serial_size_t g = 0; g < ngroups; g++) {
58 for (serial_size_t r = 0; r < row_group; r++) {
59 for (serial_size_t c = 0; c < col_group; c++) {
60 idx = (r + g * row_group) * cols_ + c + g * col_group;
61 connected_[idx] = true;
62 }
63 }
64 }
65 }
66
67 bool is_connected(serial_size_t x, serial_size_t y) const {
68 return is_empty() ? true : connected_[y * cols_ + x];
69 }
70
71 bool is_empty() const {
72 return rows_ == 0 && cols_ == 0;
73 }
74
75 template <typename Archive>
76 void serialize(Archive & ar) {
77 ar(cereal::make_nvp("rows", rows_), cereal::make_nvp("cols", cols_));
78
79 if (is_empty()) {
80 ar(cereal::make_nvp("connection", std::string("all")));
81 }
82 else {
83 ar(cereal::make_nvp("connection", connected_));
84 }
85 }
86
87 std::deque<bool> connected_;
88 serial_size_t rows_;
89 serial_size_t cols_;
90};
91
92class conv_params : public Params {
93 public:
96 index3d<serial_size_t> in_padded;
99 bool has_bias;
100 padding pad_type;
101 serial_size_t w_stride;
102 serial_size_t h_stride;
103
104 friend std::ostream& operator<<(std::ostream &o,
105 const core::conv_params& param) {
106 o << "in: " << param.in << "\n";
107 o << "out: " << param.out << "\n";
108 o << "in_padded: " << param.in_padded << "\n";
109 o << "weight: " << param.weight << "\n";
110 o << "has_bias: " << param.has_bias << "\n";
111 o << "w_stride: " << param.w_stride << "\n";
112 o << "h_stride: " << param.h_stride << "\n";
113 return o;
114 }
115};
116
117inline conv_params Params::conv() const {
118 return *(static_cast<const conv_params*>(this));
119}
120
122 public:
123 Conv2dPadding() {}
124 Conv2dPadding(const conv_params& params) : params_(params) {}
125
126 /* Applies padding to an input tensor given the convolution parameters
127 *
128 * @param in The input tensor
129 * @param out The output tensor with padding applied
130 */
131 void copy_and_pad_input(const tensor_t& in, tensor_t& out) {
132 if (params_.pad_type == padding::valid) {
133 return;
134 }
135
136 tensor_t buf(in.size());
137
138 for_i(true, buf.size(), [&](int sample) {
139 // alloc temporary buffer.
140 buf[sample].resize(params_.in_padded.size());
141
142 // make padded version in order to avoid corner-case in fprop/bprop
143 for (serial_size_t c = 0; c < params_.in.depth_; c++) {
144 float_t* pimg = &buf[sample][params_.in_padded.get_index(
145 params_.weight.width_ / 2,
146 params_.weight.height_ / 2, c)];
147 const float_t* pin = &in[sample][params_.in.get_index(0, 0, c)];
148
149 for (serial_size_t y = 0; y < params_.in.height_; y++) {
150 std::copy(pin, pin + params_.in.width_, pimg);
151 pin += params_.in.width_;
152 pimg += params_.in_padded.width_;
153 }
154 }
155 });
156
157 // shrink buffer to output
158 out = buf;
159 }
160
161 /* Applies unpadding to an input tensor given the convolution parameters
162 *
163 * @param in The input tensor
164 * @param out The output tensor with unpadding applied
165 */
166 void copy_and_unpad_delta(const tensor_t& delta, tensor_t& delta_unpadded) {
167 if (params_.pad_type == padding::valid) {
168 return;
169 }
170
171 tensor_t buf(delta.size());
172
173 for_i(true, buf.size(), [&](int sample) {
174 // alloc temporary buffer.
175 buf[sample].resize(params_.in.size());
176
177 for (serial_size_t c = 0; c < params_.in.depth_; c++) {
178 const float_t *pin =
179 &delta[sample][params_.in_padded.get_index(
180 params_.weight.width_ / 2,
181 params_.weight.height_ / 2, c)];
182 float_t *pdst = &buf[sample][params_.in.get_index(0, 0, c)];
183
184 for (serial_size_t y = 0; y < params_.in.height_; y++) {
185 std::copy(pin, pin + params_.in.width_, pdst);
186 pdst += params_.in.width_;
187 pin += params_.in_padded.width_;
188 }
189 }
190 });
191
192 // shrink buffer to output
193 delta_unpadded = buf;
194 }
195
196 private:
197 conv_params params_;
198};
199
200} // namespace core
201} // namespace tiny_dnn
Definition conv_params.h:121
Definition params.h:37
Definition conv_params.h:92
Simple image utility class.
Definition image.h:94
error exception class for tiny-dnn
Definition nn_error.h:37
Definition conv_params.h:40