From b195554dcdcdc85997ee0e2c10330c2646c4a02d Mon Sep 17 00:00:00 2001
From: archinksagar <68829863+archinksagar@users.noreply.github.com>
Date: Sat, 14 Sep 2024 06:36:31 -0400
Subject: [PATCH] feat: Implementation of memorydb api (#8108)
---
IMPLEMENTATION_COVERAGE.md | 191 +++-----
docs/docs/services/memorydb.rst | 57 +++
moto/backend_index.py | 1 +
moto/backends.py | 4 +
moto/memorydb/__init__.py | 1 +
moto/memorydb/exceptions.py | 59 +++
moto/memorydb/models.py | 664 +++++++++++++++++++++++++++
moto/memorydb/responses.py | 229 +++++++++
moto/memorydb/urls.py | 11 +
moto/moto_server/werkzeug_app.py | 2 +
tests/test_memorydb/__init__.py | 0
tests/test_memorydb/test_memorydb.py | 661 ++++++++++++++++++++++++++
12 files changed, 1750 insertions(+), 130 deletions(-)
create mode 100644 docs/docs/services/memorydb.rst
create mode 100644 moto/memorydb/__init__.py
create mode 100644 moto/memorydb/exceptions.py
create mode 100644 moto/memorydb/models.py
create mode 100644 moto/memorydb/responses.py
create mode 100644 moto/memorydb/urls.py
create mode 100644 tests/test_memorydb/__init__.py
create mode 100644 tests/test_memorydb/test_memorydb.py
diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md
index 8b0aedcf9fa1..d5b35d29101c 100644
--- a/IMPLEMENTATION_COVERAGE.md
+++ b/IMPLEMENTATION_COVERAGE.md
@@ -296,7 +296,7 @@
## appconfig
-33% implemented
+34% implemented
- [X] create_application
- [X] create_configuration_profile
@@ -312,7 +312,6 @@
- [ ] delete_extension
- [ ] delete_extension_association
- [X] delete_hosted_configuration_version
-- [ ] get_account_settings
- [X] get_application
- [ ] get_configuration
- [X] get_configuration_profile
@@ -335,7 +334,6 @@
- [ ] stop_deployment
- [X] tag_resource
- [X] untag_resource
-- [ ] update_account_settings
- [X] update_application
- [X] update_configuration_profile
- [ ] update_deployment_strategy
@@ -364,50 +362,6 @@
- [ ] untag_resource
-## appmesh
-
-57% implemented
-
-- [ ] create_gateway_route
-- [X] create_mesh
-- [X] create_route
-- [ ] create_virtual_gateway
-- [X] create_virtual_node
-- [X] create_virtual_router
-- [ ] create_virtual_service
-- [ ] delete_gateway_route
-- [X] delete_mesh
-- [X] delete_route
-- [ ] delete_virtual_gateway
-- [X] delete_virtual_node
-- [X] delete_virtual_router
-- [ ] delete_virtual_service
-- [ ] describe_gateway_route
-- [X] describe_mesh
-- [X] describe_route
-- [ ] describe_virtual_gateway
-- [X] describe_virtual_node
-- [X] describe_virtual_router
-- [ ] describe_virtual_service
-- [ ] list_gateway_routes
-- [X] list_meshes
-- [X] list_routes
-- [X] list_tags_for_resource
-- [ ] list_virtual_gateways
-- [X] list_virtual_nodes
-- [X] list_virtual_routers
-- [ ] list_virtual_services
-- [X] tag_resource
-- [ ] untag_resource
-- [ ] update_gateway_route
-- [X] update_mesh
-- [X] update_route
-- [ ] update_virtual_gateway
-- [X] update_virtual_node
-- [X] update_virtual_router
-- [ ] update_virtual_service
-
-
## appsync
23% implemented
@@ -753,50 +707,37 @@
## bedrock
-28% implemented
+39% implemented
-- [ ] batch_delete_evaluation_job
- [ ] create_evaluation_job
- [ ] create_guardrail
- [ ] create_guardrail_version
- [ ] create_model_copy_job
- [X] create_model_customization_job
-- [ ] create_model_import_job
-- [ ] create_model_invocation_job
- [ ] create_provisioned_model_throughput
- [X] delete_custom_model
- [ ] delete_guardrail
-- [ ] delete_imported_model
- [X] delete_model_invocation_logging_configuration
- [ ] delete_provisioned_model_throughput
- [X] get_custom_model
- [ ] get_evaluation_job
- [ ] get_foundation_model
- [ ] get_guardrail
-- [ ] get_imported_model
-- [ ] get_inference_profile
- [ ] get_model_copy_job
- [X] get_model_customization_job
-- [ ] get_model_import_job
-- [ ] get_model_invocation_job
- [X] get_model_invocation_logging_configuration
- [ ] get_provisioned_model_throughput
- [X] list_custom_models
- [ ] list_evaluation_jobs
- [ ] list_foundation_models
- [ ] list_guardrails
-- [ ] list_imported_models
-- [ ] list_inference_profiles
- [ ] list_model_copy_jobs
- [X] list_model_customization_jobs
-- [ ] list_model_import_jobs
-- [ ] list_model_invocation_jobs
- [ ] list_provisioned_model_throughputs
- [X] list_tags_for_resource
- [X] put_model_invocation_logging_configuration
- [ ] stop_evaluation_job
- [X] stop_model_customization_job
-- [ ] stop_model_invocation_job
- [X] tag_resource
- [X] untag_resource
- [ ] update_guardrail
@@ -3328,7 +3269,7 @@
## elbv2
-64% implemented
+67% implemented
- [X] add_listener_certificates
- [X] add_tags
@@ -3346,7 +3287,6 @@
- [ ] delete_trust_store
- [X] deregister_targets
- [ ] describe_account_limits
-- [ ] describe_listener_attributes
- [X] describe_listener_certificates
- [X] describe_listeners
- [X] describe_load_balancer_attributes
@@ -3364,7 +3304,6 @@
- [ ] get_trust_store_ca_certificates_bundle
- [ ] get_trust_store_revocation_content
- [X] modify_listener
-- [ ] modify_listener_attributes
- [X] modify_load_balancer_attributes
- [X] modify_rule
- [X] modify_target_group
@@ -4976,7 +4915,7 @@
## lambda
-63% implemented
+65% implemented
- [ ] add_layer_version_permission
- [X] add_permission
@@ -5004,7 +4943,6 @@
- [X] get_function_concurrency
- [ ] get_function_configuration
- [X] get_function_event_invoke_config
-- [ ] get_function_recursion_config
- [X] get_function_url_config
- [X] get_layer_version
- [ ] get_layer_version_by_arn
@@ -5032,7 +4970,6 @@
- [ ] put_function_code_signing_config
- [X] put_function_concurrency
- [X] put_function_event_invoke_config
-- [ ] put_function_recursion_config
- [ ] put_provisioned_concurrency_config
- [ ] put_runtime_management_config
- [ ] remove_layer_version_permission
@@ -5053,7 +4990,7 @@
47% implemented
- [ ] associate_kms_key
-- [X] cancel_export_task
+- [ ] cancel_export_task
- [ ] create_delivery
- [X] create_export_task
- [ ] create_log_anomaly_detector
@@ -5075,7 +5012,6 @@
- [X] delete_retention_policy
- [X] delete_subscription_filter
- [ ] describe_account_policies
-- [ ] describe_configuration_templates
- [ ] describe_deliveries
- [ ] describe_delivery_destinations
- [ ] describe_delivery_sources
@@ -5126,7 +5062,6 @@
- [X] untag_log_group
- [X] untag_resource
- [ ] update_anomaly
-- [ ] update_delivery_configuration
- [ ] update_log_anomaly_detector
@@ -5165,7 +5100,7 @@
## mediaconnect
-34% implemented
+35% implemented
- [ ] add_bridge_outputs
- [ ] add_bridge_sources
@@ -5183,7 +5118,6 @@
- [ ] describe_bridge
- [X] describe_flow
- [ ] describe_flow_source_metadata
-- [ ] describe_flow_source_thumbnail
- [ ] describe_gateway
- [ ] describe_gateway_instance
- [ ] describe_offering
@@ -5382,6 +5316,50 @@
- [X] put_object
+## memorydb
+
+34% implemented
+
+- [ ] batch_update_cluster
+- [ ] copy_snapshot
+- [ ] create_acl
+- [X] create_cluster
+- [ ] create_parameter_group
+- [X] create_snapshot
+- [X] create_subnet_group
+- [ ] create_user
+- [ ] delete_acl
+- [X] delete_cluster
+- [ ] delete_parameter_group
+- [X] delete_snapshot
+- [X] delete_subnet_group
+- [ ] delete_user
+- [ ] describe_acls
+- [X] describe_clusters
+- [ ] describe_engine_versions
+- [ ] describe_events
+- [ ] describe_parameter_groups
+- [ ] describe_parameters
+- [ ] describe_reserved_nodes
+- [ ] describe_reserved_nodes_offerings
+- [ ] describe_service_updates
+- [X] describe_snapshots
+- [X] describe_subnet_groups
+- [ ] describe_users
+- [ ] failover_shard
+- [ ] list_allowed_node_type_updates
+- [X] list_tags
+- [ ] purchase_reserved_nodes_offering
+- [ ] reset_parameter_group
+- [X] tag_resource
+- [X] untag_resource
+- [ ] update_acl
+- [X] update_cluster
+- [ ] update_parameter_group
+- [ ] update_subnet_group
+- [ ] update_user
+
+
## meteringmarketplace
25% implemented
@@ -5954,7 +5932,6 @@
- [ ] update_dataset
- [ ] update_metric_attribution
- [ ] update_recommender
-- [ ] update_solution
## pinpoint
@@ -6100,32 +6077,6 @@
- [ ] synthesize_speech
-## qldb
-
-30% implemented
-
-- [ ] cancel_journal_kinesis_stream
-- [X] create_ledger
-- [X] delete_ledger
-- [ ] describe_journal_kinesis_stream
-- [ ] describe_journal_s3_export
-- [X] describe_ledger
-- [ ] export_journal_to_s3
-- [ ] get_block
-- [ ] get_digest
-- [ ] get_revision
-- [ ] list_journal_kinesis_streams_for_ledger
-- [ ] list_journal_s3_exports
-- [ ] list_journal_s3_exports_for_ledger
-- [ ] list_ledgers
-- [X] list_tags_for_resource
-- [ ] stream_journal_to_kinesis
-- [X] tag_resource
-- [ ] untag_resource
-- [X] update_ledger
-- [ ] update_ledger_permissions_mode
-
-
## quicksight
7% implemented
@@ -7124,7 +7075,7 @@
## s3
-71% implemented
+68% implemented
- [X] abort_multipart_upload
- [X] complete_multipart_upload
@@ -7149,7 +7100,7 @@
- [X] delete_object_tagging
- [X] delete_objects
- [X] delete_public_access_block
-- [X] get_bucket_accelerate_configuration
+- [ ] get_bucket_accelerate_configuration
- [X] get_bucket_acl
- [ ] get_bucket_analytics_configuration
- [X] get_bucket_cors
@@ -7188,7 +7139,7 @@
- [ ] list_bucket_metrics_configurations
- [X] list_buckets
- [ ] list_directory_buckets
-- [X] list_multipart_uploads
+- [ ] list_multipart_uploads
- [X] list_object_versions
- [X] list_objects
- [X] list_objects_v2
@@ -7212,7 +7163,7 @@
- [ ] put_bucket_request_payment
- [X] put_bucket_tagging
- [X] put_bucket_versioning
-- [X] put_bucket_website
+- [ ] put_bucket_website
- [X] put_object
- [X] put_object_acl
- [X] put_object_legal_hold
@@ -7298,7 +7249,6 @@
- [ ] list_access_grants_locations
- [ ] list_access_points
- [ ] list_access_points_for_object_lambda
-- [ ] list_caller_access_grants
- [ ] list_jobs
- [ ] list_multi_region_access_points
- [ ] list_regional_buckets
@@ -7983,7 +7933,7 @@
## shield
-25% implemented
+19% implemented
- [ ] associate_drt_log_bucket
- [ ] associate_drt_role
@@ -7991,7 +7941,7 @@
- [ ] associate_proactive_engagement_details
- [X] create_protection
- [ ] create_protection_group
-- [X] create_subscription
+- [ ] create_subscription
- [X] delete_protection
- [ ] delete_protection_group
- [ ] delete_subscription
@@ -8001,7 +7951,7 @@
- [ ] describe_emergency_contact_settings
- [X] describe_protection
- [ ] describe_protection_group
-- [X] describe_subscription
+- [ ] describe_subscription
- [ ] disable_application_layer_automatic_response
- [ ] disable_proactive_engagement
- [ ] disassociate_drt_log_bucket
@@ -8505,27 +8455,6 @@
- [ ] update_adapter
-## timestream-query
-
-40% implemented
-
-- [ ] cancel_query
-- [X] create_scheduled_query
-- [X] delete_scheduled_query
-- [ ] describe_account_settings
-- [X] describe_endpoints
-- [X] describe_scheduled_query
-- [ ] execute_scheduled_query
-- [ ] list_scheduled_queries
-- [ ] list_tags_for_resource
-- [ ] prepare_query
-- [X] query
-- [ ] tag_resource
-- [ ] untag_resource
-- [ ] update_account_settings
-- [X] update_scheduled_query
-
-
## timestream-write
78% implemented
@@ -8835,6 +8764,7 @@
- application-insights
- application-signals
- applicationcostprofiler
+- appmesh
- apprunner
- appstream
- apptest
@@ -8874,6 +8804,7 @@
- codeguru-reviewer
- codeguru-security
- codeguruprofiler
+- codestar
- codestar-connections
- codestar-notifications
- cognito-sync
@@ -8978,7 +8909,6 @@
- mediapackagev2
- mediatailor
- medical-imaging
-- memorydb
- mgh
- mgn
- migration-hub-refactor-spaces
@@ -9001,7 +8931,6 @@
- payment-cryptography-data
- pca-connector-ad
- pca-connector-scep
-- pcs
- personalize-events
- personalize-runtime
- pi
@@ -9015,6 +8944,7 @@
- qapps
- qbusiness
- qconnect
+- qldb
- qldb-session
- rbin
- redshift-serverless
@@ -9055,6 +8985,7 @@
- synthetics
- taxsettings
- timestream-influxdb
+- timestream-query
- tnb
- translate
- trustedadvisor
diff --git a/docs/docs/services/memorydb.rst b/docs/docs/services/memorydb.rst
new file mode 100644
index 000000000000..e36276f40571
--- /dev/null
+++ b/docs/docs/services/memorydb.rst
@@ -0,0 +1,57 @@
+.. _implementedservice_memorydb:
+
+.. |start-h3| raw:: html
+
+
+
+.. |end-h3| raw:: html
+
+
+
+========
+memorydb
+========
+
+.. autoclass:: moto.memorydb.models.MemoryDBBackend
+
+|start-h3| Implemented features for this service |end-h3|
+
+- [ ] batch_update_cluster
+- [ ] copy_snapshot
+- [ ] create_acl
+- [X] create_cluster
+- [ ] create_parameter_group
+- [X] create_snapshot
+- [X] create_subnet_group
+- [ ] create_user
+- [ ] delete_acl
+- [X] delete_cluster
+- [ ] delete_parameter_group
+- [X] delete_snapshot
+- [X] delete_subnet_group
+- [ ] delete_user
+- [ ] describe_acls
+- [X] describe_clusters
+- [ ] describe_engine_versions
+- [ ] describe_events
+- [ ] describe_parameter_groups
+- [ ] describe_parameters
+- [ ] describe_reserved_nodes
+- [ ] describe_reserved_nodes_offerings
+- [ ] describe_service_updates
+- [X] describe_snapshots
+- [X] describe_subnet_groups
+- [ ] describe_users
+- [ ] failover_shard
+- [ ] list_allowed_node_type_updates
+- [X] list_tags
+- [ ] purchase_reserved_nodes_offering
+- [ ] reset_parameter_group
+- [X] tag_resource
+- [X] untag_resource
+- [ ] update_acl
+- [X] update_cluster
+- [ ] update_parameter_group
+- [ ] update_subnet_group
+- [ ] update_user
+
diff --git a/moto/backend_index.py b/moto/backend_index.py
index a965b21fb111..46d10cae806c 100644
--- a/moto/backend_index.py
+++ b/moto/backend_index.py
@@ -116,6 +116,7 @@
("mediapackage", re.compile("https?://mediapackage\\.(.+)\\.amazonaws.com")),
("mediastore", re.compile("https?://mediastore\\.(.+)\\.amazonaws\\.com")),
("mediastoredata", re.compile("https?://data\\.mediastore\\.(.+)\\.amazonaws.com")),
+ ("memorydb", re.compile("https?://memory-db\\.(.+)\\.amazonaws\\.com")),
(
"meteringmarketplace",
re.compile("https?://metering.marketplace.(.+).amazonaws.com"),
diff --git a/moto/backends.py b/moto/backends.py
index 0ca6319f7e64..bb70dda5a250 100644
--- a/moto/backends.py
+++ b/moto/backends.py
@@ -92,6 +92,7 @@
from moto.mediapackage.models import MediaPackageBackend
from moto.mediastore.models import MediaStoreBackend
from moto.mediastoredata.models import MediaStoreDataBackend
+ from moto.memorydb.models import MemoryDBBackend
from moto.meteringmarketplace.models import MeteringMarketplaceBackend
from moto.moto_api._internal.models import MotoAPIBackend
from moto.mq.models import MQBackend
@@ -268,6 +269,7 @@ def get_service_from_url(url: str) -> Optional[str]:
"Literal['medialive']",
"Literal['mediapackage']",
"Literal['mediastore']",
+ "Literal['memorydb']",
"Literal['mediastore-data']",
"Literal['meteringmarketplace']",
"Literal['moto_api']",
@@ -564,6 +566,8 @@ def get_backend(
name: "Literal['mediastore-data']",
) -> "BackendDict[MediaStoreDataBackend]": ...
@overload
+def get_backend(name: "Literal['memorydb']") -> "BackendDict[MemoryDBBackend]": ...
+@overload
def get_backend(
name: "Literal['meteringmarketplace']",
) -> "BackendDict[MeteringMarketplaceBackend]": ...
diff --git a/moto/memorydb/__init__.py b/moto/memorydb/__init__.py
new file mode 100644
index 000000000000..01ab867bb9c2
--- /dev/null
+++ b/moto/memorydb/__init__.py
@@ -0,0 +1 @@
+from .models import memorydb_backends # noqa: F401
diff --git a/moto/memorydb/exceptions.py b/moto/memorydb/exceptions.py
new file mode 100644
index 000000000000..2d0c9cf4a86a
--- /dev/null
+++ b/moto/memorydb/exceptions.py
@@ -0,0 +1,59 @@
+"""Exceptions raised by the memorydb service."""
+
+from typing import List
+
+from moto.core.exceptions import JsonRESTError
+
+
+class MemoryDBClientError(JsonRESTError):
+ code = 400
+
+
+class ClusterAlreadyExistsFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("ClusterAlreadyExistsFault", msg)
+
+
+class InvalidSubnetError(MemoryDBClientError):
+ def __init__(self, subnet_identifier: List[str]):
+ super().__init__("InvalidSubnetError", f"Subnet {subnet_identifier} not found.")
+
+
+class SubnetGroupAlreadyExistsFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("SubnetGroupAlreadyExistsFault", msg)
+
+
+class ClusterNotFoundFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("ClusterNotFoundFault", msg)
+
+
+class SnapshotAlreadyExistsFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("SnapshotAlreadyExistsFault", msg)
+
+
+class SnapshotNotFoundFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("SnapshotNotFoundFault", msg)
+
+
+class SubnetGroupNotFoundFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("SubnetGroupNotFoundFault", msg)
+
+
+class TagNotFoundFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("TagNotFoundFault", msg)
+
+
+class InvalidParameterValueException(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("InvalidParameterValueException", msg)
+
+
+class SubnetGroupInUseFault(MemoryDBClientError):
+ def __init__(self, msg: str):
+ super().__init__("SubnetGroupInUseFault", msg)
diff --git a/moto/memorydb/models.py b/moto/memorydb/models.py
new file mode 100644
index 000000000000..d52b2f8502b6
--- /dev/null
+++ b/moto/memorydb/models.py
@@ -0,0 +1,664 @@
+"""MemoryDBBackend class with methods for supported APIs."""
+
+import copy
+import random
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+from moto.core.base_backend import BackendDict, BaseBackend
+from moto.core.common_models import BaseModel
+from moto.ec2 import ec2_backends
+from moto.utilities.tagging_service import TaggingService
+
+from .exceptions import (
+ ClusterAlreadyExistsFault,
+ ClusterNotFoundFault,
+ InvalidParameterValueException,
+ InvalidSubnetError,
+ SnapshotAlreadyExistsFault,
+ SnapshotNotFoundFault,
+ SubnetGroupAlreadyExistsFault,
+ SubnetGroupInUseFault,
+ SubnetGroupNotFoundFault,
+ TagNotFoundFault,
+)
+
+
+class MemoryDBCluster(BaseModel):
+ def __init__(
+ self,
+ cluster_name: str,
+ node_type: str,
+ parameter_group_name: str,
+ description: str,
+ num_shards: int,
+ num_replicas_per_shard: int,
+ subnet_group_name: str,
+ vpc_id: str,
+ maintenance_window: str,
+ port: int,
+ sns_topic_arn: str,
+ kms_key_id: str,
+ snapshot_arns: List[str],
+ snapshot_name: str,
+ snapshot_retention_limit: int,
+ snapshot_window: str,
+ acl_name: str,
+ engine_version: str,
+ region: str,
+ account_id: str,
+ security_group_ids: List[str],
+ auto_minor_version_upgrade: bool,
+ data_tiering: bool,
+ tls_enabled: bool,
+ ):
+ self.cluster_name = cluster_name
+ self.node_type = node_type
+ # Default is set to 'default.memorydb-redis7'.
+ self.parameter_group_name = parameter_group_name or "default.memorydb-redis7"
+ # Setting it to 'in-sync', other option are 'active' or 'applying'.
+ self.parameter_group_status = "in-sync"
+ self.description = description
+ self.num_shards = num_shards or 1 # Default shards is set to 1
+ # Defaults to 1 (i.e. 2 nodes per shard).
+ self.num_replicas_per_shard = num_replicas_per_shard or 1
+ self.subnet_group_name = subnet_group_name
+ self.vpc_id = vpc_id
+ self.maintenance_window = maintenance_window or "wed:08:00-wed:09:00"
+ self.port = port or 6379 # Default is set to 6379
+ self.sns_topic_arn = sns_topic_arn
+ self.tls_enabled = tls_enabled or True
+ self.kms_key_id = kms_key_id
+ self.snapshot_arns = snapshot_arns
+ self.snapshot_name = snapshot_name
+ self.snapshot_retention_limit = snapshot_retention_limit or 0
+ self.snapshot_window = snapshot_window or "03:00-04:00"
+ # When tlsenable is set to false, the acl_name must be open-access.
+ self.acl_name = "open-access" if not tls_enabled else acl_name
+ self.region = region
+ self.engine_version = engine_version
+ if engine_version == "7.0":
+ self.engine_patch_version = "7.0.7"
+ elif engine_version == "6.2":
+ self.engine_patch_version = "6.2.6"
+ else:
+ self.engine_version = "7.1" # Default is '7.1'.
+ self.engine_patch_version = "7.1.1"
+ self.auto_minor_version_upgrade = auto_minor_version_upgrade or True
+ self.data_tiering = "true" if data_tiering else "false"
+ # The initial status of the cluster will be set to 'creating'."
+ self.status = (
+ # Set to 'available', other options are 'creating', 'Updating'.
+ "available"
+ )
+ self.pending_updates: Dict[Any, Any] = {} # TODO
+ self.shards = self.get_shard_details()
+
+ self.availability_mode = (
+ "SingleAZ" if self.num_replicas_per_shard == 0 else "MultiAZ"
+ )
+ self.cluster_endpoint = {
+ "Address": f"clustercfg.{self.cluster_name}.aoneci.memorydb.{region}.amazonaws.com",
+ "Port": self.port,
+ }
+ self.security_group_ids = security_group_ids or []
+ self.security_groups = []
+ for sg in self.security_group_ids:
+ security_group = {"SecurityGroupId": sg, "Status": "active"}
+ self.security_groups.append(security_group)
+ self.arn = f"arn:aws:memorydb:{region}:{account_id}:cluster/{self.cluster_name}"
+ self.sns_topic_status = "active" if self.sns_topic_arn else ""
+
+ def get_shard_details(self) -> List[Dict[str, Any]]:
+ shards = []
+ for i in range(self.num_shards):
+ shard_name = f"{i+1:04}"
+ num_nodes = self.num_replicas_per_shard + 1
+ nodes = []
+ azs = ["a", "b", "c", "d"]
+ for n in range(num_nodes):
+ node_name = f"{self.cluster_name}-{shard_name}-{n+1:03}"
+ node = {
+ "Name": node_name,
+ "Status": "available",
+ "AvailabilityZone": f"{self.region}{random.choice(azs)}",
+ "CreateTime": datetime.now().strftime(
+ "%Y-%m-%dT%H:%M:%S.000%f+0000"
+ ),
+ "Endpoint": {
+ "Address": f"{node_name}.{self.cluster_name}.aoneci.memorydb.{self.region}.amazonaws.com",
+ "Port": self.port,
+ },
+ }
+ nodes.append(node)
+
+ shard = {
+ "Name": shard_name,
+ # Set to 'available', other options are 'creating', 'modifying' , 'deleting'.
+ "Status": "available",
+ "Slots": f"0-{str(random.randint(10000,99999))}",
+ "Nodes": nodes,
+ "NumberOfNodes": num_nodes,
+ }
+ shards.append(shard)
+ return shards
+
+ def update(
+ self,
+ description: Optional[str],
+ security_group_ids: Optional[List[str]],
+ maintenance_window: Optional[str],
+ sns_topic_arn: Optional[str],
+ sns_topic_status: Optional[str],
+ parameter_group_name: Optional[str],
+ snapshot_window: Optional[str],
+ snapshot_retention_limit: Optional[int],
+ node_type: Optional[str],
+ engine_version: Optional[str],
+ replica_configuration: Optional[Dict[str, int]],
+ shard_configuration: Optional[Dict[str, int]],
+ acl_name: Optional[str],
+ ) -> None:
+ if description is not None:
+ self.description = description
+ if security_group_ids is not None:
+ self.security_group_ids = security_group_ids
+ if maintenance_window is not None:
+ self.maintenance_window = maintenance_window
+ if sns_topic_arn is not None:
+ self.sns_topic_arn = sns_topic_arn
+ if sns_topic_status is not None:
+ self.sns_topic_status = sns_topic_status
+ if parameter_group_name is not None:
+ self.parameter_group_name = parameter_group_name
+ if snapshot_window is not None:
+ self.snapshot_window = snapshot_window
+ if snapshot_retention_limit is not None:
+ self.snapshot_retention_limit = snapshot_retention_limit
+ if node_type is not None:
+ self.node_type = node_type
+ if engine_version is not None:
+ self.engine_version = engine_version
+ if replica_configuration is not None:
+ self.num_replicas_per_shard = replica_configuration["ReplicaCount"]
+ self.shards = self.get_shard_details() # update shards and nodes
+ if shard_configuration is not None:
+ self.num_shards = shard_configuration["ShardCount"]
+ self.shards = self.get_shard_details() # update shards and nodes
+ if acl_name is not None:
+ self.acl_name = acl_name
+
+ def to_dict(self) -> Dict[str, Any]:
+ dct = {
+ "Name": self.cluster_name,
+ "Description": self.description,
+ "Status": self.status,
+ "PendingUpdates": self.pending_updates,
+ "NumberOfShards": self.num_shards,
+ "AvailabilityMode": self.availability_mode,
+ "ClusterEndpoint": self.cluster_endpoint,
+ "NodeType": self.node_type,
+ "EngineVersion": self.engine_version,
+ "EnginePatchVersion": self.engine_patch_version,
+ "ParameterGroupName": self.parameter_group_name,
+ "ParameterGroupStatus": self.parameter_group_status,
+ "SecurityGroups": self.security_groups,
+ "SubnetGroupName": self.subnet_group_name,
+ "TLSEnabled": self.tls_enabled,
+ "KmsKeyId": self.kms_key_id,
+ "ARN": self.arn,
+ "SnsTopicArn": self.sns_topic_arn,
+ "SnsTopicStatus": self.sns_topic_status,
+ "MaintenanceWindow": self.maintenance_window,
+ "SnapshotWindow": self.snapshot_window,
+ "ACLName": self.acl_name,
+ "AutoMinorVersionUpgrade": self.auto_minor_version_upgrade,
+ "DataTiering": self.data_tiering,
+ }
+ dct_items = {k: v for k, v in dct.items() if v}
+ dct_items["SnapshotRetentionLimit"] = self.snapshot_retention_limit
+ return dct_items
+
+ def to_desc_dict(self) -> Dict[str, Any]:
+ dct = self.to_dict()
+ dct["Shards"] = self.shards
+ return dct
+
+
+class MemoryDBSubnetGroup(BaseModel):
+ def __init__(
+ self,
+ region_name: str,
+ account_id: str,
+ ec2_backend: Any,
+ subnet_group_name: str,
+ description: str,
+ subnet_ids: List[str],
+ tags: Optional[List[Dict[str, str]]] = None,
+ ):
+ self.ec2_backend = ec2_backend
+ self.subnet_group_name = subnet_group_name
+ self.description = description
+ self.subnet_ids = subnet_ids
+ if not self.subnets:
+ raise InvalidSubnetError(subnet_ids)
+ self.arn = f"arn:aws:memorydb:{region_name}:{account_id}:subnetgroup/{subnet_group_name}"
+
+ @property
+ def subnets(self) -> Any: # type: ignore[misc]
+ return self.ec2_backend.describe_subnets(filters={"subnet-id": self.subnet_ids})
+
+ @property
+ def vpc_id(self) -> str:
+ return self.subnets[0].vpc_id
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "Name": self.subnet_group_name,
+ "Description": self.description,
+ "VpcId": self.vpc_id,
+ "Subnets": [
+ {
+ "Identifier": subnet.id,
+ "AvailabilityZone": {"Name": subnet.availability_zone},
+ }
+ for subnet in self.subnets
+ ],
+ "ARN": self.arn,
+ }
+
+
+class MemoryDBSnapshot(BaseModel):
+ def __init__(
+ self,
+ account_id: str,
+ region_name: str,
+ cluster: MemoryDBCluster,
+ snapshot_name: str,
+ kms_key_id: Optional[str],
+ tags: Optional[List[Dict[str, str]]],
+ source: Optional[str],
+ ):
+ self.cluster = copy.copy(cluster)
+ self.cluster_name = self.cluster.cluster_name
+ self.snapshot_name = snapshot_name
+ self.status = "available"
+ self.source = source
+ self.kms_key_id = kms_key_id if kms_key_id else cluster.kms_key_id
+ self.arn = (
+ f"arn:aws:memorydb:{region_name}:{account_id}:snapshot/{snapshot_name}"
+ )
+ self.vpc_id = self.cluster.vpc_id
+ self.shards = []
+ for i in self.cluster.shards:
+ shard = {
+ "Name": i["Name"],
+ "Configuration": {
+ "Slots": i["Slots"],
+ "ReplicaCount": self.cluster.num_replicas_per_shard,
+ },
+ "Size": "11 MB",
+ "SnapshotCreationTime": datetime.now().strftime(
+ "%Y-%m-%dT%H:%M:%S.000%f+0000"
+ ),
+ }
+ self.shards.append(shard)
+
+ def to_dict(self) -> Dict[str, Any]:
+ dct = {
+ "Name": self.snapshot_name,
+ "Status": self.status,
+ "Source": self.source,
+ "KmsKeyId": self.kms_key_id,
+ "ARN": self.arn,
+ "ClusterConfiguration": {
+ "Name": self.cluster_name,
+ "Description": self.cluster.description,
+ "NodeType": self.cluster.node_type,
+ "EngineVersion": self.cluster.engine_version,
+ "MaintenanceWindow": self.cluster.maintenance_window,
+ "TopicArn": self.cluster.sns_topic_arn,
+ "Port": self.cluster.port,
+ "ParameterGroupName": self.cluster.parameter_group_name,
+ "SubnetGroupName": self.cluster.subnet_group_name,
+ "VpcId": self.vpc_id,
+ "SnapshotRetentionLimit": self.cluster.snapshot_retention_limit,
+ "SnapshotWindow": self.cluster.snapshot_window,
+ "NumShards": self.cluster.num_shards,
+ },
+ "DataTiering": self.cluster.data_tiering,
+ }
+ return {k: v for k, v in dct.items() if v}
+
+ def to_desc_dict(self) -> Dict[str, Any]:
+ dct = self.to_dict()
+ dct["ClusterConfiguration"]["Shards"] = self.shards
+ return dct
+
+
+class MemoryDBBackend(BaseBackend):
+ """Implementation of MemoryDB APIs."""
+
+ def __init__(self, region_name: str, account_id: str):
+ super().__init__(region_name, account_id)
+
+ self.ec2_backend = ec2_backends[account_id][region_name]
+ self.clusters: Dict[str, MemoryDBCluster] = dict()
+ self.subnet_groups: Dict[str, MemoryDBSubnetGroup] = {
+ "default": MemoryDBSubnetGroup(
+ region_name,
+ account_id,
+ self.ec2_backend,
+ "default",
+ "Default MemoryDB Subnet Group",
+ self.get_default_subnets(),
+ )
+ }
+ self.snapshots: Dict[str, MemoryDBSnapshot] = dict()
+ self.tagger = TaggingService()
+
+ def get_default_subnets(self) -> List[str]:
+ default_subnets = self.ec2_backend.describe_subnets(
+ filters={"default-for-az": "true"}
+ )
+ default_subnet_ids = [i.id for i in default_subnets]
+ return default_subnet_ids
+
+ def _list_arns(self) -> List[str]:
+ return [cluster.arn for cluster in self.clusters.values()]
+
+ def create_cluster(
+ self,
+ cluster_name: str,
+ node_type: str,
+ parameter_group_name: str,
+ description: str,
+ subnet_group_name: str,
+ security_group_ids: List[str],
+ maintenance_window: str,
+ port: int,
+ sns_topic_arn: str,
+ tls_enabled: bool,
+ kms_key_id: str,
+ snapshot_arns: List[str],
+ snapshot_name: str,
+ snapshot_retention_limit: int,
+ tags: List[Dict[str, str]],
+ snapshot_window: str,
+ acl_name: str,
+ engine_version: str,
+ auto_minor_version_upgrade: bool,
+ data_tiering: bool,
+ num_shards: int,
+ num_replicas_per_shard: int,
+ ) -> MemoryDBCluster:
+ if cluster_name in self.clusters:
+ raise ClusterAlreadyExistsFault(
+ msg="Cluster with specified name already exists."
+ )
+
+ subnet_group_name = subnet_group_name or "default"
+ subnet_group = self.subnet_groups[subnet_group_name]
+ vpc_id = subnet_group.vpc_id
+ cluster = MemoryDBCluster(
+ cluster_name=cluster_name,
+ node_type=node_type,
+ parameter_group_name=parameter_group_name,
+ description=description,
+ num_shards=num_shards,
+ num_replicas_per_shard=num_replicas_per_shard,
+ subnet_group_name=subnet_group_name,
+ vpc_id=vpc_id,
+ security_group_ids=security_group_ids,
+ maintenance_window=maintenance_window,
+ port=port,
+ sns_topic_arn=sns_topic_arn,
+ tls_enabled=tls_enabled,
+ kms_key_id=kms_key_id,
+ snapshot_arns=snapshot_arns,
+ snapshot_name=snapshot_name,
+ snapshot_retention_limit=snapshot_retention_limit,
+ snapshot_window=snapshot_window,
+ acl_name=acl_name,
+ engine_version=engine_version,
+ auto_minor_version_upgrade=auto_minor_version_upgrade,
+ data_tiering=data_tiering,
+ region=self.region_name,
+ account_id=self.account_id,
+ )
+ self.clusters[cluster.cluster_name] = cluster
+ self.tag_resource(cluster.arn, tags)
+ return cluster
+
+ def create_subnet_group(
+ self,
+ subnet_group_name: str,
+ description: str,
+ subnet_ids: List[str],
+ tags: Optional[List[Dict[str, str]]] = None,
+ ) -> MemoryDBSubnetGroup:
+ if subnet_group_name in self.subnet_groups:
+ raise SubnetGroupAlreadyExistsFault(
+ msg=f"Subnet group {subnet_group_name} already exists."
+ )
+ subnet_group = MemoryDBSubnetGroup(
+ self.region_name,
+ self.account_id,
+ self.ec2_backend,
+ subnet_group_name,
+ description,
+ subnet_ids,
+ tags,
+ )
+ self.subnet_groups[subnet_group_name] = subnet_group
+ return subnet_group
+
+ def create_snapshot(
+ self,
+ cluster_name: str,
+ snapshot_name: str,
+ kms_key_id: Optional[str] = None,
+ tags: Optional[List[Dict[str, str]]] = None,
+ source: str = "manual",
+ ) -> MemoryDBSnapshot:
+ if cluster_name not in self.clusters:
+ raise ClusterNotFoundFault(msg=f"Cluster not found: {cluster_name}")
+ cluster = self.clusters[cluster_name]
+ if snapshot_name in self.snapshots:
+ raise SnapshotAlreadyExistsFault(
+ msg="Snapshot with specified name already exists."
+ )
+
+ snapshot = MemoryDBSnapshot(
+ account_id=self.account_id,
+ region_name=self.region_name,
+ cluster=cluster,
+ snapshot_name=snapshot_name,
+ kms_key_id=kms_key_id,
+ tags=tags,
+ source=source,
+ )
+ self.snapshots[snapshot_name] = snapshot
+ return snapshot
+
+ def describe_clusters(
+ self, cluster_name: Optional[str] = None
+ ) -> List[MemoryDBCluster]:
+ if cluster_name:
+ if cluster_name in self.clusters:
+ cluster = self.clusters[cluster_name]
+ return list([cluster])
+ else:
+ raise ClusterNotFoundFault(msg=f"Cluster {cluster_name} not found")
+ clusters = list(self.clusters.values())
+ return clusters
+
+ def describe_snapshots(
+ self,
+ cluster_name: Optional[str] = None,
+ snapshot_name: Optional[str] = None,
+ source: Optional[str] = None,
+ ) -> List[MemoryDBSnapshot]:
+ sources = ["automated", "manual"] if source is None else [source]
+
+ if cluster_name and snapshot_name:
+ for snapshot in list(self.snapshots.values()):
+ if (
+ snapshot.cluster_name == cluster_name
+ and snapshot.snapshot_name == snapshot_name
+ and snapshot.source in sources
+ ):
+ return [snapshot]
+ raise SnapshotNotFoundFault(
+ msg=f"Snapshot with name {snapshot_name} not found"
+ )
+
+ if cluster_name:
+ snapshots = [
+ snapshot
+ for snapshot in self.snapshots.values()
+ if (snapshot.cluster_name == cluster_name)
+ and (snapshot.source in sources)
+ ]
+ return snapshots
+
+ if snapshot_name:
+ snapshots = [
+ snapshot
+ for snapshot in self.snapshots.values()
+ if (snapshot.snapshot_name == snapshot_name)
+ and (snapshot.source in sources)
+ ]
+ if snapshots:
+ return snapshots
+ raise SnapshotNotFoundFault(
+ msg=f"Snapshot with name {snapshot_name} not found"
+ )
+
+ snapshots = [
+ snapshot
+ for snapshot in self.snapshots.values()
+ if snapshot.source in sources
+ ]
+ return snapshots
+
+ def describe_subnet_groups(
+ self, subnet_group_name: str
+ ) -> List[MemoryDBSubnetGroup]:
+ if subnet_group_name:
+ if subnet_group_name in self.subnet_groups:
+ return list([self.subnet_groups[subnet_group_name]])
+ raise SubnetGroupNotFoundFault(
+ msg=f"Subnet group {subnet_group_name} not found."
+ )
+
+ subnet_groups = list(self.subnet_groups.values())
+ return subnet_groups
+
+ def list_tags(self, resource_arn: str) -> List[Dict[str, str]]:
+ if resource_arn not in self._list_arns():
+ cluster_name = resource_arn.split("/")[-1]
+ raise ClusterNotFoundFault(f"{cluster_name} is not present")
+ return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
+
+ def tag_resource(
+ self, resource_arn: str, tags: List[Dict[str, str]]
+ ) -> List[Dict[str, str]]:
+ if resource_arn not in self._list_arns():
+ cluster_name = resource_arn.split("/")[-1]
+ raise ClusterNotFoundFault(f"{cluster_name} is not present")
+ self.tagger.tag_resource(resource_arn, tags)
+ return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
+
+ def untag_resource(
+ self, resource_arn: str, tag_keys: List[str]
+ ) -> List[Dict[str, str]]:
+ if resource_arn not in self._list_arns():
+ cluster_name = resource_arn.split("/")[-1]
+ raise ClusterNotFoundFault(f"{cluster_name} is not present")
+ list_tags = self.list_tags(resource_arn=resource_arn)
+ list_keys = [i["Key"] for i in list_tags]
+ invalid_keys = [key for key in tag_keys if key not in list_keys]
+ if invalid_keys:
+ raise TagNotFoundFault(msg=f"These tags are not present : {[invalid_keys]}")
+ self.tagger.untag_resource_using_names(resource_arn, tag_keys)
+ return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
+
+ def update_cluster(
+ self,
+ cluster_name: str,
+ description: Optional[str],
+ security_group_ids: Optional[List[str]],
+ maintenance_window: Optional[str],
+ sns_topic_arn: Optional[str],
+ sns_topic_status: Optional[str],
+ parameter_group_name: Optional[str],
+ snapshot_window: Optional[str],
+ snapshot_retention_limit: Optional[int],
+ node_type: Optional[str],
+ engine_version: Optional[str],
+ replica_configuration: Optional[Dict[str, int]],
+ shard_configuration: Optional[Dict[str, int]],
+ acl_name: Optional[str],
+ ) -> MemoryDBCluster:
+ if cluster_name in self.clusters:
+ cluster = self.clusters[cluster_name]
+ cluster.update(
+ description,
+ security_group_ids,
+ maintenance_window,
+ sns_topic_arn,
+ sns_topic_status,
+ parameter_group_name,
+ snapshot_window,
+ snapshot_retention_limit,
+ node_type,
+ engine_version,
+ replica_configuration,
+ shard_configuration,
+ acl_name,
+ )
+ return cluster
+ raise ClusterNotFoundFault(msg="Cluster not found.")
+
+ def delete_cluster(
+ self, cluster_name: str, final_snapshot_name: Optional[str]
+ ) -> MemoryDBCluster:
+ if cluster_name in self.clusters:
+ cluster = self.clusters[cluster_name]
+ cluster.status = "deleting"
+ if final_snapshot_name is not None: # create snapshot
+ self.create_snapshot(
+ cluster_name=cluster_name,
+ snapshot_name=final_snapshot_name,
+ source="manual",
+ )
+ return self.clusters.pop(cluster_name)
+ raise ClusterNotFoundFault(cluster_name)
+
+ def delete_snapshot(self, snapshot_name: str) -> MemoryDBSnapshot:
+ if snapshot_name in self.snapshots:
+ snapshot = self.snapshots[snapshot_name]
+ snapshot.status = "deleting"
+ return self.snapshots.pop(snapshot_name)
+ raise SnapshotNotFoundFault(snapshot_name)
+
+ def delete_subnet_group(self, subnet_group_name: str) -> MemoryDBSubnetGroup:
+ if subnet_group_name in self.subnet_groups:
+ if subnet_group_name == "default":
+ raise InvalidParameterValueException(
+ msg="default is reserved and cannot be modified."
+ )
+ if subnet_group_name in [
+ c.subnet_group_name for c in self.clusters.values()
+ ]:
+ raise SubnetGroupInUseFault(
+ msg=f"Subnet group {subnet_group_name} is currently in use by a cluster."
+ )
+ return self.subnet_groups.pop(subnet_group_name)
+ raise SubnetGroupNotFoundFault(
+ msg=f"Subnet group {subnet_group_name} not found."
+ )
+
+
+memorydb_backends = BackendDict(MemoryDBBackend, "memorydb")
diff --git a/moto/memorydb/responses.py b/moto/memorydb/responses.py
new file mode 100644
index 000000000000..4f4674255134
--- /dev/null
+++ b/moto/memorydb/responses.py
@@ -0,0 +1,229 @@
+"""Handles incoming memorydb requests, invokes methods, returns responses."""
+
+import json
+
+from moto.core.responses import BaseResponse
+
+from .models import MemoryDBBackend, memorydb_backends
+
+
+class MemoryDBResponse(BaseResponse):
+ """Handler for MemoryDB requests and responses."""
+
+ def __init__(self) -> None:
+ super().__init__(service_name="memorydb")
+
+ @property
+ def memorydb_backend(self) -> MemoryDBBackend:
+ """Return backend instance specific for this region."""
+ return memorydb_backends[self.current_account][self.region]
+
+ def create_cluster(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ node_type = params.get("NodeType")
+ parameter_group_name = params.get("ParameterGroupName")
+ description = params.get("Description")
+ num_shards = params.get("NumShards")
+ num_replicas_per_shard = params.get("NumReplicasPerShard")
+ subnet_group_name = params.get("SubnetGroupName")
+ security_group_ids = params.get("SecurityGroupIds")
+ maintenance_window = params.get("MaintenanceWindow")
+ port = params.get("Port")
+ sns_topic_arn = params.get("SnsTopicArn")
+ tls_enabled = params.get("TLSEnabled")
+ kms_key_id = params.get("KmsKeyId")
+ snapshot_arns = params.get("SnapshotArns")
+ snapshot_name = params.get("SnapshotName")
+ snapshot_retention_limit = params.get("SnapshotRetentionLimit")
+ tags = params.get("Tags")
+ snapshot_window = params.get("SnapshotWindow")
+ acl_name = params.get("ACLName")
+ engine_version = params.get("EngineVersion")
+ auto_minor_version_upgrade = params.get("AutoMinorVersionUpgrade")
+ data_tiering = params.get("DataTiering")
+ cluster = self.memorydb_backend.create_cluster(
+ cluster_name=cluster_name,
+ node_type=node_type,
+ parameter_group_name=parameter_group_name,
+ description=description,
+ num_shards=num_shards,
+ num_replicas_per_shard=num_replicas_per_shard,
+ subnet_group_name=subnet_group_name,
+ security_group_ids=security_group_ids,
+ maintenance_window=maintenance_window,
+ port=port,
+ sns_topic_arn=sns_topic_arn,
+ tls_enabled=tls_enabled,
+ kms_key_id=kms_key_id,
+ snapshot_arns=snapshot_arns,
+ snapshot_name=snapshot_name,
+ snapshot_retention_limit=snapshot_retention_limit,
+ tags=tags,
+ snapshot_window=snapshot_window,
+ acl_name=acl_name,
+ engine_version=engine_version,
+ auto_minor_version_upgrade=auto_minor_version_upgrade,
+ data_tiering=data_tiering,
+ )
+ return json.dumps(dict(Cluster=cluster.to_dict()))
+
+ def create_subnet_group(self) -> str:
+ params = json.loads(self.body)
+ subnet_group_name = params.get("SubnetGroupName")
+ description = params.get("Description")
+ subnet_ids = params.get("SubnetIds")
+ tags = params.get("Tags")
+ subnet_group = self.memorydb_backend.create_subnet_group(
+ subnet_group_name=subnet_group_name,
+ description=description,
+ subnet_ids=subnet_ids,
+ tags=tags,
+ )
+ return json.dumps(dict(SubnetGroup=subnet_group.to_dict()))
+
+ def create_snapshot(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ snapshot_name = params.get("SnapshotName")
+ kms_key_id = params.get("KmsKeyId")
+ tags = params.get("Tags")
+ snapshot = self.memorydb_backend.create_snapshot(
+ cluster_name=cluster_name,
+ snapshot_name=snapshot_name,
+ kms_key_id=kms_key_id,
+ tags=tags,
+ )
+ return json.dumps(dict(Snapshot=snapshot.to_dict()))
+
+ def describe_clusters(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ show_shard_details = params.get("ShowShardDetails")
+ clusters = self.memorydb_backend.describe_clusters(
+ cluster_name=cluster_name,
+ )
+ return json.dumps(
+ dict(
+ Clusters=[
+ cluster.to_desc_dict() if show_shard_details else cluster.to_dict()
+ for cluster in clusters
+ ]
+ )
+ )
+
+ def describe_snapshots(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ snapshot_name = params.get("SnapshotName")
+ source = params.get("Source")
+ show_detail = params.get("ShowDetail")
+ snapshots = self.memorydb_backend.describe_snapshots(
+ cluster_name=cluster_name,
+ snapshot_name=snapshot_name,
+ source=source,
+ )
+ return json.dumps(
+ dict(
+ Snapshots=[
+ snapshot.to_desc_dict() if show_detail else snapshot.to_dict()
+ for snapshot in snapshots
+ ]
+ )
+ )
+
+ def describe_subnet_groups(self) -> str:
+ params = json.loads(self.body)
+ subnet_group_name = params.get("SubnetGroupName")
+ subnet_groups = self.memorydb_backend.describe_subnet_groups(
+ subnet_group_name=subnet_group_name,
+ )
+ return json.dumps(dict(SubnetGroups=[sg.to_dict() for sg in subnet_groups]))
+
+ def list_tags(self) -> str:
+ params = json.loads(self.body)
+ resource_arn = params.get("ResourceArn")
+ tag_list = self.memorydb_backend.list_tags(
+ resource_arn=resource_arn,
+ )
+ return json.dumps(dict(TagList=tag_list))
+
+ def tag_resource(self) -> str:
+ params = json.loads(self.body)
+ resource_arn = params.get("ResourceArn")
+ tags = params.get("Tags")
+ tag_list = self.memorydb_backend.tag_resource(
+ resource_arn=resource_arn,
+ tags=tags,
+ )
+ return json.dumps(dict(TagList=tag_list))
+
+ def untag_resource(self) -> str:
+ params = json.loads(self.body)
+ resource_arn = params.get("ResourceArn")
+ tag_keys = params.get("TagKeys")
+ tag_list = self.memorydb_backend.untag_resource(
+ resource_arn=resource_arn,
+ tag_keys=tag_keys,
+ )
+ return json.dumps(dict(TagList=tag_list))
+
+ def update_cluster(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ description = params.get("Description")
+ security_group_ids = params.get("SecurityGroupIds")
+ maintenance_window = params.get("MaintenanceWindow")
+ sns_topic_arn = params.get("SnsTopicArn")
+ sns_topic_status = params.get("SnsTopicStatus")
+ parameter_group_name = params.get("ParameterGroupName")
+ snapshot_window = params.get("SnapshotWindow")
+ snapshot_retention_limit = params.get("SnapshotRetentionLimit")
+ node_type = params.get("NodeType")
+ engine_version = params.get("EngineVersion")
+ replica_configuration = params.get("ReplicaConfiguration")
+ shard_configuration = params.get("ShardConfiguration")
+ acl_name = params.get("ACLName")
+ cluster = self.memorydb_backend.update_cluster(
+ cluster_name=cluster_name,
+ description=description,
+ security_group_ids=security_group_ids,
+ maintenance_window=maintenance_window,
+ sns_topic_arn=sns_topic_arn,
+ sns_topic_status=sns_topic_status,
+ parameter_group_name=parameter_group_name,
+ snapshot_window=snapshot_window,
+ snapshot_retention_limit=snapshot_retention_limit,
+ node_type=node_type,
+ engine_version=engine_version,
+ replica_configuration=replica_configuration,
+ shard_configuration=shard_configuration,
+ acl_name=acl_name,
+ )
+ return json.dumps(dict(Cluster=cluster.to_dict()))
+
+ def delete_cluster(self) -> str:
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ final_snapshot_name = params.get("FinalSnapshotName")
+ cluster = self.memorydb_backend.delete_cluster(
+ cluster_name=cluster_name,
+ final_snapshot_name=final_snapshot_name,
+ )
+ return json.dumps(dict(Cluster=cluster.to_dict()))
+
+ def delete_snapshot(self) -> str:
+ params = json.loads(self.body)
+ snapshot_name = params.get("SnapshotName")
+ snapshot = self.memorydb_backend.delete_snapshot(
+ snapshot_name=snapshot_name,
+ )
+ return json.dumps(dict(Snapshot=snapshot.to_dict()))
+
+ def delete_subnet_group(self) -> str:
+ params = json.loads(self.body)
+ subnet_group_name = params.get("SubnetGroupName")
+ subnet_group = self.memorydb_backend.delete_subnet_group(
+ subnet_group_name=subnet_group_name,
+ )
+ return json.dumps(dict(SubnetGroup=subnet_group.to_dict()))
diff --git a/moto/memorydb/urls.py b/moto/memorydb/urls.py
new file mode 100644
index 000000000000..c9118d329623
--- /dev/null
+++ b/moto/memorydb/urls.py
@@ -0,0 +1,11 @@
+"""memorydb base URL and path."""
+
+from .responses import MemoryDBResponse
+
+url_bases = [
+ r"https?://memory-db\.(.+)\.amazonaws\.com",
+]
+
+url_paths = {
+ "{0}/$": MemoryDBResponse.dispatch,
+}
diff --git a/moto/moto_server/werkzeug_app.py b/moto/moto_server/werkzeug_app.py
index 3449955be743..1daf656139a4 100644
--- a/moto/moto_server/werkzeug_app.py
+++ b/moto/moto_server/werkzeug_app.py
@@ -163,6 +163,8 @@ def infer_service_region_host(
host = "s3control"
elif service == "ses" and path.startswith("/v2/"):
host = "sesv2"
+ elif service == "memorydb":
+ host = f"memory-db.{region}.amazonaws.com"
else:
host = f"{service}.{region}.amazonaws.com"
diff --git a/tests/test_memorydb/__init__.py b/tests/test_memorydb/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_memorydb/test_memorydb.py b/tests/test_memorydb/test_memorydb.py
new file mode 100644
index 000000000000..8b4ff0c5c252
--- /dev/null
+++ b/tests/test_memorydb/test_memorydb.py
@@ -0,0 +1,661 @@
+"""Unit tests for memorydb-supported APIs."""
+
+import boto3
+import pytest
+from botocore.exceptions import ClientError
+
+from moto import mock_aws
+from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
+
+# See our Development Tips on writing tests for hints on how to write good tests:
+# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
+
+
+def create_subnet_group(client, region_name):
+ """Return valid Subnet group."""
+ ec2 = boto3.resource("ec2", region_name=region_name)
+ vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
+ subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24")
+ subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.1.0/24")
+ subnet_group = client.create_subnet_group(
+ SubnetGroupName="my_subnet_group",
+ Description="This is my subnet group",
+ SubnetIds=[subnet1.id, subnet2.id],
+ )
+ return subnet_group
+
+
+@mock_aws
+def test_create_cluster():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ resp = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ cluster = resp["Cluster"]
+ assert "Name" in cluster
+ assert "Status" in cluster
+ assert "NumberOfShards" in cluster
+
+
+@mock_aws
+def test_create_duplicate_cluster_fails():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ client.create_cluster(
+ ClusterName="foo-bar",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ with pytest.raises(ClientError) as ex:
+ client.create_cluster(
+ ClusterName="foo-bar", NodeType="db.t4g.small", ACLName="open-access"
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterAlreadyExistsFault"
+
+
+@mock_aws
+def test_create_subnet_group():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ subnet_group = create_subnet_group(client, "ap-southeast-1")
+ sg = subnet_group["SubnetGroup"]
+ assert "Name" in sg
+ assert "Description" in sg
+ assert "VpcId" in sg
+ assert "Subnets" in sg
+ assert "ARN" in sg
+
+
+@mock_aws
+def test_create_cluster_with_subnet_group():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ subnet_group = create_subnet_group(client, "ap-southeast-1")
+ resp = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ SubnetGroupName=subnet_group["SubnetGroup"]["Name"],
+ ACLName="open-access",
+ )
+ subnet_group = resp["Cluster"]["SubnetGroupName"] == "my_subnet_group"
+
+
+@mock_aws
+def test_create_duplicate_subnet_group_fails():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ create_subnet_group(client, "ap-southeast-1")
+ with pytest.raises(ClientError) as ex:
+ create_subnet_group(client, "ap-southeast-1")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SubnetGroupAlreadyExistsFault"
+
+
+@mock_aws
+def test_create_invalid_subnet_group_fails():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ with pytest.raises(ClientError) as ex:
+ client.create_subnet_group(SubnetGroupName="foo-bar", SubnetIds=["foo", "bar"])
+ err = ex.value.response["Error"]
+ assert err["Code"] == "InvalidSubnetError"
+
+
+@mock_aws
+def test_create_snapshot():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ subnet_group = create_subnet_group(client, "ap-southeast-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ Description="Test memorydb cluster",
+ NodeType="db.t4g.small",
+ SubnetGroupName=subnet_group["SubnetGroup"]["Name"],
+ ACLName="open-access",
+ )
+ resp = client.create_snapshot(
+ ClusterName=cluster["Cluster"]["Name"],
+ SnapshotName="my-snapshot-1",
+ KmsKeyId=f"arn:aws:kms:ap-southeast-1:{ACCOUNT_ID}:key/51d81fab-b138-4bd2-8a09-07fd6d37224d",
+ Tags=[
+ {"Key": "foo", "Value": "bar"},
+ ],
+ )
+ snapshot = resp["Snapshot"]
+ assert "Name" in snapshot
+ assert "Status" in snapshot
+ assert "Source" in snapshot
+ assert "KmsKeyId" in snapshot
+ assert "ARN" in snapshot
+ assert "ClusterConfiguration" in snapshot
+ assert "DataTiering" in snapshot
+
+
+@mock_aws
+def test_create_snapshot_with_non_existing_cluster_fails():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ with pytest.raises(ClientError) as ex:
+ client.create_snapshot(ClusterName="foobar", SnapshotName="my-snapshot-1")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_create_duplicate_snapshot_fails():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ client.create_snapshot(
+ ClusterName=cluster["Cluster"]["Name"], SnapshotName="my-snapshot-1"
+ )
+ with pytest.raises(ClientError) as ex:
+ client.create_snapshot(
+ ClusterName=cluster["Cluster"]["Name"], SnapshotName="my-snapshot-1"
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SnapshotAlreadyExistsFault"
+
+
+@mock_aws
+def test_describe_clusters():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ resp = client.describe_clusters()
+ assert "Clusters" in resp
+ assert len(resp["Clusters"]) == 2
+ assert "Shards" not in resp["Clusters"][0]
+
+
+@mock_aws
+def test_describe_clusters_with_shard_details():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ resp = client.describe_clusters(
+ ClusterName="test-memory-db-1",
+ ShowShardDetails=True,
+ )
+ assert resp["Clusters"][0]["Name"] == "test-memory-db-1"
+ assert len(resp["Clusters"]) == 1
+ assert "Shards" in resp["Clusters"][0]
+
+
+@mock_aws
+def test_describe_clusters_with_cluster_name():
+ client = boto3.client("memorydb", region_name="ap-southeast-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ resp = client.describe_clusters(
+ ClusterName="test-memory-db-1",
+ )
+ assert resp["Clusters"][0]["Name"] == "test-memory-db-1"
+ assert len(resp["Clusters"]) == 1
+
+
+@mock_aws
+def test_describe_snapshots():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ client.create_snapshot(
+ ClusterName=f"test-memory-db-{i}", SnapshotName=f"my-snapshot-{i}"
+ )
+ resp = client.describe_snapshots()
+ assert "Snapshots" in resp
+ assert len(resp["Snapshots"]) == 2
+ assert resp["Snapshots"][0]["Name"] == "my-snapshot-1"
+
+
+@mock_aws
+def test_describe_snapshots_with_cluster_name():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ client.create_snapshot(
+ ClusterName=f"test-memory-db-{i}", SnapshotName=f"my-snapshot-{i}"
+ )
+ resp = client.describe_snapshots(ClusterName="test-memory-db-2")
+ assert len(resp["Snapshots"]) == 1
+ assert resp["Snapshots"][0]["ClusterConfiguration"]["Name"] == "test-memory-db-2"
+ assert "Shards" not in resp["Snapshots"][0]["ClusterConfiguration"]
+
+
+@mock_aws
+def test_describe_snapshots_with_shard_details():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ client.create_snapshot(
+ ClusterName=f"test-memory-db-{i}", SnapshotName=f"my-snapshot-{i}"
+ )
+ resp = client.describe_snapshots(ClusterName="test-memory-db-2", ShowDetail=True)
+ assert len(resp["Snapshots"]) == 1
+ assert resp["Snapshots"][0]["ClusterConfiguration"]["Name"] == "test-memory-db-2"
+ assert "Shards" in resp["Snapshots"][0]["ClusterConfiguration"]
+
+
+@mock_aws
+def test_describe_snapshots_with_snapshot_name():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ for i in range(1, 3):
+ client.create_cluster(
+ ClusterName=f"test-memory-db-{i}",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ client.create_snapshot(
+ ClusterName=f"test-memory-db-{i}", SnapshotName=f"my-snapshot-{i}"
+ )
+ resp = client.describe_snapshots(
+ SnapshotName="my-snapshot-1",
+ )
+ assert len(resp["Snapshots"]) == 1
+ assert resp["Snapshots"][0]["Name"] == "my-snapshot-1"
+
+
+@mock_aws
+def test_describe_snapshots_with_snapshot_and_cluster():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+
+ client.create_cluster(
+ ClusterName="test-memory-db", NodeType="db.t4g.small", ACLName="open-access"
+ )
+ for i in range(1, 3):
+ client.create_snapshot(
+ ClusterName="test-memory-db", SnapshotName=f"my-snapshot-{i}"
+ )
+ resp = client.describe_snapshots(
+ ClusterName="test-memory-db",
+ SnapshotName="my-snapshot-1",
+ )
+ assert len(resp["Snapshots"]) == 1
+ assert resp["Snapshots"][0]["Name"] == "my-snapshot-1"
+
+
+@mock_aws
+def test_describe_snapshots_with_invalid_cluster():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+
+ resp = client.describe_snapshots(
+ ClusterName="foobar",
+ )
+ assert len(resp["Snapshots"]) == 0
+
+
+@mock_aws
+def test_describe_snapshots_invalid_snapshot_fails():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+
+ with pytest.raises(ClientError) as ex:
+ client.describe_snapshots(SnapshotName="foobar")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SnapshotNotFoundFault"
+
+
+@mock_aws
+def test_describe_snapshots_with_cluster_and_invalid_snapshot_fails():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+
+ client.create_cluster(
+ ClusterName="test-memory-db", NodeType="db.t4g.small", ACLName="open-access"
+ )
+ client.create_snapshot(ClusterName="test-memory-db", SnapshotName="my-snapshot")
+
+ with pytest.raises(ClientError) as ex:
+ client.describe_snapshots(ClusterName="test-memory-db", SnapshotName="foobar")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SnapshotNotFoundFault"
+
+
+@mock_aws
+def test_describe_subnet_groups():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ ec2 = boto3.resource("ec2", region_name="eu-west-1")
+ vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
+ subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24")
+ subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.1.0/24")
+ for i in range(1, 3):
+ client.create_subnet_group(
+ SubnetGroupName=f"my_subnet_group-{i}",
+ Description="This is my subnet group",
+ SubnetIds=[subnet1.id, subnet2.id],
+ )
+ resp = client.describe_subnet_groups()
+ assert "SubnetGroups" in resp
+ assert len(resp["SubnetGroups"]) == 3 # Including default subnet group
+
+
+@mock_aws
+def test_describe_subnet_groups_with_subnet_group_name():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ ec2 = boto3.resource("ec2", region_name="eu-west-1")
+ vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
+ subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24")
+ subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.1.0/24")
+ for i in range(1, 3):
+ client.create_subnet_group(
+ SubnetGroupName=f"my_subnet_group-{i}",
+ Description="This is my subnet group",
+ SubnetIds=[subnet1.id, subnet2.id],
+ )
+ resp = client.describe_subnet_groups(SubnetGroupName="my_subnet_group-1")
+ assert len(resp["SubnetGroups"]) == 1
+ assert resp["SubnetGroups"][0]["Name"] == "my_subnet_group-1"
+
+
+@mock_aws
+def test_describe_subnet_groups_invalid_subnetgroupname_fails():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ with pytest.raises(ClientError) as ex:
+ client.describe_subnet_groups(SubnetGroupName="foobar")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SubnetGroupNotFoundFault"
+
+
+@mock_aws
+def test_list_tags():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ Tags=[
+ {"Key": "foo", "Value": "bar"},
+ ],
+ )
+ resp = client.list_tags(ResourceArn=cluster["Cluster"]["ARN"])
+ assert "TagList" in resp
+ assert len(resp["TagList"]) == 1
+ assert "foo" in resp["TagList"][0]["Key"]
+ assert "bar" in resp["TagList"][0]["Value"]
+
+
+@mock_aws
+def test_list_tags_invalid_cluster_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ with pytest.raises(ClientError) as ex:
+ client.list_tags(
+ ResourceArn=f"arn:aws:memorydb:us-east-1:{ACCOUNT_ID}:cluster/foobar",
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_tag_resource():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ Tags=[
+ {"Key": "key1", "Value": "value1"},
+ ],
+ )
+ resp = client.tag_resource(
+ ResourceArn=cluster["Cluster"]["ARN"],
+ Tags=[
+ {"Key": "key2", "Value": "value2"},
+ ],
+ )
+ assert "TagList" in resp
+ assert len(resp["TagList"]) == 2
+ assert "key2" in resp["TagList"][1]["Key"]
+ assert "value2" in resp["TagList"][1]["Value"]
+
+
+@mock_aws
+def test_tag_resource_invalid_cluster_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ with pytest.raises(ClientError) as ex:
+ client.tag_resource(
+ ResourceArn=f"arn:aws:memorydb:us-east-1:{ACCOUNT_ID}:cluster/foobar",
+ Tags=[{"Key": "key2", "Value": "value2"}],
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_untag_resource():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2", "Value": "value2"}],
+ )
+ resp = client.untag_resource(
+ ResourceArn=cluster["Cluster"]["ARN"],
+ TagKeys=[
+ "key1",
+ ],
+ )
+ assert "TagList" in resp
+ assert len(resp["TagList"]) == 1
+ assert "key2" in resp["TagList"][0]["Key"]
+ assert "value2" in resp["TagList"][0]["Value"]
+
+
+@mock_aws
+def test_untag_resource_invalid_cluster_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ with pytest.raises(ClientError) as ex:
+ client.untag_resource(
+ ResourceArn=f"arn:aws:memorydb:us-east-1:{ACCOUNT_ID}:cluster/foobar",
+ TagKeys=["key1"],
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_untag_resource_invalid_keys_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2", "Value": "value2"}],
+ )
+ with pytest.raises(ClientError) as ex:
+ client.untag_resource(
+ ResourceArn=cluster["Cluster"]["ARN"], TagKeys=["key3", "key4"]
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "TagNotFoundFault"
+
+
+@mock_aws
+def test_update_cluster_replica_count():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ desc_before_update = client.describe_clusters(ShowShardDetails=True)
+ assert desc_before_update["Clusters"][0]["Shards"][0]["NumberOfNodes"] == 2
+ client.update_cluster(
+ ClusterName=cluster["Cluster"]["Name"],
+ Description="Good cluster",
+ MaintenanceWindow="thu:23:00-thu:01:30",
+ ReplicaConfiguration={"ReplicaCount": 2},
+ )
+ desc_after_update = client.describe_clusters(ShowShardDetails=True)
+ cluster_after_update = desc_after_update["Clusters"][0]
+ assert cluster_after_update["Description"] == "Good cluster"
+ assert cluster_after_update["MaintenanceWindow"] == "thu:23:00-thu:01:30"
+ assert cluster_after_update["Shards"][0]["NumberOfNodes"] == 3
+
+
+@mock_aws
+def test_update_cluster_shards():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ desc_before_update = client.describe_clusters(ShowShardDetails=True)
+ assert desc_before_update["Clusters"][0]["NumberOfShards"] == 1
+ client.update_cluster(
+ ClusterName=cluster["Cluster"]["Name"],
+ ShardConfiguration={"ShardCount": 2},
+ )
+ desc_after_update = client.describe_clusters(ShowShardDetails=True)
+ assert desc_after_update["Clusters"][0]["NumberOfShards"] == 2
+
+
+@mock_aws
+def test_update_invalid_cluster_fails():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ with pytest.raises(ClientError) as ex:
+ client.update_cluster(
+ ClusterName="foobar",
+ Description="Good cluster",
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_delete_cluster():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ desc_resp_before = client.describe_clusters()
+ assert len(desc_resp_before["Clusters"]) == 1
+ resp = client.delete_cluster(
+ ClusterName=cluster["Cluster"]["Name"],
+ )
+ assert resp["Cluster"]["Name"] == cluster["Cluster"]["Name"]
+ desc_resp_after = client.describe_clusters()
+ assert len(desc_resp_after["Clusters"]) == 0
+
+
+@mock_aws
+def test_delete_cluster_with_snapshot():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ desc_resp_before = client.describe_snapshots()
+ assert len(desc_resp_before["Snapshots"]) == 0
+ resp = client.delete_cluster(
+ ClusterName=cluster["Cluster"]["Name"],
+ FinalSnapshotName="test-memory-db-snapshot",
+ )
+ assert resp["Cluster"]["Name"] == cluster["Cluster"]["Name"]
+ desc_resp_after = client.describe_snapshots()
+ assert len(desc_resp_after["Snapshots"]) == 1
+ assert desc_resp_after["Snapshots"][0]["Name"] == "test-memory-db-snapshot"
+
+
+@mock_aws
+def test_delete_invalid_cluster_fails():
+ client = boto3.client("memorydb", region_name="eu-west-1")
+ with pytest.raises(ClientError) as ex:
+ client.delete_cluster(
+ ClusterName="foobar",
+ )
+ err = ex.value.response["Error"]
+ assert err["Code"] == "ClusterNotFoundFault"
+
+
+@mock_aws
+def test_delete_snapshot():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ cluster = client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ ACLName="open-access",
+ )
+ snapshot = client.create_snapshot(
+ ClusterName=cluster["Cluster"]["Name"],
+ SnapshotName="my-snapshot-1",
+ )
+ desc_resp_before = client.describe_snapshots()
+ assert len(desc_resp_before["Snapshots"]) == 1
+ resp = client.delete_snapshot(SnapshotName=snapshot["Snapshot"]["Name"])
+ assert "Snapshot" in resp
+ desc_resp_after = client.describe_snapshots()
+ assert len(desc_resp_after["Snapshots"]) == 0
+
+
+@mock_aws
+def test_delete_invalid_snapshot_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ with pytest.raises(ClientError) as ex:
+ client.delete_snapshot(SnapshotName="foobar")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SnapshotNotFoundFault"
+
+
+@mock_aws
+def test_delete_subnet_group():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ subnet_group = create_subnet_group(client, "us-east-2")
+ sg = subnet_group["SubnetGroup"]
+ response = client.describe_subnet_groups()
+ assert len(response["SubnetGroups"]) == 2
+ resp = client.delete_subnet_group(SubnetGroupName=sg["Name"])
+ assert "SubnetGroup" in resp
+ response = client.describe_subnet_groups()
+ assert len(response["SubnetGroups"]) == 1 # default subnet group
+
+
+@mock_aws
+def test_delete_subnet_group_default_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+
+ with pytest.raises(ClientError) as ex:
+ client.delete_subnet_group(SubnetGroupName="default")
+ err = ex.value.response["Error"]
+ assert err["Code"] == "InvalidParameterValueException"
+
+
+@mock_aws
+def test_delete_subnet_group_in_use_fails():
+ client = boto3.client("memorydb", region_name="us-east-2")
+ subnet_group = create_subnet_group(client, "us-east-2")
+ client.create_cluster(
+ ClusterName="test-memory-db",
+ NodeType="db.t4g.small",
+ SubnetGroupName=subnet_group["SubnetGroup"]["Name"],
+ ACLName="open-access",
+ )
+ with pytest.raises(ClientError) as ex:
+ client.delete_subnet_group(SubnetGroupName=subnet_group["SubnetGroup"]["Name"])
+ err = ex.value.response["Error"]
+ assert err["Code"] == "SubnetGroupInUseFault"