Skip to content

Commit

Permalink
Merge pull request caikit#336 from evaline-ju/int-threshold
Browse files Browse the repository at this point in the history
🥅 Allow int thresholds
  • Loading branch information
evaline-ju authored Mar 13, 2024
2 parents 01526aa + f2f8380 commit f55b082
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
At this time this module is only designed for inference"""

# Standard
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Union
import os

# First Party
Expand Down Expand Up @@ -122,22 +122,24 @@ def __init__(

@TokenClassificationTask.taskmethod()
def run(
self, text: str, threshold: Optional[float] = None
self, text: str, threshold: Optional[Union[float, int]] = None
) -> TokenClassificationResults:
"""Run classification on text split into spans. Returns results
based on score threshold for labels that are to be outputted
Args:
text: str
Document to run classification on
threshold: float
threshold: float | int
(Optional) Threshold based on which to return score results
Returns:
TokenClassificationResults
"""
error.type_check("<NLP82129006E>", str, text=text)
error.type_check("<NLP01414077E>", float, allow_none=True, threshold=threshold)
error.type_check(
"<NLP01414077E>", float, int, allow_none=True, threshold=threshold
)

if threshold is None:
threshold = self.default_threshold
Expand Down Expand Up @@ -189,7 +191,7 @@ def run(

@TokenClassificationTask.taskmethod(input_streaming=True, output_streaming=True)
def run_bidi_stream(
self, text_stream: Iterable[str], threshold: Optional[float] = None
self, text_stream: Iterable[str], threshold: Optional[Union[float, int]] = None
) -> Iterable[TokenClassificationStreamResult]:
"""Run bi-directional streaming inferencing for this module.
Run classification on text split into spans. Returns results
Expand All @@ -198,13 +200,15 @@ def run_bidi_stream(
Args:
text_stream: Iterable[str]
Text stream to run classification on
threshold: float
threshold: float | int
(Optional) Threshold based on which to return score results
Returns:
Iterable[TokenClassificationStreamResult]
"""
error.type_check("<NLP96166348E>", float, allow_none=True, threshold=threshold)
error.type_check(
"<NLP96166348E>", float, int, allow_none=True, threshold=threshold
)
# TODO: For optimization implement window based approach.
if threshold is None:
threshold = self.default_threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def test_bootstrap_run_with_threshold():
) # 4 (all) results over 0.0 expected


def test_bootstrap_run_with_int_threshold():
"""Check if we can bootstrap span classification models with overriden int threshold"""
token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0)
assert isinstance(token_classification_result, TokenClassificationResults)
assert (
len(token_classification_result.results) == 4
) # 4 (all) results over 0 expected


def test_bootstrap_run_with_optional_labels_to_output():
"""Check if we can run span classification models with labels_to_output"""
model = FilteredSpanClassification.bootstrap(
Expand Down

0 comments on commit f55b082

Please sign in to comment.