Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/implement-refresh-tokens #9233

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300

# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

Expand Down
2 changes: 1 addition & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")

logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account
Expand Down
15 changes: 10 additions & 5 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
)


class OAuthConfig(BaseSettings):
class AuthConfig(BaseSettings):
"""
Configuration for OAuth authentication
Configuration for authentication and OAuth
"""

OAUTH_REDIRECT_PATH: str = Field(
Expand All @@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
)

GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub OAuth client ID",
default=None,
)

Expand All @@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
default=None,
)

ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
description="Expiration time for access tokens in minutes",
default=60,
)


class ModerationConfig(BaseSettings):
"""
Expand Down Expand Up @@ -607,6 +612,7 @@ def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
DataSetConfig,
Expand All @@ -621,14 +627,13 @@ class FeatureConfig(
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
OAuthConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
Expand Down
23 changes: 18 additions & 5 deletions api/controllers/console/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import services
from controllers.console import api
from controllers.console.setup import setup_required
from libs.helper import email, get_remote_ip
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
Expand Down Expand Up @@ -40,17 +40,16 @@ def post(self):
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}

token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

return {"result": "success", "data": token}
return {"result": "success", "data": token_pair.model_dump()}


class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1]
AccountService.logout(account=account, token=token)
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}

Expand Down Expand Up @@ -106,5 +105,19 @@ def get(self):
return {"result": "success"}


class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()

try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e:
return {"result": "fail", "data": str(e)}, 401


api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(RefreshTokenApi, "/refresh-token")
11 changes: 8 additions & 3 deletions api/controllers/console/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from configs import dify_config
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
Expand Down Expand Up @@ -81,9 +81,14 @@ def get(self, provider: str):

TenantService.create_owner_tenant_if_not_exist(account)

token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(
account=account,
ip_address=extract_remote_ip(request),
)

return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)


def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask_restful import Resource, reqparse

from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip
from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
Expand Down Expand Up @@ -46,7 +46,7 @@ def post(self):

# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
)

return {"result": "success"}, 201
Expand Down
2 changes: 1 addition & 1 deletion api/libs/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def generate_string(n):
return result


def get_remote_ip(request) -> str:
def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"):
return request.headers.get("Cf-Connecting-Ip")
elif request.headers.getlist("X-Forwarded-For"):
Expand Down
97 changes: 76 additions & 21 deletions api/services/account_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hashlib import sha256
from typing import Any, Optional

from pydantic import BaseModel
from sqlalchemy import func
from werkzeug.exceptions import Unauthorized

Expand Down Expand Up @@ -49,9 +50,39 @@
from tasks.mail_reset_password_task import send_reset_password_mail_task


class TokenPair(BaseModel):
access_token: str
refresh_token: str


REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)

@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"

@staticmethod
def _get_account_refresh_token_key(account_id: str) -> str:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"

@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)

@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))

@staticmethod
def load_user(user_id: str) -> None | Account:
account = Account.query.filter_by(id=user_id).first()
Expand All @@ -61,9 +92,7 @@ def load_user(user_id: str) -> None | Account:
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise Unauthorized("Account is banned or closed.")

current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
account_id=account.id, current=True
).first()
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
else:
Expand All @@ -84,10 +113,12 @@ def load_user(user_id: str) -> None | Account:
return account

@staticmethod
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
payload = {
"user_id": account.id,
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
"exp": exp,
"iss": dify_config.EDITION,
"sub": "Console API Passport",
}
Expand Down Expand Up @@ -213,30 +244,53 @@ def update_account(account, **kwargs):
return account

@staticmethod
def update_last_login(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()

@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None):
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
if ip_address:
AccountService.update_last_login(account, ip_address=ip_address)
exp = timedelta(days=30)
token = AccountService.get_account_jwt_token(account, exp=exp)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
return token
AccountService.update_login_info(account=account, ip_address=ip_address)

access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()

AccountService._store_refresh_token(refresh_token, account.id)

return TokenPair(access_token=access_token, refresh_token=refresh_token)

@staticmethod
def logout(*, account: Account, token: str):
redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
def logout(*, account: Account) -> None:
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)

@staticmethod
def load_logged_in_account(*, account_id: str, token: str):
if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
return None
def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id:
raise ValueError("Invalid refresh token")

account = AccountService.load_user(account_id.decode("utf-8"))
if not account:
raise ValueError("Invalid account")

# Generate new access token and refresh token
new_access_token = AccountService.get_account_jwt_token(account)
new_refresh_token = _generate_refresh_token()

AccountService._delete_refresh_token(refresh_token, account.id)
AccountService._store_refresh_token(new_refresh_token, account.id)

return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)

@staticmethod
def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id)

@classmethod
Expand All @@ -258,10 +312,6 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")


def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"


class TenantService:
@staticmethod
def create_tenant(name: str) -> Tenant:
Expand Down Expand Up @@ -698,3 +748,8 @@ def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) ->

invitation = json.loads(data)
return invitation


def _generate_refresh_token(length: int = 64):
laipz8200 marked this conversation as resolved.
Show resolved Hide resolved
token = secrets.token_hex(length)
return token
3 changes: 3 additions & 0 deletions docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
# The default value is 300 seconds.
FILES_ACCESS_TIMEOUT=300

# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0

Expand Down
1 change: 1 addition & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env
REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-}
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
Expand Down