65 if (OpKernel::device_ !=
nullptr) {
66 auto params = OpKernel::params_->conv();
67 init_libdnn(OpKernel::device_, params);
75 const tensor_t&
W =
context.input(1);
76 const tensor_t& bias =
context.input(2);
87 CLCudaAPI::Context
ctx = OpKernel::device_->context();
88 CLCudaAPI::Queue
queue = OpKernel::device_->queue();
90 for (serial_size_t
i = 0;
i <
in_data.size(); ++
i) {
98 W[0].begin(),
W[0].end());
101 bias[0].begin(), bias[0].end());
166 std::vector<float_t> out(
out_data[
i].size(), 0);
179 std::copy(std::begin(out), std::end(out), std::begin(
out_data[
i]));
183 throw nn_error(
"TinyDNN was not compiled with LibDNN support.");
195 return reinterpret_cast<const float_t*
>(
202 assert(device !=
nullptr);
205 greentea::device::setupViennaCLContext(device->deviceId(),
206 device->context()(), device->device()(), device->queue()());
209 std::make_shared<greentea::device>(
215 greentea::Backend::BACKEND_OpenCL
217 greentea::Backend::BACKEND_CUDA
219 greentea::Backend::BACKEND_CPU
227 greentea::LibDNNConfig
config;
233 const float_t dy = params.in_padded.height_ - params.in.height_;
234 const float_t dx = params.in_padded.width_ - params.in.width_;
236 std::vector<int32_t> in_shape = {
243 std::vector<int32_t> out_shape = {
250 std::vector<int32_t>
kernel = {
251 params.weight.height_,
255 std::vector<int32_t>
pad = {
dy/2,
dx/2 };
257 std::vector<int32_t>
stride = {
262 std::vector<int32_t>
dilation = { 1, 1 };
264 config.in_shape = in_shape;
265 config.out_shape = out_shape;
272 config.bias_term = params.has_bias;
275 config.fast_unsafe_math =
false;
277 config.weights_backward =
false;
279 config.bias_backward =
false;
283 if (std::is_same<float_t, float>::value ||
284 dev_ptr_->CheckCapability(
"cl_khr_int64_base_atomics")) {
285 config.wgalgo = greentea::LIBDNN_CONVOLUTION_WG_ALGO_ATOMIC;
286 config.bwalgo = greentea::LIBDNN_CONVOLUTION_BW_ALGO_COL2IM_ATOMIC;
288 config.wgalgo = greentea::LIBDNN_CONVOLUTION_WG_ALGO_DIRECT;
289 config.bwalgo = greentea::LIBDNN_CONVOLUTION_BW_ALGO_IM2COL;
299 std::shared_ptr<greentea::device>
dev_ptr_;
300 std::shared_ptr<greentea::LibDNNConv<float_t> >
kernel_;
Definition op_kernel.h:72