Skip to content

Commit

Permalink
# Refactored image utilities to use classification utils and added su…
Browse files Browse the repository at this point in the history
…pport for multiple categories

- Separated classification logic into classification_utils.py.
- Updated image downloading and comparison logic.
- Added support to classify and filter images from Architecture, Aviation, and Backgrounds folders.
- Included necessary imports and fixed missing cosine_similarity import.
  • Loading branch information
Solrikk authored Jul 17, 2024
1 parent 0467f04 commit 762f1ec
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions image_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import aiohttp
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image
import boto3
from botocore.client import Config
from concurrent.futures import ThreadPoolExecutor
import asyncio
from sklearn.metrics.pairwise import cosine_similarity
from classification_utils import is_architecture_image, is_aviation_image, is_background_image, extract_features_batch
from sklearn.metrics.pairwise import cosine_similarity # Новый импорт

S3_BUCKET_NAME = 'YOUR_S3_BUCKET_NAME'
S3_REGION = 'YOUR_S3_REGION'
Expand All @@ -25,9 +22,6 @@
config=Config(region_name=S3_REGION, signature_version='s3v4'),
)

base_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
model = Model(inputs=base_model.input, outputs=base_model.output)


async def download_image_from_s3(file_name):
try:
Expand All @@ -43,16 +37,6 @@ async def download_image_from_s3(file_name):
return None


def extract_features_batch(img_list):
img_array = np.array(
[cv2.resize(img, (224, 224)) for img in img_list if img is not None])
if len(img_array) == 0:
return np.array([])
img_array = preprocess_input(img_array)
features = model.predict(img_array, batch_size=32)
return features


async def compare_images(file_name, target_features, loop):
try:
current_image = await download_image_from_s3(file_name)
Expand All @@ -70,16 +54,19 @@ async def compare_images(file_name, target_features, loop):
return (0, "")


async def list_s3_images():
async def list_s3_images(folder=''):
try:
images = []
continuation_token = None
while True:
if continuation_token:
response = s3_client.list_objects_v2(
Bucket=S3_BUCKET_NAME, ContinuationToken=continuation_token)
Bucket=S3_BUCKET_NAME,
ContinuationToken=continuation_token,
Prefix=folder)
else:
response = s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME)
response = s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME,
Prefix=folder)
images.extend([
item['Key'] for item in response.get('Contents', [])
if item['Key'].lower().endswith(('.jpg', '.jpeg', '.png'))
Expand All @@ -96,13 +83,26 @@ async def list_s3_images():

async def find_similar_images(file_path):
loop = asyncio.get_event_loop()
is_architecture = await is_architecture_image(file_path)
is_aviation = await is_aviation_image(file_path)
is_background = await is_background_image(file_path)

if is_architecture:
folder = "Architecture/"
elif is_aviation:
folder = "Aviation/"
elif is_background:
folder = "Backgrounds/"
else:
folder = ''

target_image = cv2.imread(file_path)
if target_image is None:
raise ValueError(
f"Failed to read target image from file path: {file_path}")
target_features = extract_features_batch([target_image])

file_names = await list_s3_images()
file_names = await list_s3_images(folder=folder)

tasks = [
compare_images(file_name, target_features, loop)
Expand Down

0 comments on commit 762f1ec

Please sign in to comment.