Skip to content

Commit

Permalink
Merge pull request #958 from pierce314159/957_mimic_re
Browse files Browse the repository at this point in the history
Closes #957: Mimic `re` library functionality applied to SegStrings
  • Loading branch information
reuster986 authored Nov 2, 2021
2 parents 6318846 + 87ca702 commit 8256fe3
Show file tree
Hide file tree
Showing 14 changed files with 1,296 additions and 261 deletions.
2 changes: 2 additions & 0 deletions ENVIRONMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ These env vars can be used to configure your build of Arkouda when running `make
- ARKOUDA_PRINT_PASSES_FILE : Setting this adds `--print-passes-file <file>` to the Chapel compiler flags and writes
the associated "pass timing" output to the specified file. This is mainly used in the nightly testing infrastructure.
- CHPL_DEBUG_FLAGS : We add `--print-passes` automatically, but you can add additional flags here.
- REGEX_MAX_CAPTURES : Set this to an integer to change the maximum number of capture groups accessible using ``Match.group``
(set to 20 by default)

#### Dependency Paths
Most folks install anaconda and link to these libraries through Makefile.paths instructions. If you have an alternative
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ ifdef ARKOUDA_PRINT_PASSES_FILE
PRINT_PASSES_FLAGS := --print-passes-file $(ARKOUDA_PRINT_PASSES_FILE)
endif

ifdef REGEX_MAX_CAPTURES
REGEX_MAX_CAPTURES_FLAG = -sregexMaxCaptures=$(REGEX_MAX_CAPTURES)
endif

ARKOUDA_SOURCES = $(shell find $(ARKOUDA_SOURCE_DIR)/ -type f -name '*.chpl')
ARKOUDA_MAIN_SOURCE := $(ARKOUDA_SOURCE_DIR)/$(ARKOUDA_MAIN_MODULE).chpl

Expand All @@ -212,7 +216,7 @@ else
endif

$(ARKOUDA_MAIN_MODULE): check-deps $(ARKOUDA_SOURCES) $(ARKOUDA_MAKEFILES)
$(CHPL) $(CHPL_DEBUG_FLAGS) $(PRINT_PASSES_FLAGS) $(CHPL_FLAGS_WITH_VERSION) $(ARKOUDA_MAIN_SOURCE) $(ARKOUDA_COMPAT_MODULES) -o $@
$(CHPL) $(CHPL_DEBUG_FLAGS) $(PRINT_PASSES_FLAGS) $(REGEX_MAX_CAPTURES_FLAG) $(CHPL_FLAGS_WITH_VERSION) $(ARKOUDA_MAIN_SOURCE) $(ARKOUDA_COMPAT_MODULES) -o $@

CLEAN_TARGETS += arkouda-clean
.PHONY: arkouda-clean
Expand Down
4 changes: 3 additions & 1 deletion arkouda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
pdarrayIterThresh = pdarrayIterThreshDefVal
maxTransferBytesDefVal = 2**30
maxTransferBytes = maxTransferBytesDefVal
regexMaxCaptures: int = -1

logger = getArkoudaLogger(name='Arkouda Client')
clientLogger = getArkoudaLogger(name='Arkouda User Logger', logFormat='%(message)s')
Expand Down Expand Up @@ -91,7 +92,7 @@ def connect(server : str="localhost", port : int=5555, timeout : int=0,
On success, prints the connected address, as seen by the server. If called
with an existing connection, the socket will be re-initialized.
"""
global context, socket, pspStr, connected, serverConfig, verbose, username, token
global context, socket, pspStr, connected, serverConfig, verbose, username, token, regexMaxCaptures

logger.debug("ZMQ version: {}".format(zmq.zmq_version()))

Expand Down Expand Up @@ -147,6 +148,7 @@ def connect(server : str="localhost", port : int=5555, timeout : int=0,
'this may cause some commands to fail or behave ' +
'incorrectly! Updating arkouda is strongly recommended.').\
format(__version__, serverConfig['arkoudaVersion']), RuntimeWarning)
regexMaxCaptures = serverConfig['regexMaxCaptures'] # type:ignore
clientLogger.info(return_message)

def _parse_url(url : str) -> Tuple[str,int,Optional[str]]:
Expand Down
231 changes: 231 additions & 0 deletions arkouda/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from typing import cast
from arkouda.client import generic_msg
from arkouda.pdarrayclass import pdarray, create_pdarray
import json
from enum import Enum

MatchType = Enum('MatchType', ['SEARCH', 'MATCH', 'FULLMATCH'])

class Match:
def __init__(self, matched: pdarray, starts: pdarray, lengths: pdarray, indices: pdarray,
parent_bytes_name: str, parent_offsets_name: str, match_type: MatchType, pattern: str):
self._objtype = type(self).__name__
self._parent_bytes_name = parent_bytes_name
self._parent_offsets_name = parent_offsets_name
self._match_type = match_type
self._matched = matched
self._starts = starts
self._lengths = lengths
self._ends = starts + lengths
self._indices = indices
self._parent_obj: object = None
self.re = pattern

def __str__(self):
from arkouda.client import pdarrayIterThresh
if self._matched.size <= pdarrayIterThresh:
vals = [self.__getitem__(i) for i in range(self._matched.size)]
else:
vals = [self.__getitem__(i) for i in range(3)]
vals.append('... ')
vals.extend([self.__getitem__(i) for i in range(self._matched.size - 3, self._matched.size)])
return f"<ak.{self._objtype} object: {'; '.join(vals)}>"

def __getitem__(self, item):
return f"matched={self._matched[item]}, span=({self._starts[self._indices[item]]}, {self._ends[self._indices[item]]})" if self._matched[item] else f"matched={self._matched[item]}"

def __repr__(self):
return self.__str__()

def matched(self) -> pdarray:
"""
Returns a boolean array indiciating whether each element matched
Returns
-------
pdarray, bool
True for elements that match, False otherwise
Examples
--------
>>> strings = ak.array(['1_2___', '____', '3', '__4___5____6___7', ''])
>>> strings.search('_+').matched()
array([True True False True False])
"""
return self._matched

def start(self) -> pdarray:
"""
Returns the starts of matches
Returns
-------
pdarray, int64
The start positions of matches
Examples
--------
>>> strings = ak.array(['1_2___', '____', '3', '__4___5____6___7', ''])
>>> strings.search('_+').start()
array([1 0 0])
"""
return self._starts

def end(self) -> pdarray:
"""
Returns the ends of matches
Returns
-------
pdarray, int64
The end positions of matches
Examples
--------
>>> strings = ak.array(['1_2___', '____', '3', '__4___5____6___7', ''])
>>> strings.search('_+').end()
array([2 4 2])
"""
return self._ends

def match_type(self) -> str:
"""
Returns the type of the Match object
Returns
-------
str
MatchType of the Match object
Examples
--------
>>> strings = ak.array(['1_2___', '____', '3', '__4___5____6___7', ''])
>>> strings.search('_+').match_type()
'SEARCH'
"""
return self._match_type.name

def find_matches(self, return_match_origins: bool = False):
"""
Return all matches as a new Strings object
Parameters
----------
return_match_origins: bool
If True, return a pdarray containing the index of the original string each pattern match is from
Returns
-------
Strings
Strings object containing only matches
pdarray, int64 (optional)
The index of the original string each pattern match is from
Raises
------
RuntimeError
Raised if there is a server-side error thrown
Examples
--------
>>> strings = ak.array(['1_2___', '____', '3', '__4___5____6___7', ''])
>>> strings.search('_+').find_matches(return_match_origins=True)
(array(['_', '____', '__']), array([0 1 3]))
"""
from arkouda.strings import Strings
cmd = "segmentedFindAll"
args = "{} {} {} {} {} {} {} {}".format(self._objtype,
self._parent_offsets_name,
self._parent_bytes_name,
self._matched.name,
self._starts.name,
self._lengths.name,
self._indices.name,
return_match_origins)
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
if return_match_origins:
arrays = repMsg.split('+', maxsplit=2)
return Strings(arrays[0], arrays[1]), create_pdarray(arrays[2])
else:
arrays = repMsg.split('+', maxsplit=1)
return Strings(arrays[0], arrays[1])

def group(self, group_num: int = 0, return_group_origins: bool = False):
"""
Returns a new Strings containing the capture group corresponding to group_num. For the default, group_num=0, return the full match
Parameters
----------
group_num: int
The index of the capture group to be returned
return_group_origins: bool
If True, return a pdarray containing the index of the original string each capture group is from
Returns
-------
Strings
Strings object containing only the capture groups corresponding to group_num
pdarray, int64 (optional)
The index of the original string each group is from
Examples
--------
>>> strings = ak.array(["Isaac Newton, physicist", '<--calculus-->', 'Gottfried Leibniz, mathematician'])
>>> m = strings.search("(\\w+) (\\w+)")
>>> m.group()
array(['Isaac Newton', 'Gottfried Leibniz'])
>>> m.group(1)
array(['Isaac', 'Gottfried'])
>>> m.group(2, return_group_origins=True)
(array(['Newton', 'Leibniz']), array([0 2]))
"""
from arkouda.strings import Strings
from arkouda.client import regexMaxCaptures
if group_num < 0:
raise ValueError("group_num cannot be negative")
if group_num > regexMaxCaptures:
max_capture_flag = f'-e REGEX_MAX_CAPTURES={group_num}'
e = f"group_num={group_num} > regexMaxCaptures={regexMaxCaptures}. To run group({group_num}), recompile the server with flag '{max_capture_flag}'"
raise ValueError(e)

# We don't cache the locations of groups, find the location info and call findAll
cmd = "segmentedFindLoc"
args = "{} {} {} {} {}".format(self._objtype,
self._parent_offsets_name,
self._parent_bytes_name,
group_num,
json.dumps([self.re]))
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
created_map = json.loads(repMsg)
global_starts = create_pdarray(created_map["Starts"])
global_lengths = create_pdarray(created_map["Lens"])
global_indices = create_pdarray(created_map["Indices"])
if self._match_type == MatchType.SEARCH:
matched = create_pdarray(created_map["SearchBool"])
indices = create_pdarray(created_map["SearchInd"])
elif self._match_type == MatchType.MATCH:
matched = create_pdarray(created_map["MatchBool"])
indices = create_pdarray(created_map["MatchInd"])
elif self._match_type == MatchType.FULLMATCH:
matched = create_pdarray(created_map["FullMatchBool"])
indices = create_pdarray(created_map["FullMatchInd"])
else:
raise ValueError(f"{self._match_type} is not a MatchType")
starts = global_starts[global_indices[matched]]
lengths = global_lengths[global_indices[matched]]
cmd = "segmentedFindAll"
args = "{} {} {} {} {} {} {} {}".format(self._objtype,
self._parent_offsets_name,
self._parent_bytes_name,
matched.name,
starts.name,
lengths.name,
indices.name,
return_group_origins)
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
if return_group_origins:
arrays = repMsg.split('+', maxsplit=2)
return Strings(arrays[0], arrays[1]), create_pdarray(arrays[2])
else:
arrays = repMsg.split('+', maxsplit=1)
return Strings(arrays[0], arrays[1])
Loading

0 comments on commit 8256fe3

Please sign in to comment.