Skip to content

Commit

Permalink
fix(search): Apply decay relevance for remaining search types
Browse files Browse the repository at this point in the history
- Refactored logic for applying build_decay_relevance_score
- Added tests for all search types
  • Loading branch information
albertisfu committed Dec 21, 2024
1 parent 09d80ff commit ee66af8
Show file tree
Hide file tree
Showing 8 changed files with 723 additions and 273 deletions.
206 changes: 108 additions & 98 deletions cl/lib/elasticsearch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ApiPositionMapping,
BasePositionMapping,
CleanData,
EsJoinQueries,
EsMainQueries,
ESRangeQueryParams,
)
Expand Down Expand Up @@ -1089,19 +1090,16 @@ def combine_plain_filters_and_queries(

def get_match_all_query(
cd: CleanData,
search_query: Search,
api_version: Literal["v3", "v4"] | None = None,
child_highlighting: bool = True,
) -> Search:
) -> Query:
"""Build and return a match-all query for each type of document.
:param cd: The query CleanedData
:param search_query: Elasticsearch DSL Search object
:param api_version: Optional, the request API version.
:param child_highlighting: Whether highlighting should be enabled in child docs.
:return: The modified Search object based on the given conditions.
:return: The Match All Query object.
"""

_, query_hits_limit = get_child_top_hits_limit(
cd, cd["type"], api_version=api_version
)
Expand All @@ -1125,9 +1123,6 @@ def get_match_all_query(
final_match_all_query = Q(
"bool", should=q_should, minimum_should_match=1
)
final_match_all_query = apply_custom_score_to_main_query(
cd, final_match_all_query, api_version
)
case SEARCH_TYPES.RECAP | SEARCH_TYPES.DOCKETS:
# Match all query for RECAP and Dockets, it'll return dockets
# with child documents and also empty dockets.
Expand All @@ -1149,9 +1144,6 @@ def get_match_all_query(
should=[match_all_child_query, match_all_parent_query],
minimum_should_match=1,
)
final_match_all_query = apply_custom_score_to_main_query(
cd, final_match_all_query, api_version, boost_mode="replace"
)
case SEARCH_TYPES.OPINION:
# Only return Opinion clusters.
match_all_child_query = build_has_child_query(
Expand All @@ -1169,18 +1161,12 @@ def get_match_all_query(
final_match_all_query = Q(
"bool", should=q_should, minimum_should_match=1
)
final_match_all_query = apply_custom_score_to_main_query(
cd, final_match_all_query, api_version, boost_mode="replace"
)
case _:
# No string_query or filters in plain search types like OA and
# Parentheticals. Use a match_all query.
match_all_query = Q("match_all")
final_match_all_query = apply_custom_score_to_main_query(
cd, match_all_query, api_version, boost_mode="replace"
)
final_match_all_query = Q("match_all")

return search_query.query(final_match_all_query)
return final_match_all_query


def build_es_base_query(
Expand All @@ -1207,10 +1193,13 @@ def build_es_base_query(

main_query = None
string_query = None
child_docs_query = None
child_query = None
parent_query = None
filters = []
plain_doc = False
join_queries = None
has_text_query = False
match_all_query = False
match cd["type"]:
case SEARCH_TYPES.PARENTHETICAL:
filters = build_es_plain_filters(cd)
Expand Down Expand Up @@ -1253,14 +1242,12 @@ def build_es_base_query(
],
)
)
main_query, child_docs_query, parent_query = (
build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
child_highlighting=child_highlighting,
api_version=api_version,
)
join_queries = build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
child_highlighting=child_highlighting,
api_version=api_version,
)

case (
Expand All @@ -1286,15 +1273,13 @@ def build_es_base_query(
],
)
)
main_query, child_docs_query, parent_query = (
build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
child_highlighting=child_highlighting,
api_version=api_version,
alerts=alerts,
)
join_queries = build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
child_highlighting=child_highlighting,
api_version=api_version,
alerts=alerts,
)

