Skip to content

Commit

Permalink
Fix torch api about getting and setting tensor attributes (#1115)
Browse files Browse the repository at this point in the history
* Fox torch api about getting/setting tensor attributes

* Release v1.23.2
  • Loading branch information
csukuangfj authored Nov 25, 2022
1 parent df515ed commit 1feafa0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ message(STATUS "Enabled languages: ${languages}")

project(k2 ${languages})

set(K2_VERSION "1.23.1")
set(K2_VERSION "1.23.2")

# ----------------- Supported build types for K2 project -----------------
set(K2_ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
Expand Down
19 changes: 18 additions & 1 deletion k2/torch/csrc/torch_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ std::vector<std::vector<int32_t>> BestPath(const FsaClassPtr &lattice) {
return aux_labels_vec;
}

FsaClassPtr ShortestPath(const FsaClassPtr &lattice) {
FsaClass path = ShortestPath(*lattice);
return std::make_shared<FsaClass>(path);
}

void ScaleTensorAttribute(FsaClassPtr &fsa, float scale,
const std::string &attribute) {
if (attribute == "scores") {
Expand All @@ -136,14 +141,26 @@ void ScaleTensorAttribute(FsaClassPtr &fsa, float scale,
}

torch::Tensor GetTensorAttr(FsaClassPtr &fsa, const std::string &attribute) {
if (attribute == "labels") {
return fsa->Labels();
} else if (attribute == "scores") {
return fsa->Scores();
}

K2_CHECK(fsa->HasTensorAttr(attribute))
<< "The given Fsa doesn't has the attribute : " << attribute;
return fsa->GetTensorAttr(attribute);
}

void SetTensorAttr(FsaClassPtr &fsa, const std::string &attribute,
torch::Tensor value) {
fsa->SetTensorAttr(attribute, value);
if (attribute == "labels") {
fsa->SetLabels(value);
} else if (attribute == "scores") {
fsa->SetScores(value);
} else {
fsa->SetTensorAttr(attribute, value);
}
}

// A wrapper for RnntDecodingStream which can connect the RnntDecodingStream
Expand Down
6 changes: 6 additions & 0 deletions k2/torch/csrc/torch_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ FsaClassPtr GetLattice(torch::Tensor log_softmax_out,
*/
std::vector<std::vector<int32_t>> BestPath(const FsaClassPtr &lattice);

/* Return the best path of a lattice.
*
* Different from `BestPath`, this function returns a lattice.
*/
FsaClassPtr ShortestPath(const FsaClassPtr &lattice);

/** Scale the given attribute for a Fsa.
*
* Note: Support only float type attributes.
Expand Down

0 comments on commit 1feafa0

Please sign in to comment.