Skip to content

Commit

Permalink
Collapse final communities workflow (#1150)
Browse files Browse the repository at this point in the history
* Collapse create_final_communities

* Semver

* Spellcheck

* Clean up filtering

* Add space in title

* Format

* Cleanup imports and format

* Spruce up the tests

* Update dictionary.txt

* Spellcheck

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
  • Loading branch information
natoverse and AlonsoGuevara authored Sep 18, 2024
1 parent a473265 commit aa5b426
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 159 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240917213220301479.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create_final_communities."
}
3 changes: 3 additions & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ isna
getcwd
fillna
noqa
dtypes

# Azure
abfs
Expand Down Expand Up @@ -97,6 +98,8 @@ retryer
agenerate
aembed
dedupe
dropna
dtypes

# LLM Terms
AOAI
Expand Down
28 changes: 25 additions & 3 deletions graphrag/index/verbs/graph/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,32 @@ def unpack_graph(
column: <column name> # The name of the column containing the graph, should be a graphml graph
```
"""
input_df = input.get_input()
output_df = unpack_graph_df(
cast(pd.DataFrame, input_df),
callbacks,
column,
type,
copy,
embeddings_column,
kwargs=kwargs,
)
return TableContainer(table=output_df)


def unpack_graph_df(
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,
type: str, # noqa A002
copy: list[str] | None = None,
embeddings_column: str = "embeddings",
**kwargs,
) -> pd.DataFrame:
"""Unpack nodes or edges from a graphml graph, into a list of nodes or edges."""
if copy is None:
copy = default_copy
input_df = input.get_input()

num_total = len(input_df)
result = []
copy = [col for col in copy if col in input_df.columns]
Expand All @@ -64,8 +87,7 @@ def unpack_graph(
)
])

output_df = pd.DataFrame(result)
return TableContainer(table=output_df)
return pd.DataFrame(result)


def _run_unpack(
Expand Down
149 changes: 1 addition & 148 deletions graphrag/index/workflows/v1/create_final_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,154 +19,7 @@ def build_steps(
"""
return [
{
"id": "graph_nodes",
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "nodes",
},
"verb": "create_final_communities",
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "graph_edges",
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "edges",
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "source_clusters",
"verb": "join",
"args": {
"on": ["label", "source"],
},
"input": {"source": "graph_nodes", "others": ["graph_edges"]},
},
{
"id": "target_clusters",
"verb": "join",
"args": {
"on": ["label", "target"],
},
"input": {"source": "graph_nodes", "others": ["graph_edges"]},
},
{
"id": "concatenated_clusters",
"verb": "concat",
"input": {
"source": "source_clusters",
"others": ["target_clusters"],
},
},
{
"id": "combined_clusters",
"verb": "filter",
"args": {
# level_1 is the left side of the join
# level_2 is the right side of the join
"column": "level_1",
"criteria": [
{"type": "column", "operator": "equals", "value": "level_2"}
],
},
"input": {"source": "concatenated_clusters"},
},
{
"id": "cluster_relationships",
"verb": "aggregate_override",
"args": {
"groupby": [
"cluster",
"level_1", # level_1 is the left side of the join
],
"aggregations": [
{
"column": "id_2", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_1",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
},
"input": {"source": "combined_clusters"},
},
{
"id": "all_clusters",
"verb": "aggregate_override",
"args": {
"groupby": ["cluster", "level"],
"aggregations": [{"column": "cluster", "to": "id", "operation": "any"}],
},
"input": {"source": "graph_nodes"},
},
{
"verb": "join",
"args": {
"on": ["id", "cluster"],
},
"input": {"source": "all_clusters", "others": ["cluster_relationships"]},
},
{
"verb": "filter",
"args": {
# level is the left side of the join
# level_1 is the right side of the join
"column": "level",
"criteria": [
{"type": "column", "operator": "equals", "value": "level_1"}
],
},
},
*create_community_title_wf,
{
# TODO: Rodrigo says "raw_community" is temporary
"verb": "copy",
"args": {
"column": "id",
"to": "raw_community",
},
},
{
"verb": "select",
"args": {
"columns": [
"id",
"title",
"level",
"raw_community",
"relationship_ids",
"text_unit_ids",
],
},
},
]


create_community_title_wf = [
# Hack to string concat "Community " + id
{
"verb": "fill",
"args": {
"to": "__temp",
"value": "Community ",
},
},
{
"verb": "merge",
"args": {
"columns": [
"__temp",
"id",
],
"to": "title",
"strategy": "concat",
"preserveSource": True,
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

"""The Indexing Engine workflows -> subflows package root."""

from .create_final_communities import create_final_communities
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding

__all__ = [
"create_final_communities",
"create_final_text_units_pre_embedding",
]
107 changes: 107 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_final_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final communities."""

from typing import cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="create_final_communities", treats_input_tables_as_immutable=True)
def create_final_communities(
input: VerbInput,
callbacks: VerbCallbacks,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final communities."""
table = cast(pd.DataFrame, input.get_input())

graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")

source_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="source",
how="inner",
)
target_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="target",
how="inner",
)

concatenated_clusters = pd.concat(
[source_clusters, target_clusters], ignore_index=True
)

# level_x is the left side of the join
# level_y is the right side of the join
# we only want to keep the clusters that are the same on both sides
combined_clusters = concatenated_clusters[
concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
].reset_index(drop=True)

cluster_relationships = aggregate_df(
cast(Table, combined_clusters),
aggregations=[
{
"column": "id_y", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_x",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
groupby=[
"cluster",
"level_x", # level_x is the left side of the join
],
)

all_clusters = aggregate_df(
graph_nodes,
aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
groupby=["cluster", "level"],
)

joined = all_clusters.merge(
cluster_relationships,
left_on="id",
right_on="cluster",
how="inner",
)

filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True)

filtered["title"] = "Community " + filtered["id"].astype(str)

return create_verb_result(
cast(
Table,
filtered[
[
"id",
"title",
"level",
"relationship_ids",
"text_unit_ids",
]
],
)
)
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
1,
2000
],
"subworkflows": 14,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_community_reports": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
1,
2000
],
"subworkflows": 14,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_community_reports": {
Expand Down
35 changes: 35 additions & 0 deletions tests/verbs/test_create_final_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.create_final_communities import (
build_steps,
workflow_name,
)

from .util import (
compare_outputs,
get_workflow_output,
load_expected,
load_input_tables,
)


async def test_create_final_communities():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)

steps = build_steps({})

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

# we removed the raw_community column, so expect one less in the output
compare_outputs(
actual, expected, ["id", "title", "level", "relationship_ids", "text_unit_ids"]
)
Loading

0 comments on commit aa5b426

Please sign in to comment.