forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from klei22/add_mnist_dataset
Add scripts compatible with the MNIST dataset
- Loading branch information
Showing
8 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mnist_images/ | ||
grayscale_images/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
## MNIST Dataset and Preprocessing Scripts | ||
|
||
This directory contains scripts compatible with the MNIST dataset, which is a | ||
classic dataset of handwritten digits commonly used for training various image | ||
processing systems and machine learning models, including Language Learning | ||
Models (LLMs). | ||
|
||
### Script `get_dataset.py` | ||
|
||
The `get_dataset.py` script is used to download the MNIST dataset and save the | ||
images to a specified directory. | ||
|
||
**Usage:** | ||
```bash | ||
python3 get_dataset.py | ||
``` | ||
|
||
By default, the script will download the MNIST dataset and save the images in | ||
the `mnist_images` directory. | ||
|
||
### Script `gray.py` | ||
|
||
The `gray.py` script is used to convert images from the MNIST dataset into ASCII | ||
art, which can be used for training LLMs or for visualization purposes. | ||
|
||
**Usage:** | ||
|
||
```bash | ||
python3 gray.py --image-dir mnist_images --output-dimensions 8x8 --levels 2 --chars 01 --append-to-file | ||
``` | ||
|
||
Options: | ||
- `--image-dir` (required): Directory containing images to convert. | ||
- `--output-dir` (default: `grayscale_images`): Directory to save ASCII art. | ||
- `--output-dimensions` (default: `16x16`): Output dimensions for ASCII art, e.g., 8x8, 16x16. | ||
- `--levels` (default: 2): Number of grayscale levels, currently 2 - 9 supported. | ||
- `--chars` (optional): Custom characters for ASCII art, ordered from darkest to lightest. | ||
- `--append-to-file` (optional): Append ASCII art to a single file instead of creating separate files. | ||
- `--output-file` (default: `input.txt`): File to append ASCII art to if `--append-to-file` is used. | ||
- `--number-placement` (default: `before`): Place the type of number before or after the ASCII image in the output file. Choices are `before` or `after`. | ||
|
||
The script will process each image in the `--image-dir` directory and save the | ||
resulting ASCII art to the `--output-dir` directory. | ||
|
||
## License Information for MNIST Dataset | ||
|
||
The MNIST dataset is made available under the terms of the [Creative Commons | ||
Attribution-Share Alike 3.0 | ||
license](https://creativecommons.org/licenses/by-sa/3.0/). When using the MNIST | ||
dataset, you should attribute the source as provided by the original authors. | ||
|
||
Please note that while the MNIST dataset itself is licensed under the Creative | ||
Commons license, the `get_dataset.py` and `gray.py` scripts provided in this | ||
directory are subject to the license terms of the repository they reside in. | ||
|
||
## Citation Information for MNIST Dataset | ||
|
||
If you use the MNIST dataset in your research, please cite: | ||
|
||
```bibtex | ||
@article{lecun2010mnist, | ||
title={MNIST handwritten digit database}, | ||
author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, | ||
journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, | ||
volume={2}, | ||
year={2010} | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
from tqdm import tqdm | ||
from datasets import load_dataset | ||
|
||
def save_mnist_images(dataset, output_dir): | ||
"""Saves images from the MNIST dataset to a specified directory.""" | ||
os.makedirs(output_dir, exist_ok=True) | ||
for idx, item in tqdm(enumerate(dataset), total=len(dataset), desc="Saving images"): | ||
image = item['image'] # The image is already a PIL.Image object. | ||
label = item['label'] | ||
image.save(os.path.join(output_dir, f'{idx}_{label}.png')) | ||
|
||
def main(): | ||
# Load the MNIST dataset from Hugging Face | ||
dataset = load_dataset("mnist", split='train') | ||
|
||
# Specify the output directory for images | ||
output_dir = 'mnist_images' | ||
|
||
# Save images | ||
save_mnist_images(dataset, output_dir) | ||
print(f"Images saved in {output_dir}") | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from PIL import Image | ||
import argparse | ||
import os | ||
import numpy as np | ||
import sys | ||
|
||
def convert_image_to_ascii(image_path, output_size, grayscale, levels, chars=None): | ||
# Load the image from the specified path and resize it | ||
img = Image.open(image_path).resize(output_size, Image.LANCZOS) | ||
|
||
if grayscale: | ||
img = img.convert("L") # Convert to grayscale | ||
if levels < 256: | ||
img = img.point(lambda p: (p * levels) // 256 * (256 // levels)) | ||
|
||
# Define default characters for different levels | ||
default_chars = "@%#*+=-:. " | ||
if chars is None: | ||
if levels == 2: | ||
chars = "@ " | ||
elif levels == 3: | ||
chars = "@. " | ||
elif levels == 4: | ||
chars = "@*- " | ||
elif levels == 5: | ||
chars = "@#+- " | ||
elif levels == 6: | ||
chars = "@#+-. " | ||
elif levels == 7: | ||
chars = "@#+-:. " | ||
elif levels == 8: | ||
chars = "@#*+-:. " | ||
elif levels == 9: | ||
chars = "@%#*+-:. " | ||
else: | ||
sys.exit(f"number of levels {levels} not supported") | ||
|
||
# Normalize the characters set based on the number of levels | ||
char_array = np.array([c for c in chars]) | ||
n_chars = len(char_array) | ||
scale_factor = 256 // levels | ||
|
||
# Convert the image to a numpy array | ||
img_np = np.array(img) | ||
|
||
# Convert each pixel to the corresponding ASCII character | ||
ascii_img = char_array[img_np // scale_factor] | ||
|
||
# Join characters to form lines and then join lines to form the full ASCII image | ||
ascii_img_lines = ["".join(row) for row in ascii_img] | ||
ascii_result = "\n".join(ascii_img_lines) | ||
|
||
return ascii_result | ||
|
||
# Usage example | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Convert images to ASCII art.') | ||
parser.add_argument('--output-dimensions', type=str, default='16x16', help='Output dimensions for ASCII art, e.g., 8x8, 16x16') | ||
parser.add_argument('--levels', type=int, default=2, help='Number of grayscale levels, currently 2 - 9 supported') | ||
parser.add_argument('--image-dir', type=str, required=True, help='Directory containing images to convert.') | ||
parser.add_argument('--output-dir', type=str, default='grayscale_images', help='Directory to save ASCII art.') | ||
parser.add_argument('--append-to-file', action='store_true', help='Append ASCII art to a single file.') | ||
parser.add_argument('--output-file', type=str, default='input.txt', help='File to append ASCII art to.') | ||
parser.add_argument('--number-placement', type=str, default='before', choices=['before', 'after'], help='Place the type of number before or after the ASCII image in the output file.') | ||
parser.add_argument('--chars', type=str, default=None, help='Custom characters for ASCII art, ordered from darkest to lightest.') | ||
args = parser.parse_args() | ||
|
||
# Parse output dimensions | ||
output_dimensions = tuple(map(int, args.output_dimensions.split('x'))) | ||
|
||
# Create output directory if it doesn't exist | ||
if not os.path.exists(args.output_dir): | ||
os.makedirs(args.output_dir) | ||
|
||
# Open the output file for appending if required | ||
output_file = None | ||
if args.append_to_file: | ||
output_file = open(args.output_file, 'a') | ||
|
||
# Process each image in the directory | ||
for image_filename in os.listdir(args.image_dir): | ||
if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): | ||
image_path = os.path.join(args.image_dir, image_filename) | ||
ascii_art = convert_image_to_ascii(image_path, output_size=output_dimensions, grayscale=True, levels=args.levels, chars=args.chars) | ||
if args.append_to_file: | ||
# Determine the number from the filename | ||
number = os.path.splitext(image_filename)[0].split('_')[-1] | ||
# Append the ASCII art to the output file with the number placement | ||
if args.number_placement == 'before': | ||
output_file.write(f'{number}\n{ascii_art}\n') | ||
else: | ||
output_file.write(f'{ascii_art}\n{number}\n') | ||
else: | ||
# Save the ASCII art to a text file | ||
output_path = os.path.join(args.output_dir, os.path.splitext(image_filename)[0] + '.txt') | ||
with open(output_path, 'w') as f: | ||
f.write(ascii_art) | ||
print(f'ASCII art saved to {output_path}') | ||
|
||
# Close the output file if it was opened | ||
if output_file: | ||
output_file.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../template/prepare.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../template/utils/meta_util.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../template/utils/txt_to_phonemes.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[ | ||
{ | ||
"max_iters": ["30000"], | ||
"eval_iters": ["200"], | ||
"eval_interval": ["250"], | ||
"log_interval": ["10"], | ||
"n_layer": ["8"], | ||
"n_kv_group": ["4"], | ||
"n_head": ["8"], | ||
"n_embd": ["384"], | ||
"block_size": ["256"], | ||
"shared_mlp_size" : ["1"], | ||
"shared_mlp_sym" : [true], | ||
"shared_attn_size" : ["2"], | ||
"shared_attn_sym" : [true], | ||
"device": ["cuda"], | ||
"dataset": ["mnist"], | ||
"compile": [true], | ||
"use_post_ln": [false], | ||
"softmax_variant_attn": ["softmax"], | ||
"use_abs_pos_embeddings": [false], | ||
"use_rotary_embeddings": [false], | ||
"use_fire_embeddings": [true], | ||
"shared_fire_embeddings": [true], | ||
"tensorboard_run_name": ["mnist_ascii"] | ||
} | ||
] | ||
|