case SEARCH_TYPES.OPINION:
Expand All @@ -1306,20 +1291,19 @@ def build_es_base_query(
mlt_query = async_to_sync(build_more_like_this_query)(
cluster_pks
)
main_query, child_docs_query, parent_query = (
build_full_join_es_queries(
cd,
{"opinion": []},
[],
mlt_query,
child_highlighting=True,
api_version=api_version,
)
join_queries = build_full_join_es_queries(
cd,
{"opinion": []},
[],
mlt_query,
child_highlighting=True,
api_version=api_version,
)
return EsMainQueries(
search_query=search_query.query(main_query),
parent_query=parent_query,
child_query=child_docs_query,
search_query=search_query.query(join_queries.main_query),
boost_mode="multiply",
parent_query=join_queries.parent_query,
child_query=join_queries.child_query,
)

opinion_search_fields = SEARCH_OPINION_QUERY_FIELDS
Expand All @@ -1346,53 +1330,48 @@ def build_es_base_query(
],
)
)
main_query, child_docs_query, parent_query = (
build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
mlt_query,
child_highlighting=child_highlighting,
api_version=api_version,
alerts=alerts,
)
join_queries = build_full_join_es_queries(
cd,
child_query_fields,
parent_query_fields,
mlt_query,
child_highlighting=child_highlighting,
api_version=api_version,
alerts=alerts,
)

if join_queries is not None:
main_query = join_queries.main_query
parent_query = join_queries.parent_query
child_query = join_queries.child_query
has_text_query = join_queries.has_text_query

if not any([filters, string_query, main_query]):
# No filters, string_query or main_query provided by the user, return a
# match_all query
match_all_query = get_match_all_query(
cd, search_query, api_version, child_highlighting
)

return EsMainQueries(
search_query=match_all_query,
parent_query=parent_query,
child_query=child_docs_query,
)
main_query = get_match_all_query(cd, api_version, child_highlighting)
match_all_query = True

boost_mode = "multiply"
if plain_doc:
boost_mode = "multiply" if has_text_query else "replace"
if plain_doc and not match_all_query:
# Combine the filters and string query for plain documents like Oral
# arguments and parentheticals
main_query = combine_plain_filters_and_queries(
cd, filters, string_query, api_version
)
if not string_query:
boost_mode = "replace"
else:
main_query_dict = main_query.to_dict()
contain_query_string = "should" in main_query_dict["bool"]["should"][1]["bool"] and "query_string" in main_query_dict["bool"]["should"][1]["bool"]["should"][0]
if not contain_query_string:
boost_mode = "replace"

boost_mode = "multiply" if string_query else "replace"

main_query = apply_custom_score_to_main_query(cd, main_query, api_version, boost_mode=boost_mode)
# Apply a custom function score to the main query, useful for cursor pagination
# in the V4 API and for date decay relevance.
main_query = apply_custom_score_to_main_query(
cd, main_query, api_version, boost_mode=boost_mode
)

return EsMainQueries(
search_query=search_query.query(main_query),
boost_mode=boost_mode,
parent_query=parent_query,
child_query=child_docs_query,
child_query=child_query,
)


