-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
60 lines (39 loc) · 1.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from utils import use_model, load_latest_checkpoint
from transformer import transformer
from nlp_utils import TextTokenizing
import argparse
from hyperparameters import NUM_LAYERS, D_MODEL, NUM_HEADS, DFF, DROPOUT
def main(args):
checkpoint_dir = args.tokenizer
latest_checkpoint = load_latest_checkpoint(checkpoint_dir)
textTokenizing = TextTokenizing()
tokenizer = textTokenizing.load_tokenizer(args.tokenizer)
VOCAB_SIZE, START_TOKEN, END_TOKEN = textTokenizing.tokens()
model = transformer(
vocab_size=VOCAB_SIZE,
num_layers=NUM_LAYERS,
dff=DFF,
d_model=D_MODEL,
num_heads=NUM_HEADS,
dropout=DROPOUT
)
prediction = use_model(model=model, tokenizer=tokenizer, START_TOKEN=START_TOKEN, END_TOKEN=END_TOKEN, MAX_LENGTH=50)
while True:
user_input = input("You: ")
machine_answer = prediction.predict(user_input)
print(f"AI: {machine_answer}")
if __name__ == '__main__':
"""Arguments Parser"""
argparser = argparse.ArgumentParser(description="Human Like Chatbot Training")
argparser.add_argument(
'--tokenizer',
default="tokenizer",
help="Directory of the .subword tokenized words file."
)
argparser.add_argument(
'--checkpoint',
default="training",
help="Directory where checkpoint file was stored"
)
args = argparser.parse_args()
main(args)