From bc6605e6019b909f456b8388d97ab8d3dc4767f2 Mon Sep 17 00:00:00 2001 From: zjhellofss Date: Sat, 10 Feb 2024 12:25:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=BF=94=E5=9B=9E=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E7=9A=84=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/layer/abstract/param_layer.hpp | 2 ++ source/layer/abstract/param_layer.cpp | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/include/layer/abstract/param_layer.hpp b/include/layer/abstract/param_layer.hpp index 5048577f..1c496798 100644 --- a/include/layer/abstract/param_layer.hpp +++ b/include/layer/abstract/param_layer.hpp @@ -122,6 +122,8 @@ class ParamLayer : public Layer { */ void set_bias(const std::vector>>& bias) override; + std::shared_ptr> weight(int32_t index) const; + protected: std::vector>> weights_; std::vector>> bias_; diff --git a/source/layer/abstract/param_layer.cpp b/source/layer/abstract/param_layer.cpp index be003b7c..50cf846c 100644 --- a/source/layer/abstract/param_layer.cpp +++ b/source/layer/abstract/param_layer.cpp @@ -94,6 +94,11 @@ void ParamLayer::set_weights(const std::vector& weights) { } } +std::shared_ptr> ParamLayer::weight(int32_t index) const { + CHECK_LE(index, this->weights_.size()); + return this->weights_.at(index); +} + void ParamLayer::set_bias(const std::vector& bias) { size_t bias_size = 0; const size_t elem_size = bias.size();