Skip to content

Commit

Permalink
DynamoDB: scan() now supports parallelization using the Segment/Total…
Browse files Browse the repository at this point in the history
…Segments parameters (#8303)
  • Loading branch information
bblommers authored Nov 10, 2024
1 parent 1c140bc commit f1e48a9
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 14 deletions.
2 changes: 2 additions & 0 deletions moto/dynamodb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def scan(
index_name: str,
consistent_read: bool,
projection_expression: Optional[List[List[str]]],
segments: Union[Tuple[None, None], Tuple[int, int]],
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
table = self.get_table(table_name)

Expand All @@ -421,6 +422,7 @@ def scan(
index_name,
consistent_read,
projection_expression,
segments=segments,
)

def update_item(
Expand Down
31 changes: 30 additions & 1 deletion moto/dynamodb/models/dynamo_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import copy
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.utils import merge_dicts
Expand All @@ -12,6 +12,7 @@
IncorrectDataType,
ItemSizeTooLarge,
)
from moto.utilities.utils import md5_hash

from .utilities import bytesize, find_nested_key

Expand Down Expand Up @@ -455,3 +456,31 @@ def project(self, projection_expressions: List[List[str]]) -> "Item":
# We need to convert that into DynamoDB dictionary ({'M': {'key': {'S': 'value'}}})
attrs=serializer.serialize(result)["M"],
)

def is_within_segment(
self, segments: Union[Tuple[None, None], Tuple[int, int]]
) -> bool:
"""
Segments can be either (x, y) or (None, None)
None, None => the user requested the entire table, so the item always falls within that
x, y => the user requested segment x out of y
Segment membership is computed based on the value of the hash key
"""
if segments == (None, None):
return True

segment, total_segments = segments
# Creates a reproducible hash number for this item (between 0 and 256)
# Note that we can't use the builtin hash() method, as that is not deterministic between executions
#
# Using a hash based on the hash key ensures parity with how AWS seems to behave:
# - Items are not divided equally between segment
# - Items always fall in the same segment, regardless of how often you call `scan()`
# - Items with the same hash key but different range keys always fall in the same segment
# - Items with different hash keys may be part of different segments
#
item_hash = md5_hash(self.hash_key.value.encode("utf8")).digest()[0]
# Modulo ensures that we always get a number between 0 and (total_segments)
item_segment = item_hash % total_segments
return segment == item_segment
6 changes: 5 additions & 1 deletion moto/dynamodb/models/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from moto.core.common_models import BaseModel, CloudFormationModel
from moto.core.utils import unix_time, unix_time_millis, utcnow
Expand Down Expand Up @@ -897,6 +897,7 @@ def scan(
index_name: Optional[str] = None,
consistent_read: bool = False,
projection_expression: Optional[List[List[str]]] = None,
segments: Union[Tuple[None, None], Tuple[int, int]] = (None, None),
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results: List[Item] = []
result_size = 0
Expand Down Expand Up @@ -942,6 +943,9 @@ def scan(
last_evaluated_key = None
processing_previous_page = exclusive_start_key is not None
for item in items:
if not item.is_within_segment(segments):
continue

# Cycle through the previous page of results
# When we encounter our start key, we know we've reached the end of the previous page
if processing_previous_page:
Expand Down
31 changes: 25 additions & 6 deletions moto/dynamodb/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,24 @@ def scan(self) -> str:
limit = self.body.get("Limit")
index_name = self.body.get("IndexName")
consistent_read = self.body.get("ConsistentRead", False)
segment = self.body.get("Segment")
total_segments = self.body.get("TotalSegments")
if segment is not None and total_segments is None:
raise MockValidationException(
"The TotalSegments parameter is required but was not present in the request when Segment parameter is present"
)
if total_segments is not None and segment is None:
raise MockValidationException(
"The Segment parameter is required but was not present in the request when parameter TotalSegments is present"
)
if (
segment is not None
and total_segments is not None
and segment >= total_segments
):
raise MockValidationException(
f"The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: {segment} is not less than TotalSegments: {total_segments}"
)

projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
Expand All @@ -840,12 +858,13 @@ def scan(self) -> str:
filters,
limit,
exclusive_start_key,
filter_expression,
expression_attribute_names,
expression_attribute_values,
index_name,
consistent_read,
projection_expressions,
filter_expression=filter_expression,
expr_names=expression_attribute_names,
expr_values=expression_attribute_values,
index_name=index_name,
consistent_read=consistent_read,
projection_expression=projection_expressions,
segments=(segment, total_segments),
)
except ValueError as err:
raise MockValidationException(f"Bad Filter Expression: {err}")
Expand Down
30 changes: 24 additions & 6 deletions tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,40 @@

class BaseTest:
@classmethod
def setup_class(cls):
def setup_class(cls, add_range=False):
if not allow_aws_request():
cls.mock = mock_aws()
cls.mock.start()
cls.client = boto3.client("dynamodb", region_name="us-east-1")
cls.table_name = "T" + str(uuid4())[0:6]
cls.has_range_key = add_range

dynamodb = boto3.resource("dynamodb", region_name="us-east-1")

# Create the DynamoDB table.
schema = [{"AttributeName": "pk", "KeyType": "HASH"}]
defs = [{"AttributeName": "pk", "AttributeType": "S"}]
if add_range:
schema.append({"AttributeName": "rk", "KeyType": "RANGE"})
defs.append({"AttributeName": "rk", "AttributeType": "S"})
dynamodb.create_table(
TableName=cls.table_name,
KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}],
KeySchema=schema,
AttributeDefinitions=defs,
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
waiter = cls.client.get_waiter("table_exists")
waiter.wait(TableName=cls.table_name)
cls.table = dynamodb.Table(cls.table_name)
cls.table.put_item(
Item={"pk": "the-key", "subject": "123", "body": "some test msg"}
)

def setup_method(self):
# Empty table between runs
items = self.table.scan()["Items"]
for item in items:
if self.has_range_key:
self.table.delete_item(Key={"pk": item["pk"], "rk": item["rk"]})
else:
self.table.delete_item(Key={"pk": item["pk"]})

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -1296,6 +1308,12 @@ def test_query_with_missing_expression_attribute():

@pytest.mark.aws_verified
class TestReturnValuesOnConditionCheckFailure(BaseTest):
def setup_method(self):
super().setup_method()
self.table.put_item(
Item={"pk": "the-key", "subject": "123", "body": "some test msg"}
)

def test_put_item_does_not_return_old_item(self):
with pytest.raises(ClientError) as exc:
self.table.put_item(
Expand Down
125 changes: 125 additions & 0 deletions tests/test_dynamodb/test_dynamodb_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from botocore.exceptions import ClientError

from moto import mock_aws
from tests.test_dynamodb.exceptions.test_dynamodb_exceptions import BaseTest

from . import dynamodb_aws_verified

Expand Down Expand Up @@ -729,3 +730,127 @@ def test_scan_with_scanfilter(self):
"Items"
]
assert items == [{"partitionKey": "pk-1"}]


@pytest.mark.aws_verified
class TestParallelScan(BaseTest):
@staticmethod
def setup_class(cls): # pylint: disable=arguments-renamed
super().setup_class(add_range=True)

def test_segment_only(self):
with pytest.raises(ClientError) as exc:
self.table.scan(Segment=1)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The TotalSegments parameter is required but was not present in the request when Segment parameter is present"
)

