-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample.py
128 lines (104 loc) · 4.37 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Sample from a trained model
"""
import modal
from modal import Image
# Initialize modal app
app = modal.App()
# Setup volume for storing model weights
volume = modal.Volume.from_name("pretraining-gpt2-tinystories")
MODEL_DIR = "/vol"
image = (
Image.from_registry("thr3a/cuda12.1-torch")
.pip_install("jax[cuda12]", "jaxlib", "cloudpickle", "tqdm", "equinox", "python-dotenv", "optax", "numpy", "tiktoken", gpu="A100")
)
@app.function(
gpu="t4",
timeout=86400, # Allow one day timout period
image=image,
mounts=[
modal.Mount.from_local_dir(
"./data",
remote_path="/root/data"),
modal.Mount.from_local_python_packages("helpers", "executables"),
],
volumes={
MODEL_DIR: volume
},
)
def sample():
import os
import subprocess
import cloudpickle
import jax
import jax.numpy as jnp
import equinox as eqx
import modal
import tiktoken
from executables.model import GPTConfig, GPT
from functools import partial
from tqdm import tqdm
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = '/vol' # ignored if init_from is not 'resume'
start = """Once upon a time, there were three researchers; Sachith, Chandeepa and Yasiru working for a company named Surge Global.""" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
# -----------------------------------------------------------------------------
key = jax.random.PRNGKey(seed)
jax.default_matmul_precision = "tensorfloat32"
checkpoint_params_file = os.path.join(out_dir, "params.pkl")
checkpoint_file = os.path.join(out_dir, "model.eqx")
with open(checkpoint_params_file, 'rb') as f:
checkpoint_params = cloudpickle.load(f)
gptconf = checkpoint_params['model_args']
model = GPT(gptconf, key=key)
model = eqx.tree_deserialise_leaves(checkpoint_file, model)
model = eqx.nn.inference_mode(model)
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if 'config' in checkpoint_params and 'dataset' in checkpoint_params['config']: # older checkpoints might not have these...
subprocess.run(["python", "prepare.py"], cwd=f'data/{checkpoint_params["config"]["dataset"]}')
meta_path = os.path.join('data', checkpoint_params['config']['dataset'], 'meta.pkl')
load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = cloudpickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i.item()] for i in l])
else:
# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# encode the beginning of the prompt
start_ids = encode(start)
x = jnp.array(start_ids, dtype=jnp.int32)[None]
def generate(model: GPT, token, key: jax.random.PRNGKey):
generate_fn = partial(
model.generate,
max_new_tokens=max_new_tokens,
top_k=top_k,
temperature=temperature
)
if token.shape[0] == 1:
generated = generate_fn(token[0], key=key)
else:
key = jax.random.split(key, token.shape[0])
generated = jax.vmap(generate_fn)(token, key=key)
return decode(generated)
for k in tqdm(range(num_samples), desc="samples"):
sampling_key = jax.random.fold_in(key, k)
generated = generate(model, x, sampling_key)
print(generated)
print('---------------')
@app.local_entrypoint()
def main():
sample.remote()