-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
85 lines (65 loc) · 1.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright 2024 Samsung Electronics Co., Ltd. All Rights Reserved
import json
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Union, List
def parse_hostfile(file_path: str) -> Dict[int, Dict[str, Union[str, int]]]:
num_node, hostfile_info = 0, dict()
with open(file_path, 'rt') as hostfile:
line = hostfile.readline()
while line:
splitted_data = line.split(' ')
ip = splitted_data[0]
num_device = int(splitted_data[1][6:7])
hostfile_info[num_node] = dict()
hostfile_info[num_node]["ip"] = ip
hostfile_info[num_node]["num_device"] = num_device
line = hostfile.readline()
num_node += 1
return hostfile_info
def parse_nodefile(file_path: str) -> Dict[str, Dict[str, Union[str, int]]]:
with open(file_path, 'r') as content:
clusters = json.loads(content.read())
return clusters
def factor(N: int, upper: int = None, lower: int = None) -> List:
if upper is None:
upper = N
ret = []
for i in range(1, upper + 1):
if N % i == 0:
if lower is None or i >= lower:
ret.append(i)
return ret
class DeviceType(Enum):
A100 = "a100"
V100 = "v100"
P100 = "p100"
T4 = "t4"
@staticmethod
def from_string(s: str) -> 'DeviceType':
try:
return DeviceType[s.upper()]
except KeyError:
raise ValueError
@dataclass
class ResourceConfig:
device_type: DeviceType
inter_bw: int
intra_bw: int
num_nodes: int
num_devices: int
total_devices: int
device_memory: int
@dataclass
class ModelConfig:
num_layers: int
hidden_size: int
sequence_length: int
vocab_size: int
hidden_size: int
attention_head_size: int
model_name: str
@dataclass
class GPUNode:
device_type: DeviceType
num_devices: int