1 #ifndef ROCKY_ETNA_LINEAR
2 #define ROCKY_ETNA_LINEAR
10 enum opt {bias, no_bias};
16 template<
typename T_e,
int T_in_num,
int T_in_dim,
int T_out_dim,
17 opt T_opt_bias=opt::bias>
20 static constexpr
int deduce_num_params_weights(){
21 return T_in_dim * T_out_dim;
23 static constexpr
int deduce_num_params_bias(){
24 if constexpr(T_opt_bias == opt::bias)
29 static constexpr
int deduce_num_params(){
30 return deduce_num_params_weights() + deduce_num_params_bias();
35 void feed(T_e* layer_mem_ptr, T_e* in_mem_ptr, T_e* out_mem_ptr){
36 Eigen::Map<Eigen::Matrix<T_e, T_in_dim, T_out_dim, Eigen::RowMajor>> W_(layer_mem_ptr);
37 Eigen::Map<Eigen::Matrix<T_e, T_in_num, T_in_dim, Eigen::RowMajor>> In_(in_mem_ptr);
38 Eigen::Map<Eigen::Matrix<T_e, T_in_num, T_out_dim, Eigen::RowMajor>> Out_(out_mem_ptr);
41 if constexpr (T_opt_bias == opt::bias){
42 Eigen::Map<Eigen::Matrix<T_e, 1, T_out_dim, Eigen::RowMajor>> Bias_(layer_mem_ptr + T_in_dim * T_out_dim);
43 Out_.rowwise() += Bias_;
48 template<
typename T_e,
int T_layers_num,
49 int T_in_num,
int T_in_dim,
50 int T_out_dim,
int T_hidden_dim,
51 opt T_opt_bias=opt::bias>
54 static constexpr
int deduce_num_params_in(){
57 static constexpr
int deduce_num_params_hidden(){
60 static constexpr
int deduce_num_params_out(){
63 static constexpr
int deduce_num_params(){
64 return T_layers_num * deduce_num_params_hidden() + deduce_num_params_in() + deduce_num_params_out();
75 void feed(T_e* layer_mem_ptr, T_e* in_mem_ptr, T_e* out_mem_ptr){
81 T_e* H1_ =
new T_e[T_in_num * T_hidden_dim];
82 T_e* H2_ =
new T_e[T_in_num * T_hidden_dim];
84 l_in.
feed(layer_mem_ptr, in_mem_ptr, H1_);
87 int offset = l_in.deduce_num_params();
88 for (
int hidden=0; hidden<T_layers_num; hidden++){
89 if (hidden % 2 == 0){ src = H1_; dest = H2_;}
90 else{ src = H2_; dest = H1_;}
91 l_hidden.
feed(layer_mem_ptr + offset, src, dest);
92 offset += l_hidden.deduce_num_params();
95 if constexpr (T_layers_num % 2 == 0)
96 l_out.
feed(layer_mem_ptr + offset, H1_, out_mem_ptr);
98 l_out.
feed(layer_mem_ptr + offset, H2_, out_mem_ptr);