Skip to content

Commit

Permalink
Merge pull request #400 from ksahlin/fix-bindings
Browse files Browse the repository at this point in the history
Fix StrobemerIndex.find() in Python bindings
  • Loading branch information
marcelm committed Mar 1, 2024
2 parents d420610 + 4141d21 commit 79cdd96
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ requires = [
"ninja; platform_system!='Windows'",
"nanobind>=0.2.0",
]

[tool.pytest.ini_options]
testpaths = ["tests"]
11 changes: 5 additions & 6 deletions src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ NB_MODULE(strobealign_extension, m_) {
.def_ro("syncmer", &IndexParameters::syncmer)
.def_ro("randstrobe", &IndexParameters::randstrobe)
;
nb::class_<RefRandstrobe>(m, "RefRandstrobeWithHash", "Randstrobe on a reference")
nb::class_<RefRandstrobe>(m, "RefRandstrobe", "Randstrobe on a reference")
.def_ro("position", &RefRandstrobe::position)
.def_ro("hash", &RefRandstrobe::hash)
.def_prop_ro("reference_index", &RefRandstrobe::reference_index)
.def_prop_ro("strobe2_offset", &RefRandstrobe::strobe2_offset)
;
Expand All @@ -119,11 +120,9 @@ NB_MODULE(strobealign_extension, m_) {
.def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector<RefRandstrobe> {
std::vector<RefRandstrobe> v;
auto position = index.find(key);
if (position != index.end()) {
/*while (index.randstrobes[position].hash == key) {
v.push_back(index.randstrobes[position]);
position++;
}*/
while (position != index.end() && index.get_hash(position) == key) {
v.push_back(index.get_randstrobe(position));
position++;
}
return v;
})
Expand Down
21 changes: 21 additions & 0 deletions tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,24 @@ def test_indexing_and_nams_finding():
ref_aligned = ref[nam.ref_start:nam.ref_end]
query_aligned = query[nam.query_start:nam.query_end]
score = nam.score


def test_index_find():
refs = strobealign.References.from_fasta("tests/phix.fasta")
index_parameters = strobealign.IndexParameters.from_read_length(100)
index = strobealign.StrobemerIndex(refs, index_parameters)
index.populate()

query = "TGCGTTTATGGTACGCTGGACTTTGTGGGATACCCTCGCTTTCCTGCTCCTGTTGAGTTTATTGCTGCCG"
query_randstrobes = strobealign.randstrobes_query(query, index_parameters)
assert query_randstrobes
# First randstrobe must be found
assert index.find(query_randstrobes[0].hash)

n = 0
for qr in query_randstrobes:
for rs in index.find(qr.hash):
n += 1
assert rs.hash == qr.hash
# Ensure the for loop did test something
assert n > 1

0 comments on commit 79cdd96

Please sign in to comment.