-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #205 from Peart-Guy/main
Add new feature of PyTorch based Chatbot
- Loading branch information
Showing
16 changed files
with
749 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
## **PyTorch Based AI Chatbot** | ||
|
||
|
||
|
||
|
||
### 🎯 **Goal** | ||
|
||
Python Project - This ChatBot allows programmers to attach it to any website where normal users can use it. | ||
|
||
Modules Used: | ||
1. Pytorch | ||
2. NLTK | ||
3. NumPy | ||
4. Random | ||
5. JSON | ||
6. Flask | ||
|
||
|
||
# Chatbot Deployment with Flask and JavaScript | ||
How to deploy: | ||
- Deploy within Flask app with jinja2 template | ||
|
||
## Initial Setup: | ||
``` | ||
$ cd chatbot | ||
$ py -m venv venv | ||
$ . venv/scripts/activate | ||
``` | ||
Install dependencies | ||
``` | ||
$ (venv) pip install Flask torch torchvision nltk | ||
``` | ||
Install nltk package | ||
``` | ||
$ (venv) python | ||
>>> import nltk | ||
>>> nltk.download('punkt') | ||
``` | ||
Modify `intents.json` with different intents and responses for your Chatbot as per your need | ||
|
||
Run | ||
``` | ||
$ (venv) python train.py | ||
``` | ||
This will dump data.pth file. And then run | ||
the following command to test it in the console. | ||
|
||
``` | ||
Run the chatbot with frontend in cmd or editor | ||
``` | ||
``` | ||
$python app.py | ||
``` | ||
|
||
### ✒️ **Your Signature** | ||
|
||
`Ankan Mukhopadhyay` | ||
[GitHub Profile](https://github.com/Peart-Guy) | [LinkedIn](https://www.linkedin.com/in/ankan-mukhopadhyay-06baa4315/) | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from flask import Flask , render_template , request , jsonify | ||
|
||
from chat import get_response | ||
|
||
app = Flask(__name__) | ||
|
||
@app.get("/") | ||
def index_get(): | ||
return render_template("base.html") | ||
|
||
@app.post("/predict") | ||
def predict(): | ||
text = request.get_json().get("message") | ||
#checking if message is valid | ||
response = get_response(text) | ||
message = {"answer": response} | ||
|
||
return jsonify(message) | ||
|
||
if __name__ == "__main__" : | ||
app.run(debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import random | ||
import json | ||
|
||
import torch | ||
|
||
from model import NeuralNet | ||
from nltk_utils import bag_of_words, tokenize | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
with open('intents.json', 'r') as json_data: | ||
intents = json.load(json_data) | ||
|
||
FILE = "data.pth" | ||
data = torch.load(FILE) | ||
|
||
input_size = data["input_size"] | ||
hidden_size = data["hidden_size"] | ||
output_size = data["output_size"] | ||
all_words = data['all_words'] | ||
tags = data['tags'] | ||
model_state = data["model_state"] | ||
|
||
model = NeuralNet(input_size, hidden_size, output_size).to(device) | ||
model.load_state_dict(model_state) | ||
model.eval() | ||
|
||
bot_name = "Ankan" | ||
|
||
def get_response(msg): | ||
sentence = tokenize(msg) | ||
X = bag_of_words(sentence, all_words) | ||
X = X.reshape(1, X.shape[0]) | ||
X = torch.from_numpy(X).to(device) | ||
|
||
output = model(X) | ||
_, predicted = torch.max(output, dim=1) | ||
|
||
tag = tags[predicted.item()] | ||
|
||
probs = torch.softmax(output, dim=1) | ||
prob = probs[0][predicted.item()] | ||
if prob.item() > 0.75: | ||
for intent in intents['intents']: | ||
if tag == intent["tag"]: | ||
return random.choice(intent['responses']) | ||
|
||
return "I do not understand..." | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Let's chat! (type 'quit' to exit)") | ||
while True: | ||
# sentence = "do you use credit cards?" | ||
sentence = input("You: ") | ||
if sentence == "quit": | ||
break | ||
|
||
resp = get_response(sentence) | ||
print(resp) | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
{ | ||
"intents": [ | ||
{ | ||
"tag": "greeting", | ||
"patterns": [ | ||
"Hi", | ||
"Hey", | ||
"How are you", | ||
"Is anyone there?", | ||
"Hello", | ||
"Good day" | ||
], | ||
"responses": [ | ||
"Hey :-)", | ||
"Hello, thanks for visiting", | ||
"Hi there, what can I do for you?", | ||
"Hi there, how can I help?" | ||
] | ||
}, | ||
{ | ||
"tag": "goodbye", | ||
"patterns": ["Bye", "See you later", "Goodbye"], | ||
"responses": [ | ||
"See you later, thanks for visiting", | ||
"Have a nice day", | ||
"Bye! Come back again soon." | ||
] | ||
}, | ||
{ | ||
"tag": "thanks", | ||
"patterns": ["Thanks", "Thank you", "That's helpful", "Thank's a lot!"], | ||
"responses": ["Happy to help!", "Any time!", "My pleasure"] | ||
}, | ||
{ | ||
"tag": "items", | ||
"patterns": [ | ||
"Which services do you provide?", | ||
"What kinds of service are there?", | ||
"What do you do?" | ||
], | ||
"responses": [ | ||
"We provide services." | ||
|
||
] | ||
}, | ||
{ | ||
"tag": "payments", | ||
"patterns": [ | ||
"Do you take credit cards?", | ||
"Do you accept Mastercard?", | ||
"Can I pay with UPI?", | ||
"Are you cash only?" | ||
], | ||
"responses": [ | ||
"We accept VISA, Mastercard and UPI", | ||
"We accept most major credit cards, and UPI" | ||
] | ||
}, | ||
{ | ||
"tag": "Sorry", | ||
"patterns": [ | ||
"You were wrong", | ||
"This is not right", | ||
"I don't Understand" | ||
], | ||
"responses": [ | ||
"Sorry for Inconveninience. I am still in Learning phase", | ||
"Please submit the feedback , I will surely learn from my mistakes." | ||
] | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class NeuralNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, num_classes): | ||
super(NeuralNet, self).__init__() | ||
self.l1 = nn.Linear(input_size, hidden_size) | ||
self.l2 = nn.Linear(hidden_size, hidden_size) | ||
self.l3 = nn.Linear(hidden_size, num_classes) | ||
self.relu = nn.ReLU() | ||
|
||
def forward(self, x): | ||
out = self.l1(x) | ||
out = self.relu(out) | ||
out = self.l2(out) | ||
out = self.relu(out) | ||
out = self.l3(out) | ||
# no activation and no softmax at the end | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
import nltk | ||
# nltk.download('punkt') | ||
from nltk.stem.porter import PorterStemmer | ||
stemmer = PorterStemmer() | ||
|
||
|
||
# def tokenize(sentence): | ||
# """ | ||
# split sentence into array of words/tokens | ||
# a token can be a word or punctuation character, or number | ||
# """ | ||
# return nltk.word_tokenize(sentence) | ||
|
||
def tokenize(sentence): | ||
if sentence is None or sentence.strip() == "": | ||
# If the input sentence is None or empty, return an empty list or handle accordingly | ||
return [] | ||
return nltk.word_tokenize(sentence) | ||
|
||
|
||
def stem(word): | ||
""" | ||
stemming = find the root form of the word | ||
examples: | ||
words = ["organize", "organizes", "organizing"] | ||
words = [stem(w) for w in words] | ||
-> ["organ", "organ", "organ"] | ||
""" | ||
return stemmer.stem(word.lower()) | ||
|
||
|
||
def bag_of_words(tokenized_sentence, words): | ||
""" | ||
return bag of words array: | ||
1 for each known word that exists in the sentence, 0 otherwise | ||
example: | ||
sentence = ["hello", "how", "are", "you"] | ||
words = ["hi", "hello", "I", "you", "bye", "thank", "cool"] | ||
bog = [ 0 , 1 , 0 , 1 , 0 , 0 , 0] | ||
""" | ||
# stem each word | ||
sentence_words = [stem(word) for word in tokenized_sentence] | ||
# initialize bag with 0 for each word | ||
bag = np.zeros(len(words), dtype=np.float32) | ||
for idx, w in enumerate(words): | ||
if w in sentence_words: | ||
bag[idx] = 1 | ||
|
||
return bag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
class Chatbox { | ||
constructor() { | ||
this.args = { | ||
openButton: document.querySelector('.chatbox__button'), | ||
chatBox: document.querySelector('.chatbox__support'), | ||
sendButton: document.querySelector('.send__button') | ||
} | ||
|
||
this.state = false; | ||
this.messages = []; | ||
} | ||
|
||
display() { | ||
const {openButton, chatBox, sendButton} = this.args; | ||
|
||
openButton.addEventListener('click', () => this.toggleState(chatBox)) | ||
|
||
sendButton.addEventListener('click', () => this.onSendButton(chatBox)) | ||
|
||
const node = chatBox.querySelector('input'); | ||
node.addEventListener("keyup", ({key}) => { | ||
if (key === "Enter") { | ||
this.onSendButton(chatBox) | ||
} | ||
}) | ||
} | ||
|
||
toggleState(chatbox) { | ||
this.state = !this.state; | ||
|
||
// show or hides the box | ||
if(this.state) { | ||
chatbox.classList.add('chatbox--active') | ||
} else { | ||
chatbox.classList.remove('chatbox--active') | ||
} | ||
} | ||
|
||
onSendButton(chatbox) { | ||
var textField = chatbox.querySelector('input'); | ||
let text1 = textField.value | ||
if (text1 === "") { | ||
return; | ||
} | ||
|
||
let msg1 = { name: "User", message: text1 } | ||
this.messages.push(msg1); | ||
|
||
fetch('http://127.0.0.1:5000/predict', { | ||
method: 'POST', | ||
body: JSON.stringify({ message: text1 }), | ||
mode: 'cors', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
}) | ||
.then(r => r.json()) | ||
.then(r => { | ||
let msg2 = { name: "Ankan", message: r.answer }; | ||
this.messages.push(msg2); | ||
this.updateChatText(chatbox) | ||
textField.value = '' | ||
|
||
}).catch((error) => { | ||
console.error('Error:', error); | ||
this.updateChatText(chatbox) | ||
textField.value = '' | ||
}); | ||
} | ||
|
||
updateChatText(chatbox) { | ||
var html = ''; | ||
this.messages.slice().reverse().forEach(function(item, index) { | ||
if (item.name === "Ankan") | ||
{ | ||
html += '<div class="messages__item messages__item--visitor">' + item.message + '</div>' | ||
} | ||
else | ||
{ | ||
html += '<div class="messages__item messages__item--operator">' + item.message + '</div>' | ||
} | ||
}); | ||
|
||
const chatmessage = chatbox.querySelector('.chatbox__messages'); | ||
chatmessage.innerHTML = html; | ||
} | ||
} | ||
|
||
|
||
const chatbox = new Chatbox(); | ||
chatbox.display(); |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.