Skip to content

Commit

Permalink
增加返回权重的方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Feb 10, 2024
1 parent e6a99ac commit bc6605e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/layer/abstract/param_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class ParamLayer : public Layer<float> {
*/
void set_bias(const std::vector<std::shared_ptr<Tensor<float>>>& bias) override;

std::shared_ptr<Tensor<float>> weight(int32_t index) const;

protected:
std::vector<std::shared_ptr<Tensor<float>>> weights_;
std::vector<std::shared_ptr<Tensor<float>>> bias_;
Expand Down
5 changes: 5 additions & 0 deletions source/layer/abstract/param_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ void ParamLayer::set_weights(const std::vector<float>& weights) {
}
}

std::shared_ptr<Tensor<float>> ParamLayer::weight(int32_t index) const {
CHECK_LE(index, this->weights_.size());
return this->weights_.at(index);
}

void ParamLayer::set_bias(const std::vector<float>& bias) {
size_t bias_size = 0;
const size_t elem_size = bias.size();
Expand Down

0 comments on commit bc6605e

Please sign in to comment.