From dcfca38d9c7b1ff12aca71432b342d09a40f73f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacob=20Odg=C3=A5rd=20T=C3=B8rring?= Date: Wed, 24 Jan 2024 12:18:11 +0100 Subject: [PATCH] Updates GEMM with proper GFLOPs setup --- .../kernelbackend/kerneltuner_runner.py | 23 +- batbench/benchmarks/GEMM/GEMM-CAFF.json | 46 +-- batbench/benchmarks/GEMM/gemm.cu | 299 +++--------------- batbench/benchmarks/GEMM/gemm.py | 12 +- batbench/config_space/cuda_problem.py | 8 +- batbench/manager/manager.py | 7 +- batbench/result/dataset.py | 10 +- .../kerneltuner_runner/kerneltuner_runner.py | 12 +- 8 files changed, 119 insertions(+), 298 deletions(-) diff --git a/batbench/backends/kernelbackend/kerneltuner_runner.py b/batbench/backends/kernelbackend/kerneltuner_runner.py index e4b72d3..b206621 100644 --- a/batbench/backends/kernelbackend/kerneltuner_runner.py +++ b/batbench/backends/kernelbackend/kerneltuner_runner.py @@ -24,12 +24,13 @@ class KernelBackend: DEFAULT_OBJECTIVE = TIME def __init__(self, spec, config_space, args: Arguments, - cuda_backend="Cupy", metrics=None, objective=DEFAULT_OBJECTIVE): + cuda_backend="Cupy", metrics=None): self.spec = spec self.config_space = config_space self.kernel_spec = self.spec["KernelSpecification"] + self.objective = self.spec['General'].get('Objective', self.DEFAULT_OBJECTIVE) + self.minimize = self.spec['General'].get('Minimize', True) self.metrics = metrics - self.objective = objective self.args = args self.function_args = self.args.get_function_args() @@ -148,11 +149,17 @@ def evaluate_gridsize(self, gridsizes, dimension): def extract_param_names(self, gridsize): return [node.id for node in ast.walk(ast.parse(gridsize)) if isinstance(node, ast.Name)] + def wrap_variables_in_gridsize(self, gridsize, paramnames): for paramname in paramnames: - # prevents multiple occurrences and avoids matching substrings - if not re.search(f"\b{paramname}\b", gridsize): - gridsize = gridsize.replace(paramname, f"p['{paramname}']") + # Using a regular expression to ensure that whole words are matched + pattern = r'\b' + re.escape(paramname) + r'\b' + replacement = f"p['{paramname}']" + + # Check if the parameter name is already wrapped + wrapped_pattern = f"p\\['{paramname}'\\]" + if not re.search(wrapped_pattern, gridsize): + gridsize = re.sub(pattern, replacement, gridsize) return gridsize def validate_problemsize_length(self, problemsizes, gridsizes): @@ -164,6 +171,7 @@ def update_invalid_result(self, result, msg, error=None): result.validity = msg result.correctness = 0 result.runtimes = [0] + result.objective = 10000 if self.minimize else 0 if error: result.error = error return result @@ -172,7 +180,9 @@ def update_invalid_result(self, result, msg, error=None): def update_result(self, result, kt_result): result.runtimes = [t/1000 for t in kt_result["times"]] result.runtime = sum(result.runtimes) - result.objective = kt_result[self.objective]/1000 + result.objective = kt_result[self.objective] + if self.objective == self.TIME: + result.objective /= 1000 result.compile_time = kt_result["compile_time"]/1000 #result.time = kt_result["verification_time"] #result.time = kt_result["benchmark_time"] @@ -189,6 +199,7 @@ def run_reference(self, tuning_config): self.opts["compiler_options"], None, self.opts["block_size_names"], self.opts["quiet"], None) answer_list = [None] * len(res) + for key in self.args.output_args: idx = self.args.args[key]["index"] self.args.add_reference_value(key, res[idx]) diff --git a/batbench/benchmarks/GEMM/GEMM-CAFF.json b/batbench/benchmarks/GEMM/GEMM-CAFF.json index 9810750..9ce4f74 100644 --- a/batbench/benchmarks/GEMM/GEMM-CAFF.json +++ b/batbench/benchmarks/GEMM/GEMM-CAFF.json @@ -1,7 +1,9 @@ { "General": { "BenchmarkName": "GEMM", - "OutputFormat": "JSON" + "OutputFormat": "JSON", + "Objective": "GFLOPs", + "Minimize": false }, "ConfigurationSpace": { "TuningParameters": [ @@ -68,25 +70,25 @@ { "Name": "STRM", "Type": "int", - "Values": "[0, 1]", + "Values": "[0]", "Default": 0 }, { "Name": "STRN", "Type": "int", - "Values": "[0, 1]", + "Values": "[0]", "Default": 0 }, { "Name": "SA", "Type": "int", - "Values": "[0]", + "Values": "[0, 1]", "Default": 0 }, { "Name": "SB", "Type": "int", - "Values": "[0]", + "Values": "[0, 1]", "Default": 0 }, { @@ -102,20 +104,20 @@ "Parameters": ["KWG", "KWI"] }, { - "Expression": "(MWG % (MDIMC * VWM)) == 0", - "Parameters": ["MWG", "MDIMC", "VWM"] + "Expression": "(MWG % MDIMC) == 0", + "Parameters": ["MWG", "MDIMC"] }, { - "Expression": "(NWG % (NDIMC * VWN)) == 0", - "Parameters": ["NWG", "NDIMC", "VWN"] + "Expression": "(NWG % NDIMC) == 0", + "Parameters": ["NWG", "NDIMC"] }, { - "Expression": "(MWG % (MDIMA * VWM)) == 0", - "Parameters": ["MWG", "MDIMA", "VWM"] + "Expression": "(MWG % MDIMA) == 0", + "Parameters": ["MWG", "MDIMA"] }, { - "Expression": "(NWG % (NDIMB * VWN)) == 0", - "Parameters": ["NWG", "NDIMB", "VWN"] + "Expression": "(NWG % NDIMB) == 0", + "Parameters": ["NWG", "NDIMB"] }, { "Expression": "(KWG % ((MDIMC * NDIMC) // MDIMA)) == 0", @@ -145,11 +147,11 @@ "Z": "1" }, "GlobalSize": { - "X": "(16384 * MDIMC) // MWG", - "Y": "(16384 * NDIMC) // NWG", + "X": "4096 // MWG", + "Y": "4096 // NWG", "Z": "1" }, - "SharedMemory": 16384, + "SharedMemory": 49152, "Stream": null, "Arguments": [ { @@ -157,28 +159,28 @@ "Type": "int32", "MemoryType": "Scalar", "AccessType": "ReadOnly", - "FillValue": 16384 + "FillValue": 4096 }, { "Name": "kSizeN", "Type": "int32", "MemoryType": "Scalar", "AccessType": "ReadOnly", - "FillValue": 16384 + "FillValue": 4096 }, { "Name": "kSizeK", "Type": "int32", "MemoryType": "Scalar", "AccessType": "ReadOnly", - "FillValue": 16384 + "FillValue": 4096 }, { "Name": "agm", "Type": "float", "MemoryType": "Vector", "AccessType": "ReadOnly", - "Size": 16384, + "Size": 16777216, "FillType": "Random", "FillValue": 1.0 }, @@ -187,7 +189,7 @@ "Type": "float", "MemoryType": "Vector", "AccessType": "ReadOnly", - "Size": 16384, + "Size": 16777216, "FillType": "Random", "FillValue": 1.0 }, @@ -196,7 +198,7 @@ "Type": "float", "MemoryType": "Vector", "AccessType": "WriteOnly", - "Size": 16384, + "Size": 16777216, "FillType": "Constant", "FillValue": 0.0, "Output": 1 diff --git a/batbench/benchmarks/GEMM/gemm.cu b/batbench/benchmarks/GEMM/gemm.cu index b6a9047..56dd565 100644 --- a/batbench/benchmarks/GEMM/gemm.cu +++ b/batbench/benchmarks/GEMM/gemm.cu @@ -82,159 +82,16 @@ // Settings #define USE_VECTOR_MAD 1 // Don't unroll the vector MAD computation #define USE_CL_MAD 0 // Uses the non-IEEE754 compliant OpenCL mad() (if above is 0) - -// ================================================================================================= - -// Data-type: single or double precision -#if PRECISION == 32 - typedef float real; - typedef float2 real2; - typedef float4 real4; - #define ZERO 0.0f -#elif PRECISION == 64 - #if __OPENCL_VERSION__ <= CL_VERSION_1_1 // This the default on OpenCL 1.2 or higher - #pragma OPENCL EXTENSION cl_khr_fp64: enable - #endif - typedef double real; - typedef double2 real2; - typedef double4 real4; - #define ZERO 0.0 -#endif - -// ================================================================================================= - -// Data-widths in dimension M -#if VWM == 1 - typedef real realM; -#elif VWM == 2 - typedef real2 realM; -#elif VWM == 4 - typedef real4 realM; -#endif - -// Data-widths in dimension N -#if VWN == 1 - typedef real realN; -#elif VWN == 2 - typedef real2 realN; -#elif VWN == 4 - typedef real4 realN; -#endif +#define ZERO 0.0f extern "C" { // Needed by CUPY for Python-based tuners - -inline __device__ float2 make_float2(float s) -{ - return make_float2(s, s); -} - -inline __device__ float4 make_float4(float s) -{ - return make_float4(s, s, s, s); -} - -/* - -inline __device__ float2 operator+(float2 a, float2 b) -{ - return make_float2(a.x + b.x, a.y + b.y); -} - -inline __device__ float4 operator+(float4 a, float4 b) -{ - return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -inline __device__ float2 operator+(float2 a, float b) -{ - return make_float2(a.x + b, a.y + b); -} - -inline __device__ float4 operator+(float4 a, float b) -{ - return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); -} - -inline __host__ __device__ void operator+=(float2 &a, float2 b) -{ - a.x += b.x; - a.y += b.y; -} - -inline __host__ __device__ void operator+=(float4 &a, float4 b) -{ - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; -} - -inline __device__ float2 operator-(float2 a, float2 b) -{ - return make_float2(a.x - b.x, a.y - b.y); -} - -inline __device__ float4 operator-(float4 a, float4 b) -{ - return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); -} - -inline __device__ float2 operator-(float2 a, float b) -{ - return make_float2(a.x - b, a.y - b); -} - -inline __device__ float4 operator-(float4 a, float b) -{ - return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); -} - -inline __device__ float2 operator*(float2 a, float2 b) -{ - return make_float2(a.x * b.x, a.y * b.y); -} - -inline __device__ float4 operator*(float4 a, float4 b) -{ - return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); -} - -inline __device__ float2 operator*(float2 a, float b) -{ - return make_float2(a.x * b, a.y * b); -} - -inline __device__ float4 operator*(float4 a, float b) -{ - return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); -} - -inline __device__ float2 operator*(float b, float2 a) -{ - return make_float2(b * a.x, b * a.y); -} - -inline __device__ float4 operator*(float b, float4 a) -{ - return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); -} - -inline __device__ float2 rsqrtf(float2 x){ - return make_float2(rsqrtf(x.x), rsqrtf(x.y)); -} - -inline __device__ float4 rsqrtf(float4 x){ - return make_float4(rsqrtf(x.x), rsqrtf(x.y), rsqrtf(x.z), rsqrtf(x.w)); -} - -*/ // ================================================================================================= // Caches global off-chip memory into local (shared) memory on-chip. This function is specific for // caching the A input matrix. #if SA == 1 -inline __device__ void GlobalToLocalA(const realM* __restrict__ agm, realM* alm, +inline __device__ void GlobalToLocalA(const float* __restrict__ agm, float* alm, const int kSizeM, const int tid, const int kwg) { const int la0 = tid % MDIMA; const int la1 = tid / MDIMA; @@ -264,7 +121,7 @@ inline __device__ void GlobalToLocalA(const realM* __restrict__ agm, realM* alm, // Same as above, but now for the B input matrix #if SB == 1 -inline __device__ void GlobalToLocalB(const realN* __restrict__ bgm, realN* blm, +inline __device__ void GlobalToLocalB(const float* __restrict__ bgm, float* blm, const int kSizeN, const int tid, const int kwg) { const int lb0 = tid % NDIMB; const int lb1 = tid / NDIMB; @@ -297,7 +154,7 @@ inline __device__ void GlobalToLocalB(const realN* __restrict__ bgm, realN* blm // Caches global off-chip memory directly into per-thread private memory (registers). This function // is specific for caching the A input matrix. #if SA == 0 -inline __device__ void GlobalToPrivateA(const realM* __restrict__ agm, realM apm[MWI/VWM], +inline __device__ void GlobalToPrivateA(const float* __restrict__ agm, float apm[MWI/VWM], const int kSizeM, const int idk, const int kwg) { #pragma unroll for (int mi=0; mi None: - super().__init__("GEMM", get_spec_by_name("GEMM")) + metrics = OrderedDict() + spec = get_spec_by_name("GEMM") + matrix_bytes = 1 + for i in range(0, 3): + matrix_bytes *= spec["KernelSpecification"]["Arguments"][i]["FillValue"] + + metrics["GFLOPs"] = lambda p : (matrix_bytes/1e9) / (p["time"] / 1e3) + + super().__init__("GEMM", spec, metrics=metrics) self.runner.run_reference(self.config_space.default_config) diff --git a/batbench/config_space/cuda_problem.py b/batbench/config_space/cuda_problem.py index 118945e..84f26ce 100644 --- a/batbench/config_space/cuda_problem.py +++ b/batbench/config_space/cuda_problem.py @@ -43,12 +43,13 @@ class CUDAProblem(Problem): """ def __init__(self, kernel_name: str, spec: Optional[Dict[str, Any]] = None, run_settings: Optional[Dict[str, Any]] = None, - cuda_backend="Cupy", runner="KT") -> None: + cuda_backend="Cupy", runner="KT", metrics=None) -> None: super().__init__() self._kernel_name = kernel_name self._program = CUDAProgram(kernel_name) self._language = "CUDA" self.cuda_backend = cuda_backend + self.metrics = metrics self.spec = spec if spec is not None else {} self.spec["BenchmarkConfig"] = { "iterations": 10 } @@ -56,8 +57,9 @@ def __init__(self, kernel_name: str, spec: Optional[Dict[str, Any]] = None, self._config_space = ConfigSpace(self.spec["ConfigurationSpace"]) self.args = ArgHandler(self.spec).populate_args() if runner == "KT": - self.runner = KernelBackend(self.spec, self.config_space, - self.args, cuda_backend=self.cuda_backend) + self.runner = KernelBackend(self.spec, self.config_space, + self.args, cuda_backend=self.cuda_backend, + metrics=metrics) else: self.runner = CudaKernelRunner(self.spec, self.config_space) self.run_settings = run_settings if run_settings is not None else {} diff --git a/batbench/manager/manager.py b/batbench/manager/manager.py index f669ba2..e0d0a69 100644 --- a/batbench/manager/manager.py +++ b/batbench/manager/manager.py @@ -51,7 +51,10 @@ def __init__(self, args): self.problem = benchmark_map[args.benchmark](experiment_settings) self.config_space = self.problem.config_space self.budget_trials = experiment_settings["Budget"][0]["BudgetValue"] - self.dataset = Dataset(experiment_settings, args.benchmark) + print(f"General: {self.problem.spec['General']}") + self.objective = self.problem.spec['General']['Objective'], + self.minimize = self.problem.spec['General']['Minimize'] + self.dataset = Dataset(experiment_settings, args.benchmark, self.objective, self.minimize) self.trial = 0 self.total_time = 0 @@ -71,6 +74,7 @@ def upload(root_results_path): Zenodo(datasets).upload() def finished(self): + self.cleanup = False if self.cleanup: self.dataset.delete_files() @@ -83,6 +87,7 @@ def run(self, tuning_config, result): raise KeyboardInterrupt if list(tuning_config.values()) not in self.problem.config_space: result.validity = "KnownConstraintsViolated" + result.objective = 10000 if self.minimize else 0 result.correctness = 0.0 else: result = self.problem.run(tuning_config, result) diff --git a/batbench/result/dataset.py b/batbench/result/dataset.py index 0170d15..24e6c5b 100644 --- a/batbench/result/dataset.py +++ b/batbench/result/dataset.py @@ -20,11 +20,13 @@ class Dataset: - def __init__(self, experiment_settings, benchmark_name): + def __init__(self, experiment_settings, benchmark_name, objective, minimize): self.files = [] self.cache_df = pd.DataFrame({}) self.writes = 0 self.write_interval = 10 + self.objective = objective + self.minimize = minimize self.root_path = "./results" self.input_zip = "input-data.zip" @@ -151,8 +153,8 @@ def get_best(self): #df = pd.read_hdf(self.cache_results_path, "Results") df = pd.read_csv(self.cache_results_path) df.reset_index(drop=True, inplace=True) - min_index = df['objective'].idxmin() - best_row = df.loc[min_index] + best_index = df['objective'].idxmin() if self.minimize else df['objective'].idxmax() + best_row = df.loc[best_index] return best_row def add_result(self, result): @@ -188,7 +190,7 @@ def final_write_data(self, df=None): #df_iter = df if df is not None else pd.read_hdf(self.cache_results_path, "Results") df_iter = df if df is not None else pd.read_csv(self.cache_results_path) df_iter.reset_index(drop=True, inplace=True) - print(df_iter) + print(df_iter, self.output_results_path) if self.output_format == "csv": df_iter.to_csv(self.output_results_path, mode='w') elif self.output_format == "json": diff --git a/batbench/tuners/kerneltuner_runner/kerneltuner_runner.py b/batbench/tuners/kerneltuner_runner/kerneltuner_runner.py index d722643..5320b8f 100644 --- a/batbench/tuners/kerneltuner_runner/kerneltuner_runner.py +++ b/batbench/tuners/kerneltuner_runner/kerneltuner_runner.py @@ -15,6 +15,9 @@ def main(self, args): self.prog_args = args self.manager = Manager(args) self.f_evals = self.manager.budget_trials + DEFAULT_OBJECTIVE = "time" + self.objective = self.manager.problem.spec['General'].get('Objective', DEFAULT_OBJECTIVE) + self.minimize = self.manager.problem.spec['General'].get('Minimize', True) strategy_options = {"max_fevals": self.f_evals} return self.tune(args.gpu_name, strategy_options=strategy_options) @@ -78,7 +81,7 @@ def convert_results(self, cache): result = Result(config=new_conf) - if isinstance(kt_result["time"], ErrorConfig): + if isinstance(kt_result[self.objective], ErrorConfig): results.append(self.invalid_result(result, "Compile exception")) continue @@ -86,7 +89,9 @@ def convert_results(self, cache): result.runtimes = [t/unit for t in kt_result["times"]] result.runtime = sum(result.runtimes) - result.objective = kt_result["time"]/unit + result.objective = kt_result[self.objective] + if self.objective == "time": + result.objective /= 1000 result.compile_time = kt_result["compile_time"]/unit result.framework_time = kt_result["framework_time"]/unit result.algorithm_time = kt_result["strategy_time"]/unit @@ -171,7 +176,8 @@ def run_tune(self, gpu_name, strategy, strategy_options, verbose, quiet, simulat compiler_options=compiler_options, strategy=strategy, strategy_options=strategy_options, - simulation_mode=simulation_mode) + simulation_mode=simulation_mode, + metrics=self.manager.problem.metrics) def tune(self,