-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
456 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
#!/usr/bin/env python | ||
# Multiply a n by m matrix and a m by k matrix | ||
import random | ||
from typing import List | ||
|
||
[l, n, m] = [18, 15, 16] | ||
|
||
def generate_mat(m: int, n: int, min_value: int = -8, max_value: int = 8) -> List[List[int]]: | ||
""" | ||
Generates an m x n matrix filled with random floating-point numbers. | ||
:param m: Number of rows | ||
:param n: Number of columns | ||
:param min_value: Minimum value for random floats (inclusive) | ||
:param max_value: Maximum value for random floats (inclusive) | ||
:return: m x n matrix with random floating-point numbers | ||
""" | ||
return [[random.randint(min_value, max_value) for _ in range(n)] for _ in range(m)] | ||
|
||
|
||
indent = 2 | ||
indent_str = " " * indent | ||
|
||
def format_mat(symbol: str, m: int, n: int, mat: List[List[int]]) -> str: | ||
ret = "" | ||
for row in range(m): | ||
for col in range(n): | ||
ret += indent_str if col == 0 else "; " | ||
ret += f"{symbol}[{row}][{col}] = {mat[row][col]}" | ||
ret += ";\n" | ||
|
||
return ret | ||
|
||
|
||
operate_mat = \ | ||
f""" | ||
let a = gen_arr({l}, {n}); | ||
let b = gen_arr({n}, {m}); | ||
let c = gen_arr({l}, {m}); | ||
""" | ||
|
||
mat_a = generate_mat(l, n) | ||
operate_mat += "\n" + format_mat("a", l, n, mat_a) | ||
|
||
mat_b = generate_mat(n, m) | ||
operate_mat += "\n" + format_mat("b", n, m, mat_b) | ||
|
||
operate_mat += "\n" + indent_str + f"let _ = matmul({l},{n},{m},a,b,c);" | ||
|
||
|
||
operate_mat += "\n" + indent_str + f"let _ = matshow({l}, {n}, a);" | ||
operate_mat += "\n" + indent_str + f"let _ = print_endline();" | ||
operate_mat += "\n" + indent_str + f"let _ = matshow({n}, {m}, b);" | ||
operate_mat += "\n" + indent_str + f"let _ = print_endline();" | ||
operate_mat += "\n" + indent_str + f"let _ = matshow({l}, {m}, c);" | ||
operate_mat += "\n" + indent_str + f"()" | ||
operate_mat += "\n" | ||
|
||
generated = """ | ||
fn matshow(m: Int, n: Int, mat: Array[Array[Int]]) -> Unit { | ||
fn loop1(i: Int) -> Unit { | ||
if i <= m - 1 { | ||
fn loop2(j: Int) -> Unit { | ||
if j <= n - 1 { | ||
let _ = print_int(mat[i][j]); | ||
let _ = print_char(32); | ||
loop2(j+1) | ||
} else { | ||
print_endline() | ||
} | ||
}; | ||
let _ = loop2(0); | ||
loop1(i+1) | ||
} else { | ||
() | ||
} | ||
}; | ||
loop1(0) | ||
}; | ||
fn matmul(l: Int, m: Int, n: Int, a: Array[Array[Int]], b: Array[Array[Int]], c: Array[Array[Int]]) -> Unit { | ||
fn loop1(i: Int) -> Unit { | ||
if 0 <= i { | ||
fn loop2(j: Int) -> Unit { | ||
if 0 <= j { | ||
fn loop3(k: Int) -> Unit { | ||
if 0 <= k { | ||
c[i][j] = c[i][j] + a[i][k] * b[k][j]; | ||
loop3(k - 1) | ||
} else { | ||
() | ||
} | ||
}; | ||
let _ = loop3(m - 1); | ||
loop2(j - 1) | ||
} else { | ||
() | ||
} | ||
}; | ||
let _ = loop2(n - 1); | ||
loop1(i - 1) | ||
} else { | ||
() | ||
} | ||
}; | ||
loop1(l - 1) | ||
}; | ||
fn main { | ||
let dummy = Array::make(0, 0); | ||
fn gen_arr(m: Int, n: Int) -> Array[Array[Int]] { | ||
let mat = Array::make(m, dummy); | ||
fn init_arr(i: Int) -> Unit { | ||
if 0 <= i { | ||
mat[i] = Array::make(n, 0); | ||
init_arr(i - 1) | ||
} else { | ||
() | ||
} | ||
}; | ||
let _ = init_arr(m - 1); | ||
mat | ||
}; | ||
""" \ | ||
+ operate_mat + \ | ||
""" | ||
};""" | ||
|
||
source_path = "../test_src/matmul_gen_int.mbt" | ||
ans_path = "../test_src/matmul_gen_int.ans" | ||
|
||
with open(source_path, 'w') as source_file: | ||
source_file.write(generated) | ||
|
||
def matrix_multiply(matrix_a, matrix_b): | ||
# Get the dimensions of the matrices | ||
rows_a = len(matrix_a) | ||
cols_a = len(matrix_a[0]) | ||
rows_b = len(matrix_b) | ||
cols_b = len(matrix_b[0]) | ||
|
||
# Check if multiplication is possible | ||
if cols_a != rows_b: | ||
raise ValueError("Cannot multiply: The number of columns in matrix A must equal the number of rows in matrix B.") | ||
|
||
# Initialize the result matrix with zeros | ||
result = [[0 for _ in range(cols_b)] for _ in range(rows_a)] | ||
|
||
# Perform multiplication | ||
for i in range(rows_a): | ||
for j in range(cols_b): | ||
for k in range(cols_a): | ||
result[i][j] += matrix_a[i][k] * matrix_b[k][j] | ||
|
||
return result | ||
|
||
mat_c = matrix_multiply(mat_a, mat_b) | ||
|
||
with open(ans_path, 'w') as ans_file: | ||
for row in mat_a: | ||
for ele in row: | ||
ans_file.write(f'{int(ele)} ') | ||
ans_file.write('\n') | ||
ans_file.write('\n') | ||
|
||
for row in mat_b: | ||
for ele in row: | ||
ans_file.write(f'{int(ele)} ') | ||
ans_file.write('\n') | ||
ans_file.write('\n') | ||
|
||
for row in mat_c: | ||
for ele in row: | ||
ans_file.write(f'{int(ele)} ') | ||
ans_file.write('\n') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,32 @@ | ||
-67 24 124 106 -53 -82 -114 82 -165 15 | ||
-30 -93 -94 -36 -44 23 -56 58 -15 29 | ||
-85 -91 171 67 132 5 22 116 -112 12 | ||
152 4 -76 29 -39 -40 -139 -28 -75 25 | ||
-49 137 208 21 72 46 28 -1 -39 -2 | ||
104 -9 -213 41 -71 -39 -22 -181 67 34 | ||
45 122 -27 -54 -18 -15 102 101 -86 -141 | ||
-145 8 -90 -46 -180 -57 190 48 71 -47 | ||
117 -184 -129 -50 179 83 47 -42 40 42 | ||
-38 -99 57 140 124 88 -91 35 57 153 | ||
6 7 7 6 5 6 8 8 2 4 | ||
8 2 5 2 9 6 0 2 5 9 | ||
2 7 8 8 3 9 7 2 3 2 | ||
1 1 9 6 3 4 4 7 6 9 | ||
8 3 2 9 6 9 2 2 7 6 | ||
8 6 5 6 1 2 9 0 2 9 | ||
3 1 4 7 1 2 1 7 9 0 | ||
9 9 2 2 2 1 9 0 3 5 | ||
7 9 2 4 5 1 7 10 7 7 | ||
1 8 4 9 7 6 5 7 2 3 | ||
|
||
2 7 1 7 3 5 9 8 7 8 | ||
8 4 4 1 3 9 6 5 9 2 | ||
5 1 5 5 8 4 0 3 3 2 | ||
6 6 5 8 8 8 9 0 0 3 | ||
2 1 6 8 2 4 7 2 8 0 | ||
2 4 5 0 4 8 9 1 2 2 | ||
8 3 0 9 4 1 2 1 2 5 | ||
9 5 8 1 7 3 4 7 7 3 | ||
9 0 1 1 1 6 8 1 1 6 | ||
4 5 7 9 2 2 5 9 8 8 | ||
|
||
391 277 307 340 327 348 380 281 342 255 | ||
264 222 274 311 217 296 370 262 319 250 | ||
335 227 250 286 288 343 350 188 239 200 | ||
354 210 291 302 282 277 311 226 254 233 | ||
312 258 264 323 254 366 451 235 280 269 | ||
317 247 207 351 241 281 329 253 277 277 | ||
289 156 180 173 207 243 281 144 160 177 | ||
295 220 168 284 193 267 306 240 283 248 | ||
430 279 288 332 281 338 399 319 379 300 | ||
365 247 301 300 297 340 371 225 294 197 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.