-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from Handoni/develop
v0.1.1-demo
- Loading branch information
Showing
34 changed files
with
4,717 additions
and
427 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from fastapi import APIRouter, HTTPException, Depends | ||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | ||
from api.schemas.user import UserCreate, User, Token | ||
from services.user_service import create_user, get_user_by_email, authenticate_user | ||
from utils.jwt_handler import create_access_token, decode_access_token | ||
import jwt | ||
|
||
router = APIRouter() | ||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | ||
|
||
@router.post("/users/", response_model=User) | ||
def register_user(user: UserCreate): | ||
existing_user = get_user_by_email(user.email) | ||
if existing_user: | ||
raise HTTPException(status_code=400, detail="Email already registered") | ||
|
||
try: | ||
user_record = create_user(user) | ||
return User(id=user_record['id'], email=user_record['email'], nickname=user_record['nickname'], sex=user_record['sex'], age=user_record['age']) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, detail=str(e)) | ||
|
||
@router.post("/token", response_model=Token) | ||
def login_user(form_data: OAuth2PasswordRequestForm = Depends()): | ||
user = authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException(status_code=400, detail="Invalid email or password") | ||
|
||
access_token = create_access_token(data={"sub": user['email']}) | ||
return {"access_token": access_token, "token_type": "bearer"} | ||
|
||
@router.get("/users/me/", response_model=User) | ||
def read_users_me(token: str = Depends(oauth2_scheme)): | ||
credentials_exception = HTTPException( | ||
status_code=401, | ||
detail="Could not validate credentials", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
try: | ||
email = decode_access_token(token) | ||
user = get_user_by_email(email) | ||
if user is None: | ||
raise credentials_exception | ||
return User(id=user['id'], email=user['email'], sex=user['sex'], age=user['age']) | ||
except jwt.PyJWTError: | ||
raise credentials_exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pydantic import BaseModel, EmailStr, validator | ||
|
||
class Token(BaseModel): | ||
access_token: str | ||
token_type: str | ||
|
||
class TokenData(BaseModel): | ||
email: str | None = None | ||
|
||
class UserCreate(BaseModel): | ||
email: EmailStr | ||
nickname: str | ||
password: str | ||
sex: str | ||
age: int | ||
|
||
@validator("sex") | ||
def validate_sex(cls, v): | ||
if v not in ['male', 'female']: | ||
raise ValueError('Sex field must be either "male" or "female".') | ||
return v | ||
|
||
class User(BaseModel): | ||
id: str | ||
nickname: str | ||
email: EmailStr | ||
sex: str | ||
age: int | ||
|
||
class Config: | ||
from_attributes = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pandas as pd | ||
import csv | ||
|
||
def clean(symptom): | ||
# 증상이 None이거나 NaN인 경우 처리 | ||
if pd.isna(symptom): | ||
return None | ||
parts = symptom.split('^') | ||
cleaned_parts = [part.split('_')[1] if '_' in part else part for part in parts] | ||
return '^'.join(cleaned_parts) | ||
|
||
def process_csv(input_file, output_file): | ||
try: | ||
df = pd.read_csv(input_file) | ||
except FileNotFoundError: | ||
print(f"파일을 찾을 수 없습니다: {input_file}") | ||
return | ||
|
||
# 'Disease' 필드에서 UMLS 코드 제거 | ||
df['Disease'] = df['Disease'].apply(clean) | ||
|
||
# 'Symptom' 필드에서 UMLS 코드 제거 | ||
df['Symptom'] = df['Symptom'].apply(clean) | ||
|
||
# 'Count of Disease Occurrence' 열을 Int64로 변환하여 NaN 값이 유지되도록 함 | ||
df['Count of Disease Occurrence'] = df['Count of Disease Occurrence'].astype('Int64') | ||
|
||
# 결과를 새로운 CSV 파일로 저장 | ||
df.to_csv(output_file, index=False, quoting=csv.QUOTE_ALL) | ||
|
||
# 파일 경로 설정 | ||
input_file = "C:/Users/이상윤/Documents/coding/Apayo/backend/app/data/raw_data_2.csv" | ||
output_file = "C:/Users/이상윤/Documents/coding/Apayo/backend/app/data/output.csv" | ||
|
||
# 함수 실행 | ||
process_csv(input_file, output_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from openai import OpenAI | ||
import pandas as pd | ||
import numpy as np | ||
from pymongo import MongoClient, ASCENDING, DESCENDING | ||
from core.config import get_settings | ||
|
||
settings = get_settings() | ||
GPT_API_KEY = settings.gpt_api_key | ||
client = OpenAI(api_key=GPT_API_KEY) | ||
|
||
mongo_client = MongoClient(settings.mongo_uri) | ||
db = mongo_client['disease_embedding_db'] | ||
|
||
def get_embedding(text): | ||
return client.embeddings.create(input=[text], model='text-embedding-3-small').data[0].embedding | ||
|
||
def get_last_saved_disease(): | ||
disease = db.diseases_embeddings.find_one(sort=[("id", DESCENDING)]) | ||
if disease: | ||
return disease["id"] | ||
return None | ||
|
||
def create_embedding_data(): | ||
df = pd.read_csv('C:/Users/이상윤/Documents/coding/Apayo/backend/app/data/output.csv') | ||
df.fillna(method='ffill', inplace=True) | ||
df.applymap(lambda x: x.replace('\xa0','').replace('\xa9','') if type(x) == str else x) | ||
|
||
# last_saved_disease = get_last_saved_disease() | ||
# start_saving = last_saved_disease is None | ||
# print(f"Last saved disease: {last_saved_disease}") | ||
start_saving = False | ||
|
||
# 질병별로 그룹화 및 처리 | ||
for disease, group in df.groupby("Disease"): | ||
if not start_saving: | ||
if disease == 'obesity morbid': | ||
start_saving = True | ||
continue | ||
print(f"Processing {disease}") | ||
symptoms = group["Symptom"].tolist() | ||
disease_embedding = {"embedding": get_embedding(disease)} | ||
|
||
# 질병 문서 생성 | ||
disease_data = { | ||
"_id": disease, | ||
"embedding": disease_embedding | ||
} | ||
db.diseases_embeddings.insert_one(disease_data) | ||
|
||
# 각 증상을 서브컬렉션에 추가 | ||
for symptom in symptoms: | ||
# 증상 값이 유효한지 확인 | ||
symptom = symptom.strip() | ||
if symptom: | ||
symptom_embedding = {"embedding": get_embedding(symptom)} | ||
symptom_data = { | ||
"disease_id": disease, | ||
"symptom": symptom, | ||
"embedding": symptom_embedding | ||
} | ||
db.symptoms.insert_one(symptom_data) | ||
|
||
# if __name__ == "__main__": | ||
# create_embedding_data() |
Oops, something went wrong.