Skip to content

Commit

Permalink
Merge pull request #798 from GoogleCloudPlatform/aws-regional-network…
Browse files Browse the repository at this point in the history
…-in-spec

Create only one AwsRegionalNetwork per region per benchmark run.
  • Loading branch information
skschneider committed Jan 13, 2016
2 parents 3c3c956 + 36fe39f commit 29ce7e2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
12 changes: 11 additions & 1 deletion perfkitbenchmarker/benchmark_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,17 @@ def Prepare(self):

def Provision(self):
"""Prepares the VMs and networks necessary for the benchmark to run."""
vm_util.RunThreaded(lambda net: net.Create(), self.networks.values())
# Sort networks into a guaranteed order of creation based on dict key.
# There is a finite limit on the number of threads that are created to
# provision networks. Until support is added to provision resources in an
# order based on dependencies, this key ordering can be used to avoid
# deadlock by placing dependent networks later and their dependencies
# earlier. As an example, AWS stores both per-region and per-zone objects
# in this dict, and each per-zone object depends on a corresponding
# per-region object, so the per-region objects are given keys that come
# first when sorted.
networks = [self.networks[key] for key in sorted(self.networks.iterkeys())]
vm_util.RunThreaded(lambda net: net.Create(), networks)

if self.vms:
vm_util.RunThreaded(self.PrepareVm, self.vms)
Expand Down
47 changes: 33 additions & 14 deletions perfkitbenchmarker/providers/aws/aws_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import threading
import uuid

from perfkitbenchmarker import context
from perfkitbenchmarker import errors
from perfkitbenchmarker import flags
from perfkitbenchmarker import network
from perfkitbenchmarker import resource
Expand All @@ -36,6 +38,10 @@
FLAGS = flags.FLAGS


REGION = 'region'
ZONE = 'zone'


class AwsFirewall(network.BaseFirewall):
"""An object representing the AWS Firewall."""

Expand Down Expand Up @@ -407,22 +413,18 @@ def _Exists(self):
return len(placement_groups) > 0


class AwsRegionalNetwork(network.BaseNetwork):
class _AwsRegionalNetwork(network.BaseNetwork):
"""Object representing regional components of an AWS network.
This class maintains a singleton-per-region; acquire instances via
AwsRegionalNetwork.GetForRegion.
The benchmark spec contains one instance of this class per region, which an
AwsNetwork may retrieve or create via _AwsRegionalNetwork.GetForRegion.
Attributes:
region: string. The AWS region.
vpc: an AwsVpc instance.
internet_gateway: an AwsInternetGateway instance.
route_table: an AwsRouteTable instance. The default route table.
"""
# Map from region to AwsRegionalNetwork
_network_pool = {}
# Lock protecting _network_pool
_network_pool_lock = threading.Lock()

def __init__(self, region):
self.region = region
Expand All @@ -434,7 +436,7 @@ def __init__(self, region):
# Locks to ensure that a single thread creates / deletes the instance.
self._create_lock = threading.Lock()

# Tracks the number of AwsNetworks using this AwsRegionalNetwork.
# Tracks the number of AwsNetworks using this _AwsRegionalNetwork.
# Incremented by Create(); decremented by Delete();
# When a Delete() call decrements _reference_count to 0, the RegionalNetwork
# is destroyed.
Expand All @@ -443,15 +445,27 @@ def __init__(self, region):

@classmethod
def GetForRegion(cls, region):
"""Gets the AwsRegionalNetwork for a given AWS region.
"""Retrieves or creates an _AwsRegionalNetwork.
Args:
region: str. A Region name.
region: string. AWS region name.
Returns:
The AwsRegionalNetwork for 'region'.
_AwsRegionalNetwork. If an _AwsRegionalNetwork for the same region already
exists in the benchmark spec, that instance is returned. Otherwise, a new
_AwsRegionalNetwork is created and returned.
"""
with cls._network_pool_lock:
return cls._network_pool.setdefault(region, cls(region))
benchmark_spec = context.GetThreadBenchmarkSpec()
if benchmark_spec is None:
raise errors.Error('GetNetwork called in a thread without a '
'BenchmarkSpec.')
key = cls.CLOUD, REGION, region
# Because this method is only called from the AwsNetwork constructor, which
# is only called from AwsNetwork.GetNetwork, we already hold the
# benchmark_spec.networks_lock.
if key not in benchmark_spec.networks:
benchmark_spec.networks[key] = cls(region)
return benchmark_spec.networks[key]

def Create(self):
"""Creates the network."""
Expand Down Expand Up @@ -512,7 +526,7 @@ def __init__(self, spec):
"""
super(AwsNetwork, self).__init__(spec)
self.region = util.GetRegionFromZone(spec.zone)
self.regional_network = AwsRegionalNetwork.GetForRegion(self.region)
self.regional_network = _AwsRegionalNetwork.GetForRegion(self.region)
self.subnet = None
self.placement_group = AwsPlacementGroup(self.region)

Expand All @@ -533,3 +547,8 @@ def Delete(self):
self.subnet.Delete()
self.placement_group.Delete()
self.regional_network.Delete()

@classmethod
def _GetKeyFromNetworkSpec(cls, spec):
"""Returns a key used to register Network instances."""
return (cls.CLOUD, ZONE, spec.zone)

0 comments on commit 29ce7e2

Please sign in to comment.