Skip to content

Commit

Permalink
Covariate collapse (#1142)
Browse files Browse the repository at this point in the history
* Setup basic verb test runner

* Replace join_text_units_to_entity_ids with subflow

* Update comments

* Replace join_text_units_to_relationship_ids subflow

* Roll in final select

* Reuse assertion util

* Small fix + format

* Format/typing

* Semver

* Format/typing

* Semver

* Revert format changes

* Fix smoke test subworkflow count

* Edit subworkflows for another smoke test

* Update test parquets for covariates

* Collapse covariate join

* Rework subtasks for per-flow customization

* Format

* Semver

* Fix smoke test
  • Loading branch information
natoverse authored Sep 16, 2024
1 parent 2de302f commit d22c0e7
Show file tree
Hide file tree
Showing 23 changed files with 126 additions and 13 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240916191422408337.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Covariate verb collapse."
}
13 changes: 5 additions & 8 deletions graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@ def build_steps(
"""
return [
{
"verb": "select",
"args": {"columns": ["id", "text_unit_id"]},
"input": {"source": "workflow:create_final_covariates"},
},
{
"verb": "aggregate_override",
"verb": "join_text_units_to_covariate_ids",
"args": {
"groupby": ["text_unit_id"],
"aggregations": [
"select_columns": ["id", "text_unit_id"],
"aggregate_groupby": ["text_unit_id"],
"aggregate_aggregations": [
{
"column": "id",
"operation": "array_agg_distinct",
Expand All @@ -40,5 +36,6 @@ def build_steps(
},
],
},
"input": {"source": "workflow:create_final_covariates"},
},
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def build_steps(
"""
return [
{
"verb": "join_text_units",
"verb": "join_text_units_to_entity_ids",
"args": {
"select_columns": ["id", "text_unit_ids"],
"unroll_column": "text_unit_ids",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def build_steps(
"""
return [
{
"verb": "join_text_units",
"verb": "join_text_units_to_relationship_ids",
"args": {
"select_columns": ["id", "text_unit_ids"],
"unroll_column": "text_unit_ids",
Expand Down
8 changes: 6 additions & 2 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

"""The Indexing Engine workflows -> subflows package root."""

from .join_text_units import join_text_units
from .join_text_units_to_covariate_ids import join_text_units_to_covariate_ids
from .join_text_units_to_entity_ids import join_text_units_to_entity_ids
from .join_text_units_to_relationship_ids import join_text_units_to_relationship_ids

__all__ = [
"join_text_units",
"join_text_units_to_covariate_ids",
"join_text_units_to_entity_ids",
"join_text_units_to_relationship_ids",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""join_text_units_to_covariate_ids verb (subtask)."""

from typing import Any, cast

from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result

from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="join_text_units_to_covariate_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_covariate_ids(
input: VerbInput,
select_columns: list[str],
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
aggregated = aggregate_df(selected, aggregate_aggregations, aggregate_groupby)
return create_verb_result(aggregated)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""join_text_units_to_entity_ids verb (subtask)."""

from typing import Any, cast

from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result

from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="join_text_units_to_entity_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_entity_ids(
input: VerbInput,
select_columns: list[str],
unroll_column: str,
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
unrolled = selected.explode(unroll_column).reset_index(drop=True)
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
return create_verb_result(aggregated)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""join_text_units_to_relationship_ids verb (subtask)."""

from typing import Any, cast

from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result

from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="join_text_units_to_relationship_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_relationship_ids(
input: VerbInput,
select_columns: list[str],
unroll_column: str,
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
final_select_columns: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
unrolled = selected.explode(unroll_column).reset_index(drop=True)
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
return create_verb_result(cast(Table, aggregated[final_select_columns]))
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_entity_graph": {
Expand Down
Binary file modified tests/verbs/data/create_base_entity_graph.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_base_extracted_entities.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_communities.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_community_reports.parquet
Binary file not shown.
Binary file added tests/verbs/data/create_final_covariates.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_entities.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_nodes.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_relationships.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_final_text_units.parquet
Binary file not shown.
Binary file modified tests/verbs/data/create_summarized_entities.parquet
Binary file not shown.
Binary file not shown.
Binary file modified tests/verbs/data/join_text_units_to_entity_ids.parquet
Binary file not shown.
Binary file modified tests/verbs/data/join_text_units_to_relationship_ids.parquet
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/verbs/test_join_text_units_to_covariate_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.join_text_units_to_covariate_ids import build_steps

from .util import compare_outputs, get_workflow_output, load_expected, load_input_tables


async def test_join_text_units_to_covariate_ids():
input_tables = load_input_tables([
"workflow:create_final_covariates",
])
expected = load_expected("join_text_units_to_covariate_ids")

actual = await get_workflow_output(
input_tables,
{
"steps": build_steps({}),
},
)

compare_outputs(actual, expected)

0 comments on commit d22c0e7

Please sign in to comment.