-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prediction pipeline & app module added
- Loading branch information
1 parent
4c483c0
commit 7b08b41
Showing
6 changed files
with
471 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from flask import Flask, render_template, jsonify, request | ||
from flask_cors import CORS, cross_origin | ||
from wasteDetection.pipeline.training_pipeline import TrainingPipeline | ||
from wasteDetection.utils import read_yaml, create_directories, decodeImage | ||
from wasteDetection.constants import * | ||
from wasteDetection.pipeline.prediction_pipeline import PredictionPipeline | ||
import os | ||
|
||
|
||
app = Flask(__name__) | ||
CORS(app) | ||
|
||
@app.route('/') | ||
def home(): | ||
return render_template('index.html') | ||
|
||
|
||
@app.route('/train') | ||
def trainRoute(): | ||
training_pipeline = TrainingPipeline() | ||
training_pipeline.train() | ||
return 'Model trained successfully' | ||
|
||
|
||
@app.route('/predict', methods=['POST']) | ||
@cross_origin() | ||
def predictRoute(): | ||
try: | ||
config = read_yaml(CONFIG_FILE_PATH) | ||
prediction_config = config.prediction | ||
create_directories([prediction_config.root_dir]) | ||
|
||
file_name = os.path.join(prediction_config.root_dir, 'inputImage.jpg') | ||
waste_detection = PredictionPipeline() | ||
|
||
image = request.json['image'] | ||
decodeImage(image, file_name) | ||
result = waste_detection.predict(file_name) | ||
return jsonify(result) | ||
|
||
except Exception as e: | ||
raise e | ||
|
||
@app.route('/live', methods=['GET']) | ||
@cross_origin() | ||
def predictLive(): | ||
try: | ||
waste_detection = PredictionPipeline() | ||
waste_detection.livePredict() | ||
|
||
except Exception as e: | ||
raise e | ||
|
||
if __name__ == '__main__': | ||
app.run(host='0.0.0.0', port=8080) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
from wasteDetection.constants import * | ||
from wasteDetection.utils import read_yaml, encodeImageIntoBase64 | ||
|
||
|
||
class PredictionPipeline: | ||
|
||
def __init__(self): | ||
self.config = read_yaml(CONFIG_FILE_PATH) | ||
self.params = read_yaml(PARAMS_FILE_PATH) | ||
self.repo_name = self.config.model_training.yolo_model_gitgub_url.split('/')[-1].split('.')[0] | ||
|
||
|
||
def __get_required_files(self): | ||
self.best_weight_path = self.config.model_training.best_weight_file_path | ||
self.required_files = [self.best_weight_path, self.imgpath] | ||
|
||
for require_file_path in self.required_files: | ||
destination_file_path = os.path.join(self.repo_name, os.path.basename(require_file_path)) | ||
shutil.copy2(require_file_path, destination_file_path) | ||
|
||
def __remove_required_files(self): | ||
os.system(f"rm -rf {self.repo_name}/runs") | ||
for require_file_path in self.required_files: | ||
file_path = os.path.join(self.repo_name, os.path.basename(require_file_path)) | ||
if os.path.exists(file_path): | ||
os.remove(file_path) | ||
|
||
def predict(self, imgpath): | ||
|
||
self.imgpath = imgpath | ||
self.__get_required_files() | ||
|
||
os.system(f"cd {self.repo_name}/ && python detect.py --weights {os.path.basename(self.best_weight_path)} --img {self.params.IMAGE_SIZE} --conf {self.params.MIN_CONFIDENCE_SCORE} --source {os.path.basename(self.imgpath)}") | ||
|
||
detected_image_path = Path(f"{self.repo_name}/runs/detect/exp/{os.path.basename(self.imgpath)}") | ||
opencodedbase64 = encodeImageIntoBase64(detected_image_path) | ||
result = {"image": opencodedbase64.decode('utf-8')} | ||
predicted_img_path = os.path.join(self.config.prediction.root_dir, 'predicted.jpg') | ||
shutil.copy2(detected_image_path, predicted_img_path) | ||
|
||
self.__remove_required_files() | ||
|
||
return result | ||
|
||
def livePredict(self): | ||
self.best_weight_path = self.config.model_training.best_weight_file_path | ||
destination_weight_path = os.path.join(self.repo_name, os.path.basename(self.best_weight_path)) | ||
shutil.copy2(self.best_weight_path, destination_weight_path) | ||
|
||
os.system(f"cd {self.repo_name}/ && python detect.py --weights {os.path.basename(self.best_weight_path)} --img {self.params.IMAGE_SIZE} --conf {self.params.MIN_CONFIDENCE_SCORE} --source 0") | ||
|
||
os.system(f"rm -rf {self.repo_name}/runs") | ||
if os.path.exists(destination_weight_path): | ||
os.remove(destination_weight_path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.