Skip to content

Commit

Permalink
Merge pull request #144 from klei22/add_mnist_dataset
Browse files Browse the repository at this point in the history
Add scripts compatible with the MNIST dataset
  • Loading branch information
gkielian authored Apr 15, 2024
2 parents 8d02e6b + 7bc7110 commit b42b874
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 0 deletions.
2 changes: 2 additions & 0 deletions data/mnist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mnist_images/
grayscale_images/
69 changes: 69 additions & 0 deletions data/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
## MNIST Dataset and Preprocessing Scripts

This directory contains scripts compatible with the MNIST dataset, which is a
classic dataset of handwritten digits commonly used for training various image
processing systems and machine learning models, including Language Learning
Models (LLMs).

### Script `get_dataset.py`

The `get_dataset.py` script is used to download the MNIST dataset and save the
images to a specified directory.

**Usage:**
```bash
python3 get_dataset.py
```

By default, the script will download the MNIST dataset and save the images in
the `mnist_images` directory.

### Script `gray.py`

The `gray.py` script is used to convert images from the MNIST dataset into ASCII
art, which can be used for training LLMs or for visualization purposes.

**Usage:**

```bash
python3 gray.py --image-dir mnist_images --output-dimensions 8x8 --levels 2 --chars 01 --append-to-file
```

Options:
- `--image-dir` (required): Directory containing images to convert.
- `--output-dir` (default: `grayscale_images`): Directory to save ASCII art.
- `--output-dimensions` (default: `16x16`): Output dimensions for ASCII art, e.g., 8x8, 16x16.
- `--levels` (default: 2): Number of grayscale levels, currently 2 - 9 supported.
- `--chars` (optional): Custom characters for ASCII art, ordered from darkest to lightest.
- `--append-to-file` (optional): Append ASCII art to a single file instead of creating separate files.
- `--output-file` (default: `input.txt`): File to append ASCII art to if `--append-to-file` is used.
- `--number-placement` (default: `before`): Place the type of number before or after the ASCII image in the output file. Choices are `before` or `after`.

The script will process each image in the `--image-dir` directory and save the
resulting ASCII art to the `--output-dir` directory.

## License Information for MNIST Dataset

