Skip to content

Commit

Permalink
feat: add extraction and saving of similar images from zip archive
Browse files Browse the repository at this point in the history
Implemented logic to extract similar images from a zip archive based on their similarity to the uploaded image. The extracted images are saved in the 'uploads' directory, making them accessible through a web URL. The process involves:
- Loading the uploaded image and extracting its features.
- Comparing these features with images in the zip archive using a pre-trained model.
- Sorting the images by similarity and saving the top 5 matches.
- Serving the saved images to the client.
  • Loading branch information
Solrikk authored Aug 21, 2024
1 parent bbbaa3d commit 3c8475e
Showing 1 changed file with 85 additions and 47 deletions.
132 changes: 85 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,95 @@
import webbrowser
import threading
import hashlib
import os

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
import zipfile
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from image_utils import find_similar_images, s3_client, S3_BUCKET_NAME
from io import BytesIO
from PIL import Image, UnidentifiedImageError

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")

model = load_model('resnet50_local.h5')
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")

ZIP_PATH = 'photos.zip'
UPLOAD_FOLDER = 'uploads'

os.makedirs(UPLOAD_FOLDER, exist_ok=True)

def preprocess_image(image: Image.Image):
image = image.resize((224, 224))
image_array = np.array(image)
if image_array.shape[-1] == 4:
image_array = image_array[..., :3]
image_array = np.expand_dims(image_array, axis=0)
image_array = tf.keras.applications.resnet50.preprocess_input(image_array)
return image_array

def get_image_features(image: Image.Image):
preprocessed_image = preprocess_image(image)
features = model.predict(preprocessed_image)
return features

def compare_images(image1_features, image2_features):
return np.linalg.norm(image1_features - image2_features)

def get_images_from_zip():
with zipfile.ZipFile(ZIP_PATH, 'r') as archive:
image_keys = [name for name in archive.namelist() if name.lower().endswith(('.jpg', '.jpeg', '.png'))]
return image_keys

def extract_and_save_image(archive, image_key):
with archive.open(image_key) as image_file:
image = Image.open(image_file)
safe_image_name = os.path.basename(image_key)
image_path = os.path.join(UPLOAD_FOLDER, safe_image_name)
image.save(image_path)
return safe_image_name

@app.get("/", response_class=HTMLResponse)
async def upload_form():
with open('templates/upload_form.html', 'r') as file:
html_content = file.read()
return HTMLResponse(content=html_content)


@app.post("/upload/", response_class=JSONResponse)
async def create_upload_file(file: UploadFile = File(...)):
contents = await file.read()
image_hash = hashlib.sha256(contents).hexdigest()
file_path = f'uploads/{image_hash}.jpg'
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(contents)

s3_client.upload_file(file_path, S3_BUCKET_NAME, f'{image_hash}.jpg')

similar_images = await find_similar_images(file_path)

response_data = {
"message": "File uploaded successfully",
"filename": image_hash,
"similar_images": similar_images
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})

@app.post("/find_similar/")
async def find_similar_images(file: UploadFile = File(...)):
uploaded_image = Image.open(BytesIO(await file.read()))
uploaded_image_features = get_image_features(uploaded_image)
images = get_images_from_zip()
similarities = []

with zipfile.ZipFile(ZIP_PATH, 'r') as archive:
for image_key in images:
try:
with archive.open(image_key) as image_file:
image = Image.open(image_file)
image = image.convert('RGB')
image_features = get_image_features(image)
similarity = compare_images(uploaded_image_features, image_features)
similarities.append((image_key, similarity))
except UnidentifiedImageError:
print(f"Cannot identify image file: {image_key}")
except Exception as e:
print(f"Error processing image {image_key}: {e}")

similarities.sort(key=lambda x: x[1])
similar_images = []

with zipfile.ZipFile(ZIP_PATH, 'r') as archive:
for image_key, _ in similarities[:5]:
saved_image_name = extract_and_save_image(archive, image_key)
similar_images.append(saved_image_name)

return {
'filename': file.filename,
'similar_images': similar_images
}

return JSONResponse(content=response_data)


@app.get("/uploads/{image_hash}.jpg", response_class=FileResponse)
async def serve_uploaded_image(image_hash: str):
file_path = f'uploads/{image_hash}.jpg'
return FileResponse(file_path)


def start_server():
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)


if __name__ == "__main__":
threading.Timer(1.25, lambda: webbrowser.open("http://127.0.0.1:8000")).start()
start_server()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

0 comments on commit 3c8475e

Please sign in to comment.