-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
137 lines (108 loc) · 4.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import ntpath
import os
import shutil
from os import getenv
from typing import List, Optional
from uuid import uuid4
import datetime
import cv2
import numpy as np
from fastapi import Depends, FastAPI, UploadFile, File, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.responses import FileResponse
import database
import authUtils
from ImageSearch import ImageSearch
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authUtils.authenticate_user(form_data.username, form_data.password)
if user is None:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = datetime.timedelta(minutes=int(getenv("ACCESS_TOKEN_EXPIRE_MINUTES")))
access_token = authUtils.create_access_token(
user.username, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
def __save_image(path_prefix, received_image):
file_ext = os.path.splitext(received_image.filename)[1]
filename = uuid4().hex + file_ext
image_path = path_prefix + filename
with open(image_path, "wb+") as local_image:
shutil.copyfileobj(received_image.file, local_image)
return filename, image_path
@app.put("/upload")
async def upload_image(visibility: str, current_user: authUtils.User = Depends(authUtils.get_user_from_token),
files: List[UploadFile] = File(...)):
if visibility == "private":
is_private = True
path_prefix = 'private/'
elif visibility == "public":
is_private = False
path_prefix = 'public/'
else:
raise HTTPException(status_code=400, detail="Visibility can only be 'public' or 'private'")
relative_image_paths = []
for received_image in files:
filename, image_path = __save_image(path_prefix, received_image)
database.create_image(filename, current_user.username, is_private)
relative_image_paths.append(image_path)
# index image if public
if not is_private:
image_search.batch_index_images(relative_image_paths)
return {"filenames": [ntpath.basename(f) for f in relative_image_paths]}
@app.delete("/delete", status_code=204)
async def delete_image(image_name: str, current_user: authUtils.User = Depends(authUtils.get_user_from_token)):
image_record = database.get_image(image_name)
if image_record is None:
raise HTTPException(status_code=404, detail="Image does not exist")
elif image_record.username != current_user.username:
raise HTTPException(status_code=401, detail="User not authorized")
if image_record.private:
path_prefix = 'private/'
image_search.delete_image(path_prefix + image_name)
else:
path_prefix = 'public/'
database.delete_image(image_name)
os.remove(path_prefix + image_name)
return
@app.get("/image")
async def retrieve_image(image_name: str,
current_user: Optional[authUtils.User] = Depends(authUtils.optional_get_user_from_token)):
image_record = database.get_image(image_name)
if image_record is None:
raise HTTPException(status_code=404, detail="Image does not exist")
elif image_record.private:
if current_user is None:
raise HTTPException(status_code=401, detail="Not authenticated")
elif image_record.username != current_user.username:
raise HTTPException(status_code=401, detail="User not authorized")
path_prefix = 'private/'
else:
path_prefix = 'public/'
return FileResponse(path_prefix + image_name)
@app.post("/search")
async def search_image(file: UploadFile = File(...)):
filestr = file.file.read()
npimg = np.fromstring(filestr, np.uint8)
img = cv2.imdecode(npimg, cv2.IMREAD_UNCHANGED)
results = image_search.search(img)
return results
@app.post("/signup", status_code=204)
async def sign_up(username: str, password: str):
if not database.create_user(username, authUtils.hash_password(password)):
raise HTTPException(status_code=409, detail="Username already taken")
else:
return
def make_dirs():
folders = ['public', 'private']
for folder in folders:
try:
os.mkdir(folder)
except FileExistsError:
pass
image_search = ImageSearch(kmeans_clusters=10)
make_dirs()
image_search.batch_index_images([f'public/{filename}' for filename in os.listdir("public")])