33#include <cereal/archives/json.hpp>
34#include <cereal/types/memory.hpp>
35#include "tiny_dnn/util/nn_error.h"
36#include "tiny_dnn/util/macro.h"
37#include "tiny_dnn/layers/layers.h"
41template <
typename InputArchive>
44 void register_loader(
const std::string& name, std::function<std::shared_ptr<layer>(
InputArchive&)>
func) {
45 loaders_[name] = [=](
void*
ar) {
51 void register_type(
const std::string& name) {
52 type_names_[
typeid(
T)] = name;
58 if (loaders_.find(
layer_name) == loaders_.end()) {
59 throw nn_error(
"Failed to generate layer. Generator for " +
layer_name +
" is not found.\n"
60 "Please use CNN_REGISTER_LAYER_DESERIALIZER macro to register appropriate generator");
63 return loaders_[
layer_name](
reinterpret_cast<void*
>(&
ar));
66 const std::string& type_name(std::type_index index)
const {
67 if (type_names_.find(index) == type_names_.end()) {
68 throw nn_error(
"Typename is not registered");
70 return type_names_.at(index);
79 void check_if_enabled()
const {
80#ifdef CNN_NO_SERIALIZATION
82 "You are using load functions, but deserialization function is disabled in current configuration.\n\n"
83 "You need to undef CNN_NO_SERIALIZATION to enable these functions.\n"
84 "If you are using cmake, you can use -DUSE_SERIALIZER=ON option.\n\n");
89 std::map<std::string, std::function<std::shared_ptr<layer>(
void*)>> loaders_;
91 std::map<std::type_index, std::string> type_names_;
96#define CNN_REGISTER_LAYER_BODY(layer_type, layer_name) \
97 register_loader(layer_name, load_layer_impl<layer_type>);\
98 register_type<layer_type>(layer_name);
100#define CNN_REGISTER_LAYER(layer_type, layer_name) CNN_REGISTER_LAYER_BODY(layer_type, #layer_name)
102#define CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, activation_type, layer_name) \
103CNN_REGISTER_LAYER_BODY(layer_type<activation::activation_type>, #layer_name "<" #activation_type ">")
105#define CNN_REGISTER_LAYER_WITH_ACTIVATIONS(layer_type, layer_name) \
106CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_h, layer_name); \
107CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, softmax, layer_name); \
108CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, identity, layer_name); \
109CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, sigmoid, layer_name); \
110CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, relu, layer_name); \
111CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, leaky_relu, layer_name); \
112CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, elu, layer_name); \
113CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_hp1m2, layer_name)
116#include "serialization_layer_list.h"
119#undef CNN_REGISTER_LAYER_BODY
120#undef CNN_REGISTER_LAYER
121#undef CNN_REGISTER_LAYER_WITH_ACTIVATION
122#undef CNN_REGISTER_LAYER_WITH_ACTIVATIONS
126template <
typename InputArchive>
130 using ST =
typename std::aligned_storage<
sizeof(
T), CNN_ALIGNOF(
T)>::type;
132 std::unique_ptr<ST>
bn(
new ST());
134 cereal::memory_detail::LoadAndConstructLoadWrapper<InputArchive, T>
wrapper(
reinterpret_cast<T*
>(
bn.get()));
136 wrapper.CEREAL_SERIALIZE_FUNCTION_NAME(
ia);
138 std::shared_ptr<layer>
t;
139 t.reset(
reinterpret_cast<T*
>(
bn.get()));
146void start_loading_layer(T & ar) {}
149void finish_loading_layer(T & ar) {}
151inline void start_loading_layer(cereal::JSONInputArchive & ia) { ia.startNode(); }
153inline void finish_loading_layer(cereal::JSONInputArchive & ia) { ia.finishNode(); }
158template <
typename InputArchive>
160 start_loading_layer(
ia);
163 ia(cereal::make_nvp(
"type",
p));
166 finish_loading_layer(
ia);
Definition deserialization_helper.h:42
Simple image utility class.
Definition image.h:94
static std::shared_ptr< layer > load_layer(InputArchive &ia)
generate layer from cereal's Archive
Definition deserialization_helper.h:159
error exception class for tiny-dnn
Definition nn_error.h:37