Skip to content

Commit

Permalink
Prediction pipeline & app module added
Browse files Browse the repository at this point in the history
  • Loading branch information
utpalpaul108 committed Nov 23, 2023
1 parent 4c483c0 commit 7b08b41
Show file tree
Hide file tree
Showing 6 changed files with 471 additions and 3 deletions.
55 changes: 55 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from flask import Flask, render_template, jsonify, request
from flask_cors import CORS, cross_origin
from wasteDetection.pipeline.training_pipeline import TrainingPipeline
from wasteDetection.utils import read_yaml, create_directories, decodeImage
from wasteDetection.constants import *
from wasteDetection.pipeline.prediction_pipeline import PredictionPipeline
import os


app = Flask(__name__)
CORS(app)

@app.route('/')
def home():
return render_template('index.html')


@app.route('/train')
def trainRoute():
training_pipeline = TrainingPipeline()
training_pipeline.train()
return 'Model trained successfully'


@app.route('/predict', methods=['POST'])
@cross_origin()
def predictRoute():
try:
config = read_yaml(CONFIG_FILE_PATH)
prediction_config = config.prediction
create_directories([prediction_config.root_dir])

file_name = os.path.join(prediction_config.root_dir, 'inputImage.jpg')
waste_detection = PredictionPipeline()

image = request.json['image']
decodeImage(image, file_name)
result = waste_detection.predict(file_name)
return jsonify(result)

except Exception as e:
raise e

@app.route('/live', methods=['GET'])
@cross_origin()
def predictLive():
try:
waste_detection = PredictionPipeline()
waste_detection.livePredict()

except Exception as e:
raise e

if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)
8 changes: 7 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@ data_validation:
# Model Training
model_training:
root_dir: artifacts/model_training
yolo_model_gitgub_url: https://github.com/ultralytics/yolov5.git
yolo_model_gitgub_url: https://github.com/ultralytics/yolov5.git
best_weight_file_path: artifacts/model_training/best.pt


# Prediction
prediction:
root_dir: artifacts/prediction
3 changes: 2 additions & 1 deletion params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ REQUIRED_FILES: ["train", "valid", "data.yaml"] # For yoloV5
EPOCHS: 1
BATCH_SIZE: 16
PRETRAINED_MODEL_NAME: "yolov5s"
IMAGE_SIZE: 416
IMAGE_SIZE: 416
MIN_CONFIDENCE_SCORE: 0.5
58 changes: 58 additions & 0 deletions src/wasteDetection/pipeline/prediction_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import shutil
from pathlib import Path
from wasteDetection.constants import *
from wasteDetection.utils import read_yaml, encodeImageIntoBase64


class PredictionPipeline:

def __init__(self):
self.config = read_yaml(CONFIG_FILE_PATH)
self.params = read_yaml(PARAMS_FILE_PATH)
self.repo_name = self.config.model_training.yolo_model_gitgub_url.split('/')[-1].split('.')[0]


def __get_required_files(self):
self.best_weight_path = self.config.model_training.best_weight_file_path
self.required_files = [self.best_weight_path, self.imgpath]

for require_file_path in self.required_files:
destination_file_path = os.path.join(self.repo_name, os.path.basename(require_file_path))
shutil.copy2(require_file_path, destination_file_path)

def __remove_required_files(self):
os.system(f"rm -rf {self.repo_name}/runs")
for require_file_path in self.required_files:
file_path = os.path.join(self.repo_name, os.path.basename(require_file_path))
if os.path.exists(file_path):
os.remove(file_path)

def predict(self, imgpath):

self.imgpath = imgpath
self.__get_required_files()

os.system(f"cd {self.repo_name}/ && python detect.py --weights {os.path.basename(self.best_weight_path)} --img {self.params.IMAGE_SIZE} --conf {self.params.MIN_CONFIDENCE_SCORE} --source {os.path.basename(self.imgpath)}")

detected_image_path = Path(f"{self.repo_name}/runs/detect/exp/{os.path.basename(self.imgpath)}")
opencodedbase64 = encodeImageIntoBase64(detected_image_path)
result = {"image": opencodedbase64.decode('utf-8')}
predicted_img_path = os.path.join(self.config.prediction.root_dir, 'predicted.jpg')
shutil.copy2(detected_image_path, predicted_img_path)

self.__remove_required_files()

return result

def livePredict(self):
self.best_weight_path = self.config.model_training.best_weight_file_path
destination_weight_path = os.path.join(self.repo_name, os.path.basename(self.best_weight_path))
shutil.copy2(self.best_weight_path, destination_weight_path)

os.system(f"cd {self.repo_name}/ && python detect.py --weights {os.path.basename(self.best_weight_path)} --img {self.params.IMAGE_SIZE} --conf {self.params.MIN_CONFIDENCE_SCORE} --source 0")

os.system(f"rm -rf {self.repo_name}/runs")
if os.path.exists(destination_weight_path):
os.remove(destination_weight_path)

2 changes: 1 addition & 1 deletion src/wasteDetection/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def decodeImage(img_str: str, file_name: str):


@ensure_annotations
def encodeImageIntoBase64(img_path: Path)->str:
def encodeImageIntoBase64(img_path: Path):

"""return image as base64 encoded string
Expand Down
Loading

0 comments on commit 7b08b41

Please sign in to comment.