def test_total_segments_only(self):
with pytest.raises(ClientError) as exc:
self.table.scan(TotalSegments=1)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The Segment parameter is required but was not present in the request when parameter TotalSegments is present"
)

def test_parallelize_all_different_hash_keys(self):
for i in range(10):
self.table.put_item(Item={"pk": f"item{i}", "rk": "sth"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 10

def test_parallelize_different_hash_key_per_segment(self):
for i in range(3):
for j in range(4):
self.table.put_item(Item={"pk": f"item{i}", "rk": f"rk{j}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 12

def test_scan_using_filter_expression(self):
# AWS seems to return all data in Segment 1
for i in range(10):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})
for i in range(10):
self.table.put_item(Item={"pk": "n/a", "rk": f"range{i}"})
for i in range(20, 10, -1):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=0, TotalSegments=3
)["Items"]
resp2 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=1, TotalSegments=3
)["Items"]
resp3 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=2, TotalSegments=3
)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 20

def test_scan_single_hash_key(self):
# AWS seems to return all data in Segment 1
for i in range(10):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})
for i in range(20, 10, -1):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 20

def test_pagination(self):
for i in range(50):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3, Limit=10)
resp2 = self.table.scan(Segment=1, TotalSegments=3, Limit=10)
resp3 = self.table.scan(Segment=2, TotalSegments=3, Limit=10)

first_pass = len(resp1["Items"]) + len(resp2["Items"]) + len(resp3["Items"])
assert first_pass <= 30

second_pass = 0
if "LastEvaluatedKey" in resp1:
resp = self.table.scan(
Segment=0, TotalSegments=3, ExclusiveStartKey=resp1["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

if "LastEvaluatedKey" in resp2:
resp = self.table.scan(
Segment=1, TotalSegments=3, ExclusiveStartKey=resp2["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

if "LastEvaluatedKey" in resp3:
resp = self.table.scan(
Segment=2, TotalSegments=3, ExclusiveStartKey=resp3["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

assert first_pass + second_pass == 50

def test_segment_larger_than_total_segments(self):
with pytest.raises(ClientError) as exc:
self.table.scan(Segment=3, TotalSegments=3)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: 3 is not less than TotalSegments: 3"
)

0 comments on commit f1e48a9

Please sign in to comment.