-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
49 lines (39 loc) · 1.51 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import streamlit as st
import torch
import os
from GPTLanguageModelClass import *
block_size = hyperparams.block_size
batch_size = hyperparams.batch_size
max_iters = hyperparams.max_iters
learning_rate = hyperparams.learning_rate
eval_every = hyperparams.eval_every
n_embd = hyperparams.n_embd
n_head = hyperparams.n_head
n_layer = hyperparams.n_layer
dropout = hyperparams.dropout
device = hyperparams.device
st.title('LLM from scratch Demo')
st.write(f"Using device: {device}")
if not os.path.exists("./vocab.txt"):
raise Exception("Please run extract.py first")
chars = ""
with open("./vocab.txt", 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [string_to_int[ch] for ch in s]
decode = lambda x: ''.join([int_to_string[i] for i in x])
model_pickle_path = './model.pt'
st.write('loading model parameters...')
with open(model_pickle_path, 'rb') as f:
model = torch.load(f, map_location=device)
st.write('model loaded successfully!')
prompt = ''
prompt = st.text_area('Prompt:', value=prompt, height=100, max_chars=block_size - 1, key='prompt')
if len(prompt) != 0:
context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
max_new_tokens = block_size - len(prompt)
generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=max_new_tokens)[0].tolist())
st.write('Generated text:')
st.write(generated_chars)