Skip to content

Commit

Permalink
Fix break segments and overlap (#58)
Browse files Browse the repository at this point in the history
* Fix the issue of dropping cuts

* fix overlap, still has some problems

* refactor is_overlap

* Fix overlap

* release v0.10
  • Loading branch information
pkufool authored Jan 10, 2024
1 parent 98aad14 commit 6847d91
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 95 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(textsearch)

set(TS_VERSION "0.9")
set(TS_VERSION "0.10")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
10 changes: 6 additions & 4 deletions examples/libriheavy/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ def get_params() -> AttributeDict:
# you can find the docs in textsearch/match.py#align_queries
"num_close_matches": 2,
"segment_length": 5000,
"reference_length_difference": 0.1,
"reference_length_difference": 0.4,
"min_matched_query_ratio": 0.33,
# parameters for splitting aligned queries
# you can find the docs in textsearch/match.py#split_aligned_queries
"preceding_context_length": 1000,
"timestamp_position": "current",
"silence_length_to_break": 0.45,
"overlap_ratio": 0.4,
"min_duration": 2,
"max_duration": 30,
"expected_duration": (5, 20),
Expand Down Expand Up @@ -188,6 +189,7 @@ def load_data(
books.append(book)

if not transcripts:
logging.warning(f"No transcripts found.")
return {}

logging.debug(f"Worker[{worker_index}] loading cuts and books done.")
Expand Down Expand Up @@ -321,6 +323,7 @@ def split(
preceding_context_length=params.preceding_context_length,
timestamp_position=params.timestamp_position,
silence_length_to_break=params.silence_length_to_break,
overlap_ratio=params.overlap_ratio,
min_duration=params.min_duration,
max_duration=params.max_duration,
expected_duration=params.expected_duration,
Expand Down Expand Up @@ -457,9 +460,7 @@ def main():
batch_cuts = []
logging.info(f"Start processing...")
for i, cut in enumerate(raw_cuts):
if len(batch_cuts) < params.batch_size:
batch_cuts.append(cut)
else:
if len(batch_cuts) >= params.batch_size:
process_one_batch(
params,
batch_cuts=batch_cuts,
Expand All @@ -469,6 +470,7 @@ def main():
)
batch_cuts = []
logging.info(f"Number of cuts have been loaded is {i}")
batch_cuts.append(cut)
if len(batch_cuts):
process_one_batch(
params,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fasttextsearch"
version = "0.9"
version = "0.10"
authors = [
{ name="Next-gen Kaldi development team", email="wkang.pku@gmail.com" },
]
Expand Down
1 change: 1 addition & 0 deletions textsearch/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ endfunction()
if(TS_ENABLE_TESTS)
set(test_srcs
test_find_close_matches.py
test_is_overlap.py
test_levenshtein_distance.py
test_match.py
test_row_ids_to_row_splits.py
Expand Down
80 changes: 80 additions & 0 deletions textsearch/python/tests/test_is_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
#
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run this single test, use
#
# ctest --verbose -R match_test_py

import unittest

from textsearch.utils import is_overlap


class TestOverlap(unittest.TestCase):
def test_is_overlap(self):
candidates = [
[20, 30],
[15, 25],
[10, 21.1],
[1, 10],
[60, 70],
[65, 73],
[68.5, 85],
[25, 35],
[45, 55],
[20, 25],
[21, 25],
[34.5, 46.5],
[35, 46.1],
[25, 35],
[26, 34],
[44, 70.5],
]
selected_ranges: List[Tuple[float, float]] = []
selected_indexes: List[int] = []
segments = []
overlapped_segments = []
for r in candidates:
status, index = is_overlap(
selected_ranges,
selected_indexes,
query=(r[0], r[1]),
segment_index=len(segments),
overlap_ratio=0.1,
)
if status:
if index is not None:
overlapped_segments.append(index)
segments.append(r)
else:
segments.append(r)
for index in sorted(overlapped_segments, reverse=True):
segments.pop(index)
expected_segments = [
[10, 21.1],
[1, 10],
[68.5, 85],
[25, 35],
[21, 25],
[35, 46.1],
]
assert segments == expected_segments


if __name__ == "__main__":
unittest.main()
129 changes: 63 additions & 66 deletions textsearch/python/textsearch/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,44 +145,6 @@ def _break_query(
# [(query_start, query_end, target_start, target_end)]
segments: List[Tuple[int, int, int, int]] = []

def add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
):
num_chunk = (query_end - query_start) // segment_length
if num_chunk > 0:
for i in range(num_chunk):
real_target_end = (
target_start + segment_length
if target_start + segment_length < target_end
else target_end
)
segments.append(
(
query_start,
query_start + segment_length,
target_start,
real_target_end,
)
)
query_start += segment_length
target_start += segment_length
# if the remaining part is smaller than segment_length // 4, we will
# append it to the last segment rather than creating a new segment.
if segments and query_end - query_start < segment_length // 4:
segments[-1] = (
segments[-1][0],
query_end,
segments[-1][2],
target_end,
)
else:
segments.append((query_start, query_end, target_start, target_end))

target_doc_id = sourced_text.doc[matched_points[max_item[0]][1]]
target_base = sourced_text.doc_splits[target_doc_id]
next_target_base = sourced_text.doc_splits[target_doc_id + 1]
Expand All @@ -207,7 +169,15 @@ def add_segments(
for ind in range(max_item[0], max_item[1]):
if matched_points[ind][0] - prev_break_point[0] > segment_length:
if ind == max_item[0]:
continue
segments.append(
(
prev_break_point[0],
matched_points[ind][0],
prev_break_point[1],
matched_points[ind][1],
)
)
prev_break_point = matched_points[ind]
else:
query_start = prev_break_point[0]
query_end = matched_points[ind - 1][0]
Expand All @@ -217,17 +187,16 @@ def add_segments(
ratio = (target_end - target_start) / (query_end - query_start)
half = reference_length_difference / 2
if ratio < 1 - half or ratio > 1 + half:
logging.debug(
f"Invalid ratio for segment: "
f"{query_start, query_end, target_start, target_end}"
)
continue

prev_break_point = (query_end, target_end)
add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
segments.append(
(query_start, query_end, target_start, target_end)
)
prev_break_point = (query_end, target_end)

query_start, target_start = prev_break_point
query_end = next_query_base
Expand All @@ -248,14 +217,7 @@ def add_segments(
else:
segments.append((query_start, query_end, target_start, target_end))
else:
add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
)
segments.append((query_start, query_end, target_start, target_end))
return segments


Expand Down Expand Up @@ -470,15 +432,22 @@ def align_queries(
# in sourced_text
matched_points = get_longest_increasing_pairs(seq1, seq2)

if len(matched_points) == 0:
continue

# In the algorithm of `find_close_matches`,
# `sourced_text.binary_text.size - 1` means no close_matches
trim_pos = len(matched_points) - 1
while matched_points[trim_pos][1] == sourced_text.binary_text.size - 1:
trim_pos -= 1
matched_points = matched_points[0:trim_pos]
if len(matched_points) != 0:
trim_pos = len(matched_points) - 1
while (
matched_points[trim_pos][1] == sourced_text.binary_text.size - 1
):
trim_pos -= 1
matched_points = matched_points[0:trim_pos]

if len(matched_points) == 0:
logging.warning(
f"Skipping query {q}, no matched points between query and target"
f"in close_matches."
)
continue

# The following code guarantees the matched points are in the same
# reference document. We will choose the reference document that matches
Expand Down Expand Up @@ -988,6 +957,7 @@ def _split_into_segments(
preceding_context_length: int = 1000,
timestamp_position: str = "middle", # previous, middle, current
silence_length_to_break: float = 0.6, # in second
overlap_ratio: float = 0.35, # percentage
min_duration: float = 2, # in second
max_duration: float = 30, # in second
expected_duration: Tuple[float, float] = (5, 20), # in second
Expand Down Expand Up @@ -1024,6 +994,10 @@ def _split_into_segments(
preceding or succeeding silence length greater than this value, we will
add it as a possible breaking point.
Caution: Only be used when there are no punctuations in target_source.
overlap_ratio:
The ratio of overlapping part to the query or existing segments. If the
ratio is greater than `overlap_ratio` we will drop the query or existing
segment.
min_duration:
The minimum duration (in second) allowed for a segment.
max_duration:
Expand Down Expand Up @@ -1079,13 +1053,28 @@ def _split_into_segments(
# Handle the overlapping
# Caution: Don't modified selected_ranges, it will be manipulated in
# `is_overlap` and will be always kept sorted.
selected_ranges: List[Tuple[int, int]] = []
# Don't modified selected_indexes also, it will be manipulated in `is_overlap`
# according to selected_ranges.
selected_ranges: List[Tuple[float, float]] = []
selected_indexes: List[int] = []
segments = []
overlapped_segments = []
for r in candidates:
if not is_overlap(
selected_ranges, query=(r[0], r[1]), overlap_ratio=0.5
):
status, index = is_overlap(
selected_ranges,
selected_indexes,
query=(aligns[r[0]]["hyp_time"], aligns[r[1]]["hyp_time"]),
segment_index=len(segments),
overlap_ratio=overlap_ratio,
)
if status:
if index is not None:
overlapped_segments.append(index)
segments.append(r)
else:
segments.append(r)
for index in sorted(overlapped_segments, reverse=True):
segments.pop(index)

results = []

Expand Down Expand Up @@ -1217,6 +1206,7 @@ def _split_helper(
preceding_context_length: int,
timestamp_position: str,
silence_length_to_break: float,
overlap_ratio: float,
min_duration: float,
max_duration: float,
expected_duration: Tuple[float, float],
Expand All @@ -1233,6 +1223,7 @@ def _split_helper(
preceding_context_length=preceding_context_length,
timestamp_position=timestamp_position,
silence_length_to_break=silence_length_to_break,
overlap_ratio=overlap_ratio,
min_duration=min_duration,
max_duration=max_duration,
expected_duration=expected_duration,
Expand All @@ -1250,6 +1241,7 @@ def split_aligned_queries(
preceding_context_length: int = 1000,
timestamp_position: str = "current", # previous, middle, current
silence_length_to_break: float = 0.6, # in second
overlap_ratio: float = 0.35,
min_duration: float = 2, # in second
max_duration: float = 30, # in second
expected_duration: Tuple[float, float] = (5, 20), # in second
Expand Down Expand Up @@ -1288,6 +1280,10 @@ def split_aligned_queries(
preceding or succeeding silence length greater than this value, we will
add it as a possible breaking point.
Caution: Only be used when there are no punctuations in target_source.
overlap_ratio:
The ratio of overlapping part to the query or existing segments. If the
ratio is greater than `overlap_ratio` we will drop the query or existing
segment.
min_duration:
The minimum duration (in second) allowed for a segment.
max_duration:
Expand Down Expand Up @@ -1342,6 +1338,7 @@ def split_aligned_queries(
preceding_context_length,
timestamp_position,
silence_length_to_break,
overlap_ratio,
min_duration,
max_duration,
expected_duration,
Expand Down
Loading

0 comments on commit 6847d91

Please sign in to comment.