From b9eb59b89750bb65684c43ff46197267d9d153b2 Mon Sep 17 00:00:00 2001 From: prometheus <57171759+NLGithubWP@users.noreply.github.com> Date: Sat, 15 Jun 2024 12:29:36 +0800 Subject: [PATCH] Add the benchmark dataset for testing the model selection --- examples/cnn_ms/data/mnist.py | 92 +++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 examples/cnn_ms/data/mnist.py diff --git a/examples/cnn_ms/data/mnist.py b/examples/cnn_ms/data/mnist.py new file mode 100644 index 000000000..071d45d20 --- /dev/null +++ b/examples/cnn_ms/data/mnist.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import os +import sys +import gzip +import codecs + + +def check_dataset_exist(dirpath): + if not os.path.exists(dirpath): + print( + 'The MNIST dataset does not exist. Please download the mnist dataset using python data/download_mnist.py' + ) + sys.exit(0) + return dirpath + + +def load_dataset(): + train_x_path = '/tmp/train-images-idx3-ubyte.gz' # need to change to local disk + train_y_path = '/tmp/train-labels-idx1-ubyte.gz' # need to change to local disk + valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz' # need to change to local disk + valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz' # need to change to local disk + + train_x = read_image_file(check_dataset_exist(train_x_path)).astype( + np.float32) + train_y = read_label_file(check_dataset_exist(train_y_path)).astype( + np.float32) + valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype( + np.float32) + valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype( + np.float32) + return train_x, train_y, valid_x, valid_y + + +def read_label_file(path): + with gzip.open(path, 'rb') as f: + data = f.read() + assert get_int(data[:4]) == 2049 + length = get_int(data[4:8]) + parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length)) + return parsed + + +def get_int(b): + return int(codecs.encode(b, 'hex'), 16) + + +def read_image_file(path): + with gzip.open(path, 'rb') as f: + data = f.read() + assert get_int(data[:4]) == 2051 + length = get_int(data[4:8]) + num_rows = get_int(data[8:12]) + num_cols = get_int(data[12:16]) + parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape( + (length, 1, num_rows, num_cols)) + return parsed + + +def normalize(train_x, val_x): + train_x /= 255 + val_x /= 255 + return train_x, val_x + + +def load(): + train_x, train_y, val_x, val_y = load_dataset() + train_x, val_x = normalize(train_x, val_x) + train_x = train_x.astype(np.float32) + val_x = val_x.astype(np.float32) + train_y = train_y.astype(np.int32) + val_y = val_y.astype(np.int32) + return train_x, train_y, val_x, val_y + \ No newline at end of file