Skip to content

Commit

Permalink
feat(auth): implement refresh tokens and enhance login process
Browse files Browse the repository at this point in the history
- Introduced `TokenPair` model for managing access and refresh tokens.
- Added `refresh_token` method to generate new tokens upon expiration.
- Updated login/logout processes to handle token pairs and enhanced security.
- Replaced `get_remote_ip` with `extract_remote_ip` for clarity.
- Added endpoint for refreshing tokens to maintain user session continuity.
  • Loading branch information
laipz8200 committed Oct 11, 2024
1 parent 4256149 commit 49197f1
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 42 deletions.
3 changes: 3 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300

# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

Expand Down
2 changes: 1 addition & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")

logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account
Expand Down
15 changes: 10 additions & 5 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ class WorkflowConfig(BaseSettings):
)


class OAuthConfig(BaseSettings):
class AuthConfig(BaseSettings):
"""
Configuration for OAuth authentication
Configuration for authentication and OAuth
"""

OAUTH_REDIRECT_PATH: str = Field(
Expand All @@ -390,7 +390,7 @@ class OAuthConfig(BaseSettings):
)

GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub OAuth client ID",
default=None,
)

Expand All @@ -409,6 +409,11 @@ class OAuthConfig(BaseSettings):
default=None,
)

ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
description="Expiration time for access tokens in minutes",
default=60,
)


class ModerationConfig(BaseSettings):
"""
Expand Down Expand Up @@ -664,6 +669,7 @@ class LoginConfig(BaseSettings):
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
DataSetConfig,
Expand All @@ -678,14 +684,13 @@ class FeatureConfig(
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
OAuthConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
LoginConfig,
# hosted services config
HostedServiceConfig,
Expand Down
27 changes: 20 additions & 7 deletions api/controllers/console/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from controllers.console.error import EmailSendIpLimitError, NotAllowedCreateWorkspace, NotAllowedRegister
from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, get_remote_ip
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, RegisterService, TenantService
Expand Down Expand Up @@ -74,17 +74,16 @@ def post(self):
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}

token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token}
return {"result": "success", "data": token_pair.model_dump()}


class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1]
AccountService.logout(account=account, token=token)
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}

Expand Down Expand Up @@ -122,7 +121,7 @@ def post(self):
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()

ip_address = get_remote_ip(request)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()

Expand Down Expand Up @@ -187,13 +186,27 @@ def post(self):
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token = AccountService.login(account, ip_address=get_remote_ip(request))
token = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token}


class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()

try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e:
return {"result": "fail", "data": str(e)}, 401


api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(RefreshTokenApi, "/refresh-token")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
11 changes: 8 additions & 3 deletions api/controllers/console/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
Expand Down Expand Up @@ -120,9 +120,14 @@ def get(self, provider: str):
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)

token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(
account=account,
ip_address=extract_remote_ip(request),
)

return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)


def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask_restful import Resource, reqparse

from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip
from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
Expand Down Expand Up @@ -46,7 +46,7 @@ def post(self):

# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
)

return {"result": "success"}, 201
Expand Down
2 changes: 1 addition & 1 deletion api/libs/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def generate_string(n):
return result


def get_remote_ip(request) -> str:
def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"):
return request.headers.get("Cf-Connecting-Ip")
elif request.headers.getlist("X-Forwarded-For"):
Expand Down
102 changes: 80 additions & 22 deletions api/services/account_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hashlib import sha256
from typing import Any, Optional

from pydantic import BaseModel
from sqlalchemy import func
from werkzeug.exceptions import Unauthorized

Expand Down Expand Up @@ -53,13 +54,43 @@
from tasks.mail_reset_password_task import send_reset_password_mail_task


class TokenPair(BaseModel):
access_token: str
refresh_token: str


REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5

@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"

@staticmethod
def _get_account_refresh_token_key(account_id: str) -> str:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"

@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)

@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))

@staticmethod
def load_user(user_id: str) -> None | Account:
account = Account.query.filter_by(id=user_id).first()
Expand All @@ -69,9 +100,7 @@ def load_user(user_id: str) -> None | Account:
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise Unauthorized("Account is banned or closed.")

current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
account_id=account.id, current=True
).first()
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
else:
Expand All @@ -92,10 +121,13 @@ def load_user(user_id: str) -> None | Account:
return account

@staticmethod
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
def get_account_jwt_token(account: Account) -> str:
payload = {
"user_id": account.id,
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
"exp": int(
datetime.now(timezone.utc).timestamp()
+ timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES).total_seconds()
),
"iss": dify_config.EDITION,
"sub": "Console API Passport",
}
Expand Down Expand Up @@ -254,33 +286,56 @@ def update_account(account, **kwargs):
return account

@staticmethod
def update_last_login(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()

@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None):
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
if ip_address:
AccountService.update_last_login(account, ip_address=ip_address)
exp = timedelta(days=30)
AccountService.update_login_info(account=account, ip_address=ip_address)
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
db.session.commit()
token = AccountService.get_account_jwt_token(account, exp=exp)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
return token

access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()

AccountService._store_refresh_token(refresh_token, account.id)

return TokenPair(access_token=access_token, refresh_token=refresh_token)

@staticmethod
def logout(*, account: Account, token: str):
redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
def logout(*, account: Account) -> None:
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)

@staticmethod
def load_logged_in_account(*, account_id: str, token: str):
if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
return None
def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id:
raise ValueError("Invalid refresh token")

account = AccountService.load_user(account_id.decode("utf-8"))
if not account:
raise ValueError("Invalid account")

# Generate new access token and refresh token
new_access_token = AccountService.get_account_jwt_token(account)
new_refresh_token = _generate_refresh_token()

AccountService._delete_refresh_token(refresh_token, account.id)
AccountService._store_refresh_token(new_refresh_token, account.id)

return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)

@staticmethod
def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id)

@classmethod
Expand Down Expand Up @@ -421,10 +476,6 @@ def is_email_send_ip_limit(ip_address: str):
return False


def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"


class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False) -> Tenant:
Expand Down Expand Up @@ -822,7 +873,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str):
redis_client.delete(cls._get_invitation_token_key(token))

@classmethod
def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]:
def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
Expand Down Expand Up @@ -883,3 +936,8 @@ def _get_invitation_by_token(

invitation = json.loads(data)
return invitation


def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
return token
2 changes: 1 addition & 1 deletion api/services/feature_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ def _fulfill_params_from_enterprise(cls, features):
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
3 changes: 3 additions & 0 deletions docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
# The default value is 300 seconds.
FILES_ACCESS_TIMEOUT=300

# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0

Expand Down
Loading

0 comments on commit 49197f1

Please sign in to comment.