Skip to content

Commit

Permalink
MLCOMPUTE-1497 | add methods to get total driver memory including ove…
Browse files Browse the repository at this point in the history
…rhead (#146)

* MLCOMPUTE-1209 | add methods to get total driver memory including overhead

* fix tests

* MLCOMPUTE-1209 | bump up version

---------

Co-authored-by: Sameer Sharma <sameersharma@yelp.com>
  • Loading branch information
CaptainSame and Sameer Sharma authored Aug 6, 2024
1 parent 9719604 commit d507c51
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 1 deletion.
4 changes: 4 additions & 0 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ def _filter_user_spark_opts(user_spark_opts: Mapping[str, str]) -> MutableMappin
}


def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int:
return int(utils.get_spark_driver_memory_mb(spark_conf) + utils.get_spark_driver_memory_overhead_mb(spark_conf))


class SparkConfBuilder:

def __init__(self):
Expand Down
60 changes: 60 additions & 0 deletions service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from socket import SO_REUSEADDR
from socket import socket
from socket import SOL_SOCKET
from typing import Dict
from typing import Mapping
from typing import Tuple

import yaml
from typing_extensions import Literal

DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml'
Expand All @@ -24,6 +26,11 @@
EPHEMERAL_PORT_START = 49152
EPHEMERAL_PORT_END = 65535

MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4}

SPARK_DRIVER_MEM_DEFAULT_MB = 2048
SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT = 0.1


log = logging.Logger(__name__)
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -148,3 +155,56 @@ def get_runtime_env() -> str:
# we could also just crash or return None, but this seems a little easier to find
# should we somehow run into this at Yelp
return 'unknown'


def get_spark_memory_in_unit(mem: str, unit: Literal['k', 'm', 'g', 't']) -> float:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
unit can be 'k', 'm', 'g' or 't'.
Returns memory as a float converted to the desired unit.
"""
try:
memory_bytes = float(mem)
except ValueError:
try:
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
except (ValueError, IndexError):
print(f'Unable to parse memory value {mem}.')
raise
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
return round(memory_unit, 5)


def get_spark_driver_memory_mb(spark_conf: Dict[str, str]) -> float:
"""
Returns the Spark driver memory in MB.
"""
# spark_conf is expected to have "spark.driver.memory" since it is a mandatory default from srv-configs.
driver_mem = spark_conf['spark.driver.memory']
try:
return get_spark_memory_in_unit(str(driver_mem), 'm')
except (ValueError, IndexError):
return SPARK_DRIVER_MEM_DEFAULT_MB


def get_spark_driver_memory_overhead_mb(spark_conf: Dict[str, str]) -> float:
"""
Returns the Spark driver memory overhead in bytes.
"""
# Use spark.driver.memoryOverhead if it is set.
try:
driver_mem_overhead = spark_conf['spark.driver.memoryOverhead']
try:
# spark.driver.memoryOverhead default unit is MB
driver_mem_overhead_mb = float(driver_mem_overhead)
except ValueError:
driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), 'm')
# Calculate spark.driver.memoryOverhead based on spark.driver.memory and spark.driver.memoryOverheadFactor.
except Exception:
driver_mem_mb = get_spark_driver_memory_mb(spark_conf)
driver_mem_overhead_factor = float(
spark_conf.get('spark.driver.memoryOverheadFactor', SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT),
)
driver_mem_overhead_mb = driver_mem_mb * driver_mem_overhead_factor
return round(driver_mem_overhead_mb, 5)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name='service-configuration-lib',
version='2.18.19',
version='2.18.20',
provides=['service_configuration_lib'],
description='Start, stop, and inspect Yelp SOA services',
url='https://github.com/Yelp/service_configuration_lib',
Expand Down
78 changes: 78 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from socket import SO_REUSEADDR
from socket import socket as Socket
from socket import SOL_SOCKET
from typing import cast
from unittest import mock
from unittest.mock import mock_open
from unittest.mock import patch

import pytest
from typing_extensions import Literal

from service_configuration_lib import utils
from service_configuration_lib.utils import ephemeral_port_reserve_range
Expand Down Expand Up @@ -74,6 +76,82 @@ def test_generate_pod_template_path(hex_value):
assert utils.generate_pod_template_path() == f'/nail/tmp/spark-pt-{hex_value}.yaml'


@pytest.mark.parametrize(
'mem_str,unit_str,expected_mem',
(
('13425m', 'm', 13425), # Simple case
('138412032', 'm', 132), # Bytes to MB
('65536k', 'g', 0.0625), # KB to GB
('1t', 'g', 1024), # TB to GB
('1.5g', 'm', 1536), # GB to MB with decimal
('2048k', 'm', 2), # KB to MB
('0.5g', 'k', 524288), # GB to KB
('32768m', 't', 0.03125), # MB to TB
('1.5t', 'm', 1572864), # TB to MB with decimal
),
)
def test_get_spark_memory_in_unit(mem_str, unit_str, expected_mem):
assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


@pytest.mark.parametrize(
'mem_str,unit_str',
[
('invalid', 'm'),
('1024mb', 'g'),
],
)
def test_get_spark_memory_in_unit_exceptions(mem_str, unit_str):
with pytest.raises((ValueError, IndexError)):
utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


@pytest.mark.parametrize(
'spark_conf,expected_mem',
[
({'spark.driver.memory': '13425m'}, 13425), # Simple case
({'spark.driver.memory': '138412032'}, 132), # Bytes to MB
({'spark.driver.memory': '65536k'}, 64), # KB to MB
({'spark.driver.memory': '1g'}, 1024), # GB to MB
({'spark.driver.memory': 'invalid'}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case
({'spark.driver.memory': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k'}, 2), # KB to MB
({'spark.driver.memory': '0.5t'}, 524288), # TB to MB
({'spark.driver.memory': '1024m'}, 1024), # MB to MB
({'spark.driver.memory': '1.5t'}, 1572864), # TB to MB with decimal
],
)
def test_get_spark_driver_memory_mb(spark_conf, expected_mem):
assert expected_mem == utils.get_spark_driver_memory_mb(spark_conf)


@pytest.mark.parametrize(
'spark_conf,expected_mem_overhead',
[
({'spark.driver.memoryOverhead': '1024'}, 1024), # Simple case
({'spark.driver.memoryOverhead': '1g'}, 1024), # GB to MB
({'spark.driver.memory': '10240m', 'spark.driver.memoryOverheadFactor': '0.2'}, 2048), # Custom OverheadFactor
({'spark.driver.memory': '10240m'}, 1024), # Using default overhead factor
(
{'spark.driver.memory': 'invalid'},
utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT,
),
# Invalid case
({'spark.driver.memoryOverhead': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k', 'spark.driver.memoryOverheadFactor': '0.05'}, 0.1),
# KB to MB with custom factor
({'spark.driver.memory': '0.5t', 'spark.driver.memoryOverheadFactor': '0.15'}, 78643.2),
# TB to MB with custom factor
({'spark.driver.memory': '1024m', 'spark.driver.memoryOverheadFactor': '0.25'}, 256),
# MB to MB with custom factor
({'spark.driver.memory': '1.5t', 'spark.driver.memoryOverheadFactor': '0.05'}, 78643.2),
# TB to MB with custom factor
],
)
def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead):
assert expected_mem_overhead == utils.get_spark_driver_memory_overhead_mb(spark_conf)


@pytest.fixture
def mock_runtimeenv():
with patch('builtins.open', mock_open(read_data=MOCK_ENV_NAME)) as m:
Expand Down

0 comments on commit d507c51

Please sign in to comment.