Skip to content

Commit

Permalink
Very basic tests, but all passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mrtolkien committed Dec 6, 2021
1 parent 7ea8106 commit e999e91
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 37 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ Environment variables:
- It being compromised compromises the security of the API
- `FASTAPI_SIMPLE_SECURITY_HIDE_DOCS`: Whether or not to hide the API key related endpoints from the documentation
- `FASTAPI_SIMPLE_SECURITY_DB_LOCATION`: Location of the local sqlite database file
- /app/sqlite.db by default
- When running the app inside Docker, use a bind mount for persistence.
- `sqlite.db` in the running directory by default
- When running the app inside Docker, use a bind mount for persistence
- `FAST_API_SIMPLE_SECURITY_AUTOMATIC_EXPIRATION`: Duration, in days, until an API key is deemed expired
- 15 days by default

Expand Down
9 changes: 8 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
app = FastAPI()


@app.get("/unsecured")
async def unsecured_endpoint():
return {"message": "This is an unsecured endpoint"}


@app.get("/secure", dependencies=[Depends(fastapi_simple_security.api_key_security)])
async def secure_endpoint():
return {"message": "This is a secure endpoint"}


app.include_router(fastapi_simple_security.api_key_router, prefix="/auth", tags=["_auth"])
app.include_router(
fastapi_simple_security.api_key_router, prefix="/auth", tags=["_auth"]
)
61 changes: 45 additions & 16 deletions fastapi_simple_security/_sqlite_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ def __init__(self):
try:
self.db_location = os.environ["FASTAPI_SIMPLE_SECURITY_DB_LOCATION"]
except KeyError:
self.db_location = "app/sqlite.db"
self.db_location = "sqlite.db"

try:
self.expiration_limit = int(os.environ["FAST_API_SIMPLE_SECURITY_AUTOMATIC_EXPIRATION"])
self.expiration_limit = int(
os.environ["FAST_API_SIMPLE_SECURITY_AUTOMATIC_EXPIRATION"]
)
except KeyError:
self.expiration_limit = 15

Expand Down Expand Up @@ -51,7 +53,9 @@ def create_key(self, never_expire) -> str:
api_key,
1,
1 if never_expire else 0,
(datetime.utcnow() + timedelta(days=self.expiration_limit)).isoformat(timespec="seconds"),
(
datetime.utcnow() + timedelta(days=self.expiration_limit)
).isoformat(timespec="seconds"),
None,
0,
),
Expand Down Expand Up @@ -83,34 +87,47 @@ def renew_key(self, api_key: str, new_expiration_date: str) -> Optional[str]:

# Previously revoked key. Issue a text warning and reactivate it.
if response[0] == 0:
response_lines.append("This API key was revoked and has been reactivated.")
response_lines.append(
"This API key was revoked and has been reactivated."
)
# Expired key. Issue a text warning and reactivate it.
if (not response[3]) and (datetime.fromisoformat(response[2]) < datetime.utcnow()):
if (not response[3]) and (
datetime.fromisoformat(response[2]) < datetime.utcnow()
):
response_lines.append("This API key was expired and is now renewed.")

if not new_expiration_date:
parsed_expiration_date = (datetime.utcnow() + timedelta(days=self.expiration_limit)).isoformat(
timespec="seconds"
)
parsed_expiration_date = (
datetime.utcnow() + timedelta(days=self.expiration_limit)
).isoformat(timespec="seconds")
else:
try:
# We parse and re-write to the right timespec
parsed_expiration_date = datetime.fromisoformat(new_expiration_date).isoformat(timespec="seconds")
parsed_expiration_date = datetime.fromisoformat(
new_expiration_date
).isoformat(timespec="seconds")
except ValueError:
return "The expiration date could not be parsed. Please use ISO 8601."
return (
"The expiration date could not be parsed. Please use ISO 8601."
)

c.execute(
"""
UPDATE fastapi_simple_security
SET expiration_date = ?, is_active = 1
WHERE api_key = ?
""",
(parsed_expiration_date, api_key,),
(
parsed_expiration_date,
api_key,
),
)

connection.commit()

response_lines.append(f"The new expiration date for the API key is {parsed_expiration_date}")
response_lines.append(
f"The new expiration date for the API key is {parsed_expiration_date}"
)

return " ".join(response_lines)

Expand Down Expand Up @@ -162,15 +179,24 @@ def check_key(self, api_key: str) -> bool:
# Inactive
or response[0] != 1
# Expired key
or ((not response[3]) and (datetime.fromisoformat(response[2]) < datetime.utcnow()))
or (
(not response[3])
and (datetime.fromisoformat(response[2]) < datetime.utcnow())
)
):
# The key is not valid
return False
else:
# The key is valid

# We run the logging in a separate thread as writing takes some time
threading.Thread(target=self._update_usage, args=(api_key, response[1],)).start()
threading.Thread(
target=self._update_usage,
args=(
api_key,
response[1],
),
).start()

# We return directly
return True
Expand All @@ -186,7 +212,11 @@ def _update_usage(self, api_key: str, usage_count: int):
SET total_queries = ?, latest_query_date = ?
WHERE api_key = ?
""",
(usage_count + 1, datetime.utcnow().isoformat(timespec="seconds"), api_key),
(
usage_count + 1,
datetime.utcnow().isoformat(timespec="seconds"),
api_key,
),
)

connection.commit()
Expand All @@ -201,7 +231,6 @@ def get_usage_stats(self) -> List[Tuple[str, int, str, str, int]]:
with sqlite3.connect(self.db_location) as connection:
c = connection.cursor()

# TODO Add filtering somehow
c.execute(
"""
SELECT api_key, is_active, never_expire, expiration_date, latest_query_date, total_queries
Expand Down
146 changes: 131 additions & 15 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e999e91

Please sign in to comment.