diff --git a/include/layer/abstract/param_layer.hpp b/include/layer/abstract/param_layer.hpp index 5048577f..1c496798 100644 --- a/include/layer/abstract/param_layer.hpp +++ b/include/layer/abstract/param_layer.hpp @@ -122,6 +122,8 @@ class ParamLayer : public Layer { */ void set_bias(const std::vector>>& bias) override; + std::shared_ptr> weight(int32_t index) const; + protected: std::vector>> weights_; std::vector>> bias_; diff --git a/source/layer/abstract/param_layer.cpp b/source/layer/abstract/param_layer.cpp index be003b7c..50cf846c 100644 --- a/source/layer/abstract/param_layer.cpp +++ b/source/layer/abstract/param_layer.cpp @@ -94,6 +94,11 @@ void ParamLayer::set_weights(const std::vector& weights) { } } +std::shared_ptr> ParamLayer::weight(int32_t index) const { + CHECK_LE(index, this->weights_.size()); + return this->weights_.at(index); +} + void ParamLayer::set_bias(const std::vector& bias) { size_t bias_size = 0; const size_t elem_size = bias.size();