Expand Down Expand Up @@ -2223,7 +2202,6 @@ def fetch_es_results(

# Execute the ES main query + count queries in a single request.
multi_search = MultiSearch()
print("MAin query: ", main_query.to_dict())
multi_search = multi_search.add(main_query).add(main_doc_count_query)
if child_total_query:
multi_search = multi_search.add(child_total_query)
Expand Down Expand Up @@ -2593,12 +2571,28 @@ def apply_custom_score_to_main_query(
else False
)

valid_decay_relevance_types = {
SEARCH_TYPES.OPINION: ["dateFiled"],
SEARCH_TYPES.RECAP: ["dateFiled"],
SEARCH_TYPES.DOCKETS: ["dateFiled"],
SEARCH_TYPES.RECAP_DOCUMENT: ["dateFiled"],
SEARCH_TYPES.ORAL_ARGUMENT: ["dateArgued"],
valid_decay_relevance_types: dict[str, dict[str, str | int | float]] = {
SEARCH_TYPES.OPINION: {
"field": "dateFiled",
"scale": 50,
"decay": 0.5,
},
SEARCH_TYPES.RECAP: {"field": "dateFiled", "scale": 50, "decay": 0.5},
SEARCH_TYPES.DOCKETS: {
"field": "dateFiled",
"scale": 50,
"decay": 0.5,
},
SEARCH_TYPES.RECAP_DOCUMENT: {
"field": "dateFiled",
"scale": 50,
"decay": 0.5,
},
SEARCH_TYPES.ORAL_ARGUMENT: {
"field": "dateArgued",
"scale": 50,
"decay": 0.5,
},
}
main_order_by = cd.get("order_by", "")
if is_valid_custom_score_field and api_version == "v4":
Expand All @@ -2615,9 +2609,11 @@ def apply_custom_score_to_main_query(
main_order_by == "score desc"
and cd["type"] in valid_decay_relevance_types
):
date_field = valid_decay_relevance_types[cd["type"]][0]
date_field = str(valid_decay_relevance_types[cd["type"]]["field"])
scale = int(valid_decay_relevance_types[cd["type"]]["scale"])
decay = float(valid_decay_relevance_types[cd["type"]]["decay"])
query = build_decay_relevance_score(
query, date_field, scale=10, decay=0.5, boost_mode=boost_mode
query, date_field, scale=scale, decay=decay, boost_mode=boost_mode
)
return query

Expand All @@ -2630,7 +2626,7 @@ def build_full_join_es_queries(
child_highlighting: bool = True,
api_version: Literal["v3", "v4"] | None = None,
alerts: bool = False,
) -> tuple[QueryString | list, QueryString | None, QueryString | None]:
) -> EsJoinQueries:
"""Build a complete Elasticsearch query with both parent and child document
conditions.
Expand All @@ -2646,6 +2642,7 @@ def build_full_join_es_queries(
"""

q_should = []
has_text_query = False
match cd["type"]:
case (
SEARCH_TYPES.RECAP
Expand Down Expand Up @@ -2775,6 +2772,7 @@ def build_full_join_es_queries(
string_query = build_fulltext_query(
parent_query_fields, cd.get("q", ""), only_queries=True
)
has_text_query = True if string_query else False

# If child filters are set, add a has_child query as a filter to the
# parent query to exclude results without matching children.
Expand Down Expand Up @@ -2822,15 +2820,21 @@ def build_full_join_es_queries(
q_should.append(parent_query)

if not q_should:
return [], child_docs_query, parent_query
return EsJoinQueries(
main_query=[],
parent_query=parent_query,
child_query=child_docs_query,
has_text_query=has_text_query,
)

return (
Q(
return EsJoinQueries(
main_query=Q(
"bool",
should=q_should,
),
child_docs_query,
parent_query,
parent_query=parent_query,
child_query=child_docs_query,
has_text_query=has_text_query,
)


Expand Down Expand Up @@ -3091,11 +3095,14 @@ def do_es_api_query(
# and sorting are set.
# Note that in V3 Case Law Search, opinions are collapsed by cluster_id
# meaning that only one result per cluster is shown.
s = build_child_docs_query(
child_docs_query = build_child_docs_query(
child_docs_query,
cd=cd,
)
main_query = search_query.query(s)
main_query = apply_custom_score_to_main_query(
cd, child_docs_query, api_version, boost_mode=es_queries.boost_mode
)
main_query = search_query.query(main_query)
highlight_options, fields_to_exclude = build_highlights_dict(
highlighting_fields, hl_tag
)
Expand Down Expand Up @@ -3138,7 +3145,10 @@ def do_es_api_query(
# field exclusion are set.

s = apply_custom_score_to_main_query(
cd, child_docs_query, api_version
cd,
child_docs_query,
api_version,
boost_mode=es_queries.boost_mode,
)
main_query = search_query.query(s)
highlight_options, fields_to_exclude = build_highlights_dict(
Expand Down
9 changes: 9 additions & 0 deletions cl/lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,19 @@ def get_db_to_dataclass_map(self):
@dataclass
class EsMainQueries:
search_query: Search
boost_mode: str
parent_query: QueryString | None = None
child_query: QueryString | None = None


@dataclass
class EsJoinQueries:
main_query: QueryString | list
parent_query: QueryString | None
child_query: QueryString | None
has_text_query: bool


@dataclass
class ApiPositionMapping(BasePositionMapping):
position_type_dict: defaultdict[int, list[str]] = field(
Expand Down
Loading

0 comments on commit ee66af8

Please sign in to comment.