The MNIST dataset is made available under the terms of the [Creative Commons
Attribution-Share Alike 3.0
license](https://creativecommons.org/licenses/by-sa/3.0/). When using the MNIST
dataset, you should attribute the source as provided by the original authors.

Please note that while the MNIST dataset itself is licensed under the Creative
Commons license, the `get_dataset.py` and `gray.py` scripts provided in this
directory are subject to the license terms of the repository they reside in.

## Citation Information for MNIST Dataset

If you use the MNIST dataset in your research, please cite:

```bibtex
@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
volume={2},
year={2010}
}
```

26 changes: 26 additions & 0 deletions data/mnist/get_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from tqdm import tqdm
from datasets import load_dataset

def save_mnist_images(dataset, output_dir):
"""Saves images from the MNIST dataset to a specified directory."""
os.makedirs(output_dir, exist_ok=True)
for idx, item in tqdm(enumerate(dataset), total=len(dataset), desc="Saving images"):
image = item['image'] # The image is already a PIL.Image object.
label = item['label']
image.save(os.path.join(output_dir, f'{idx}_{label}.png'))

def main():
# Load the MNIST dataset from Hugging Face
dataset = load_dataset("mnist", split='train')

# Specify the output directory for images
output_dir = 'mnist_images'

# Save images
save_mnist_images(dataset, output_dir)
print(f"Images saved in {output_dir}")

if __name__ == "__main__":
main()

102 changes: 102 additions & 0 deletions data/mnist/gray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from PIL import Image
import argparse
import os
import numpy as np
import sys

def convert_image_to_ascii(image_path, output_size, grayscale, levels, chars=None):
# Load the image from the specified path and resize it
img = Image.open(image_path).resize(output_size, Image.LANCZOS)

if grayscale:
img = img.convert("L") # Convert to grayscale
if levels < 256:
img = img.point(lambda p: (p * levels) // 256 * (256 // levels))

# Define default characters for different levels
default_chars = "@%#*+=-:. "
if chars is None:
if levels == 2:
chars = "@ "
elif levels == 3:
chars = "@. "
elif levels == 4:
chars = "@*- "
elif levels == 5:
chars = "@#+- "
elif levels == 6:
chars = "@#+-. "
elif levels == 7:
chars = "@#+-:. "
elif levels == 8:
chars = "@#*+-:. "
elif levels == 9:
chars = "@%#*+-:. "
else:
sys.exit(f"number of levels {levels} not supported")

# Normalize the characters set based on the number of levels
char_array = np.array([c for c in chars])
n_chars = len(char_array)
scale_factor = 256 // levels

# Convert the image to a numpy array
img_np = np.array(img)

# Convert each pixel to the corresponding ASCII character
ascii_img = char_array[img_np // scale_factor]

# Join characters to form lines and then join lines to form the full ASCII image
ascii_img_lines = ["".join(row) for row in ascii_img]
ascii_result = "\n".join(ascii_img_lines)

return ascii_result

# Usage example
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert images to ASCII art.')
parser.add_argument('--output-dimensions', type=str, default='16x16', help='Output dimensions for ASCII art, e.g., 8x8, 16x16')
parser.add_argument('--levels', type=int, default=2, help='Number of grayscale levels, currently 2 - 9 supported')
parser.add_argument('--image-dir', type=str, required=True, help='Directory containing images to convert.')
parser.add_argument('--output-dir', type=str, default='grayscale_images', help='Directory to save ASCII art.')
parser.add_argument('--append-to-file', action='store_true', help='Append ASCII art to a single file.')
parser.add_argument('--output-file', type=str, default='input.txt', help='File to append ASCII art to.')
parser.add_argument('--number-placement', type=str, default='before', choices=['before', 'after'], help='Place the type of number before or after the ASCII image in the output file.')
parser.add_argument('--chars', type=str, default=None, help='Custom characters for ASCII art, ordered from darkest to lightest.')
args = parser.parse_args()

# Parse output dimensions
output_dimensions = tuple(map(int, args.output_dimensions.split('x')))

# Create output directory if it doesn't exist
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# Open the output file for appending if required
output_file = None
if args.append_to_file:
output_file = open(args.output_file, 'a')

# Process each image in the directory
for image_filename in os.listdir(args.image_dir):
if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
image_path = os.path.join(args.image_dir, image_filename)
ascii_art = convert_image_to_ascii(image_path, output_size=output_dimensions, grayscale=True, levels=args.levels, chars=args.chars)
if args.append_to_file:
# Determine the number from the filename
number = os.path.splitext(image_filename)[0].split('_')[-1]
# Append the ASCII art to the output file with the number placement
if args.number_placement == 'before':
output_file.write(f'{number}\n{ascii_art}\n')
else:
output_file.write(f'{ascii_art}\n{number}\n')
else:
# Save the ASCII art to a text file
output_path = os.path.join(args.output_dir, os.path.splitext(image_filename)[0] + '.txt')
with open(output_path, 'w') as f:
f.write(ascii_art)
print(f'ASCII art saved to {output_path}')

# Close the output file if it was opened
if output_file:
output_file.close()
1 change: 1 addition & 0 deletions data/mnist/prepare.py
1 change: 1 addition & 0 deletions data/mnist/utils/meta_util.py
1 change: 1 addition & 0 deletions data/mnist/utils/txt_to_phonemes.sh
28 changes: 28 additions & 0 deletions explorations/mnist.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[
{
"max_iters": ["30000"],
"eval_iters": ["200"],
"eval_interval": ["250"],
"log_interval": ["10"],
"n_layer": ["8"],
"n_kv_group": ["4"],
"n_head": ["8"],
"n_embd": ["384"],
"block_size": ["256"],
"shared_mlp_size" : ["1"],
"shared_mlp_sym" : [true],
"shared_attn_size" : ["2"],
"shared_attn_sym" : [true],
"device": ["cuda"],
"dataset": ["mnist"],
"compile": [true],
"use_post_ln": [false],
"softmax_variant_attn": ["softmax"],
"use_abs_pos_embeddings": [false],
"use_rotary_embeddings": [false],
"use_fire_embeddings": [true],
"shared_fire_embeddings": [true],
"tensorboard_run_name": ["mnist_ascii"]
}
]

0 comments on commit b42b874

Please sign in to comment.