-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add extraction and saving of similar images from zip archive
Implemented logic to extract similar images from a zip archive based on their similarity to the uploaded image. The extracted images are saved in the 'uploads' directory, making them accessible through a web URL. The process involves: - Loading the uploaded image and extracting its features. - Comparing these features with images in the zip archive using a pre-trained model. - Sorting the images by similarity and saving the top 5 matches. - Serving the saved images to the client.
- Loading branch information
Showing
1 changed file
with
85 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,95 @@ | ||
import webbrowser | ||
import threading | ||
import hashlib | ||
import os | ||
|
||
from fastapi import FastAPI, File, UploadFile | ||
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse | ||
import zipfile | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras.models import load_model | ||
from fastapi import FastAPI, UploadFile, File, Request | ||
from fastapi.responses import HTMLResponse | ||
from fastapi.templating import Jinja2Templates | ||
from fastapi.staticfiles import StaticFiles | ||
from image_utils import find_similar_images, s3_client, S3_BUCKET_NAME | ||
from io import BytesIO | ||
from PIL import Image, UnidentifiedImageError | ||
|
||
app = FastAPI() | ||
app.mount("/static", StaticFiles(directory="static"), name="static") | ||
|
||
model = load_model('resnet50_local.h5') | ||
templates = Jinja2Templates(directory="templates") | ||
app.mount("/static", StaticFiles(directory="static"), name="static") | ||
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads") | ||
|
||
ZIP_PATH = 'photos.zip' | ||
UPLOAD_FOLDER = 'uploads' | ||
|
||
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | ||
|
||
def preprocess_image(image: Image.Image): | ||
image = image.resize((224, 224)) | ||
image_array = np.array(image) | ||
if image_array.shape[-1] == 4: | ||
image_array = image_array[..., :3] | ||
image_array = np.expand_dims(image_array, axis=0) | ||
image_array = tf.keras.applications.resnet50.preprocess_input(image_array) | ||
return image_array | ||
|
||
def get_image_features(image: Image.Image): | ||
preprocessed_image = preprocess_image(image) | ||
features = model.predict(preprocessed_image) | ||
return features | ||
|
||
def compare_images(image1_features, image2_features): | ||
return np.linalg.norm(image1_features - image2_features) | ||
|
||
def get_images_from_zip(): | ||
with zipfile.ZipFile(ZIP_PATH, 'r') as archive: | ||
image_keys = [name for name in archive.namelist() if name.lower().endswith(('.jpg', '.jpeg', '.png'))] | ||
return image_keys | ||
|
||
def extract_and_save_image(archive, image_key): | ||
with archive.open(image_key) as image_file: | ||
image = Image.open(image_file) | ||
safe_image_name = os.path.basename(image_key) | ||
image_path = os.path.join(UPLOAD_FOLDER, safe_image_name) | ||
image.save(image_path) | ||
return safe_image_name | ||
|
||
@app.get("/", response_class=HTMLResponse) | ||
async def upload_form(): | ||
with open('templates/upload_form.html', 'r') as file: | ||
html_content = file.read() | ||
return HTMLResponse(content=html_content) | ||
|
||
|
||
@app.post("/upload/", response_class=JSONResponse) | ||
async def create_upload_file(file: UploadFile = File(...)): | ||
contents = await file.read() | ||
image_hash = hashlib.sha256(contents).hexdigest() | ||
file_path = f'uploads/{image_hash}.jpg' | ||
os.makedirs(os.path.dirname(file_path), exist_ok=True) | ||
with open(file_path, 'wb') as f: | ||
f.write(contents) | ||
|
||
s3_client.upload_file(file_path, S3_BUCKET_NAME, f'{image_hash}.jpg') | ||
|
||
similar_images = await find_similar_images(file_path) | ||
|
||
response_data = { | ||
"message": "File uploaded successfully", | ||
"filename": image_hash, | ||
"similar_images": similar_images | ||
async def read_root(request: Request): | ||
return templates.TemplateResponse("index.html", {"request": request}) | ||
|
||
@app.post("/find_similar/") | ||
async def find_similar_images(file: UploadFile = File(...)): | ||
uploaded_image = Image.open(BytesIO(await file.read())) | ||
uploaded_image_features = get_image_features(uploaded_image) | ||
images = get_images_from_zip() | ||
similarities = [] | ||
|
||
with zipfile.ZipFile(ZIP_PATH, 'r') as archive: | ||
for image_key in images: | ||
try: | ||
with archive.open(image_key) as image_file: | ||
image = Image.open(image_file) | ||
image = image.convert('RGB') | ||
image_features = get_image_features(image) | ||
similarity = compare_images(uploaded_image_features, image_features) | ||
similarities.append((image_key, similarity)) | ||
except UnidentifiedImageError: | ||
print(f"Cannot identify image file: {image_key}") | ||
except Exception as e: | ||
print(f"Error processing image {image_key}: {e}") | ||
|
||
similarities.sort(key=lambda x: x[1]) | ||
similar_images = [] | ||
|
||
with zipfile.ZipFile(ZIP_PATH, 'r') as archive: | ||
for image_key, _ in similarities[:5]: | ||
saved_image_name = extract_and_save_image(archive, image_key) | ||
similar_images.append(saved_image_name) | ||
|
||
return { | ||
'filename': file.filename, | ||
'similar_images': similar_images | ||
} | ||
|
||
return JSONResponse(content=response_data) | ||
|
||
|
||
@app.get("/uploads/{image_hash}.jpg", response_class=FileResponse) | ||
async def serve_uploaded_image(image_hash: str): | ||
file_path = f'uploads/{image_hash}.jpg' | ||
return FileResponse(file_path) | ||
|
||
|
||
def start_server(): | ||
import uvicorn | ||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
threading.Timer(1.25, lambda: webbrowser.open("http://127.0.0.1:8000")).start() | ||
start_server() | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |