Skip to content

Commit

Permalink
Implemented strong scaling benchmark for mat_mul.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Com1t committed Sep 14, 2023
1 parent 59f8e72 commit 1241422
Showing 1 changed file with 59 additions and 17 deletions.
76 changes: 59 additions & 17 deletions samples/mpi/mat_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@
import time


def matrix_mul(comm_world, a, b):
# check if matrix multiplication is valid
if a.shape[1] != b.shape[0]:
raise ValueError("A, B matrix dimension mismatched!")

def matrix_mul(comm_world, dim_1, dim_2):
# comm init
comm_rank = comm_world.Get_rank()
comm_size = comm_world.Get_size()

a_mat = np.array(a + comm_rank, dtype=np.float32)
b_mat = np.array(b + comm_rank, dtype=np.float32)
c_mat = np.zeros((a_mat.shape[0], b_mat.shape[1]), dtype=np.float32)
a_mat = np.full((dim_1, dim_2), 1 + comm_rank, dtype=np.float32)
b_mat = np.full((dim_2, dim_1), 1 + comm_rank, dtype=np.float32)
c_mat = np.zeros((dim_1, dim_1), dtype=np.float32)

@dace.program
def dist_mat_mult(a_mat: dace.float32[a_mat.shape[0], a_mat.shape[1]],
Expand Down Expand Up @@ -83,22 +79,15 @@ def dist_mat_mult(a_mat: dace.float32[a_mat.shape[0], a_mat.shape[1]],
return c_mat, time_con


if __name__ == "__main__":
comm_world = MPI.COMM_WORLD
comm_rank = comm_world.Get_rank()
comm_size = comm_world.Get_size()

def weak_scaling(comm_world, comm_rank, comm_size):
grid_dim = int(np.floor(np.sqrt(comm_size)))
grid_i = comm_rank // grid_dim
grid_j = comm_rank % grid_dim

dim_1 = 1024
dim_2 = 1024

a = np.ones((dim_1, dim_2), dtype=np.float32)
b = np.ones((dim_2, dim_1), dtype=np.float32)

c_mat, time_con = matrix_mul(comm_world, a, b)
c_mat, time_con = matrix_mul(comm_world, dim_1, dim_2)
# print(comm_rank, c_mat)
# print(comm_rank, "matrix_mul time:", time_con)

Expand All @@ -121,3 +110,56 @@ def dist_mat_mult(a_mat: dace.float32[a_mat.shape[0], a_mat.shape[1]],

# print("Result correctness:", np.allclose(c_mat, c_np[grid_i * dim_1:(grid_i+1) * dim_1, grid_j* dim_2:(grid_j+1) * dim_2]))
assert(np.allclose(c_mat, c_np[grid_i * dim_1:(grid_i+1) * dim_1, grid_j* dim_2:(grid_j+1) * dim_2]))


def strong_scaling(comm_world, comm_rank, comm_size):
grid_dim = int(np.floor(np.sqrt(comm_size)))
grid_i = comm_rank // grid_dim
grid_j = comm_rank % grid_dim

total_dim = 8192
dim_1 = total_dim
dim_2 = total_dim
if total_dim % comm_size > 0:
dim_1 += comm_size - total_dim % comm_size
dim_2 += comm_size - total_dim % comm_size

local_dim_1 = dim_1 // grid_dim
local_dim_2 = dim_2 // grid_dim

a = np.ones((local_dim_1, local_dim_2), dtype=np.float32)
b = np.ones((local_dim_2, local_dim_1), dtype=np.float32)

c_mat, time_con = matrix_mul(comm_world, local_dim_1, local_dim_2)
# print(comm_rank, c_mat)
# print(comm_rank, "matrix_mul time:", time_con)

# validation, since it will compute the whole matrix in the edge
# whole_a = np.ones((local_dim_1 * grid_dim, local_dim_2 * grid_dim), dtype=np.float32)
# for i in range(grid_dim):
# for j in range(grid_dim):
# whole_a[i * local_dim_1:(i+1) * local_dim_1, j * local_dim_2:(j+1) * local_dim_2] += i * grid_dim + j

# whole_b = np.ones((local_dim_2 * grid_dim, local_dim_1 * grid_dim), dtype=np.float32)
# for i in range(grid_dim):
# for j in range(grid_dim):
# whole_b[i * local_dim_2:(i+1) * local_dim_2, j * local_dim_1:(j+1) * local_dim_1] += i * grid_dim + j

# start = time.time()
# c_np = np.matmul(whole_a, whole_b)
# time_con = time.time() - start
# # print("Result correctness:", np.allclose(c_mat, c_np[grid_i * local_dim_1:(grid_i+1) * local_dim_1, grid_j* local_dim_2:(grid_j+1) * local_dim_2]))
# assert(np.allclose(c_mat, c_np[grid_i * local_dim_1:(grid_i+1) * local_dim_1, grid_j* local_dim_2:(grid_j+1) * local_dim_2]))

if __name__ == "__main__":
comm_world = MPI.COMM_WORLD
comm_rank = comm_world.Get_rank()
comm_size = comm_world.Get_size()

grid_dim = int(np.floor(np.sqrt(comm_size)))

if comm_size != grid_dim ** 2:
raise ValueError("Please run this test with a square number of processes.")

# weak_scaling(comm_world, comm_rank, comm_size)
strong_scaling(comm_world, comm_rank, comm_size)

0 comments on commit 1241422

Please sign in to comment.