Skip to content

Commit

Permalink
Merge pull request #1 from veoco/master
Browse files Browse the repository at this point in the history
使用 aiosqlite 驱动异步化数据库操作
  • Loading branch information
vastsa authored Dec 11, 2022
2 parents ea64751 + b7315bb commit d25a921
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 50 deletions.
18 changes: 12 additions & 6 deletions database.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import datetime

from sqlalchemy import create_engine, DateTime
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Boolean, Column, Integer, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession


engine = create_async_engine("sqlite+aiosqlite:///database.db")

engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


async def get_session():
async with AsyncSession(engine, expire_on_commit=False) as s:
yield s


class Codes(Base):
__tablename__ = 'codes'
__tablename__ = "codes"
id = Column(Integer, primary_key=True, index=True)
code = Column(String(10), unique=True, index=True)
key = Column(String(30), unique=True, index=True)
Expand Down
85 changes: 44 additions & 41 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import os
import uuid
import threading
import random

from fastapi import FastAPI, Depends, UploadFile, Form, File
from sqlalchemy import or_
from sqlalchemy.orm import Session
from starlette.requests import Request
from starlette.responses import HTMLResponse, FileResponse
import random

from starlette.staticfiles import StaticFiles

import database
from database import engine, SessionLocal, Base
from sqlalchemy import or_, select, update, delete, create_engine
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio.session import AsyncSession

from database import engine, get_session, Base, Codes


engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
Base.metadata.create_all(bind=engine)

app = FastAPI()
if not os.path.exists('./static'):
os.makedirs('./static')
Expand Down Expand Up @@ -58,17 +63,9 @@ def delete_file(files):
os.remove('.' + file['text'])


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()


def get_code(db: Session = Depends(get_db)):
async def get_code(s: AsyncSession):
code = random.randint(10000, 99999)
while db.query(database.Codes).filter(database.Codes.code == code).first():
while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar():
code = random.randint(10000, 99999)
return str(code)

Expand All @@ -94,21 +91,23 @@ async def admin():


@app.post(f'/{admin_address}')
async def admin_post(request: Request, db: Session = Depends(get_db)):
async def admin_post(request: Request, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
codes = db.query(database.Codes).all()
query = select(Codes)
codes = (await s.execute(query)).scalars().all()
return {'code': 200, 'msg': '查询成功', 'data': codes}
else:
return {'code': 404, 'msg': '密码错误'}


@app.delete(f'/{admin_address}')
async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)):
async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
file = db.query(database.Codes).filter(database.Codes.code == code).first()
query = select(Codes).where(Codes.code == code)
file = (await s.execute(query)).scalars().first()
threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start()
db.delete(file)
db.commit()
await s.delete(file)
await s.commit()
return {'code': 200, 'msg': '删除成功'}
else:
return {'code': 404, 'msg': '密码错误'}
Expand Down Expand Up @@ -150,22 +149,26 @@ async def get_file(code: str, db: Session = Depends(get_db)):


@app.post('/')
async def index(request: Request, code: str, db: Session = Depends(get_db)):
async def index(request: Request, code: str, s: AsyncSession = Depends(get_session)):
ip = request.client.host
if not check_ip(ip):
return {'code': 404, 'msg': '错误次数过多,请稍后再试'}
info = db.query(database.Codes).filter(database.Codes.code == code).first()
query = select(Codes).where(Codes.code == code)
info = (await s.execute(query)).scalars().first()
if not info:
return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'}
if info.exp_time < datetime.datetime.now() or info.count == 0:
threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start()
db.delete(info)
db.commit()
await s.delete(info)
await s.commit()
return {'code': 404, 'msg': '取件码已过期,请联系寄件人'}
info.count -= 1
db.commit()
count = info.count - 1
query = update(Codes).where(Codes.id == info.id).values(count=count)
await s.execute(query)
await s.commit()
if info.type != 'text':
info.text = f'/select?code={code}'

return {
'code': 200,
'msg': '取件成功,请点击"取"查看',
Expand All @@ -175,17 +178,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):

@app.post('/share')
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
file: UploadFile = File(default=None), db: Session = Depends(get_db)):
exps = db.query(database.Codes).filter(
or_(
database.Codes.exp_time < datetime.datetime.now(),
database.Codes.count == 0
)
)
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps.all()],)).start()
exps.delete()
db.commit()
code = get_code(db)
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
exps = (await s.execute(query)).scalars().all()
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()

exps_ids = [exp.id for exp in exps]
query = delete(Codes).where(Codes.id.in_(exps_ids))
await s.execute(query)
await s.commit()

code = await get_code(s)
if style == '2':
if value > 7:
return {'code': 404, 'msg': '最大有效天数为7天'}
Expand All @@ -206,7 +209,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
return {'code': 404, 'msg': '文件过大'}
else:
size, _text, _type, name = len(text), text, 'text', '文本分享'
info = database.Codes(
info = Codes(
code=code,
text=_text,
size=size,
Expand All @@ -216,8 +219,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
exp_time=exp_time,
key=key
)
db.add(info)
db.commit()
s.add(info)
await s.commit()
return {
'code': 200,
'msg': '分享成功,请点击文件箱查看取件码',
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fastapi==0.88.0
python-multipart==0.0.5
fastapi[all]==0.88.0
aiosqlite==0.17.0
SQLAlchemy==1.4.44
uvicorn==0.20.0

0 comments on commit d25a921

Please sign in to comment.