Skip to content

Commit

Permalink
Continuing flask-support for flask-sqlalchemy 3
Browse files Browse the repository at this point in the history
  • Loading branch information
Harshit Gupta committed Aug 27, 2023
1 parent 74f37e2 commit f35377e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 83 deletions.
121 changes: 42 additions & 79 deletions flask_appbuilder/models/sqla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
import logging
import re

from flask_sqlalchemy import (
_QueryProperty,
DefaultMeta,
get_state,
SessionBase,
SignallingSession,
SQLAlchemy,
)
from sqlalchemy import orm
try:
from flask_sqlalchemy import DefaultMeta, Model as FSQLAlchemyModel
except ImportError:
from flask_sqlalchemy.model import ( # noqa
DefaultMeta,
Model as FSQLAlchemyModel,
)
from flask_sqlalchemy import SQLAlchemy


try:
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.ext.declarative import as_declarative, declarative_base
except ImportError:
from sqlalchemy.ext.declarative.api import as_declarative
from sqlalchemy.ext.declarative.api import as_declarative # noqa

try:
from sqlalchemy.orm.util import identity_key # noqa
Expand All @@ -29,71 +29,6 @@
_camelcase_re = re.compile(r"([A-Z]+)(?=[a-z0-9])")


class CustomSignallingSession(SignallingSession):
"""
Custom Signaling Session to support SQLALchemy>=1.4 with flask-sqlalchemy 2.X
https://github.com/pallets/flask-sqlalchemy/issues/953
"""

def get_bind(self, mapper=None, *args, **kwargs):
"""Return the engine or connection for a given model or
table, using the ``__bind_key__`` if it is set.
Patch from https://github.com/pallets/flask-sqlalchemy/pull/1001
"""
# mapper is None if someone tries to just get a connection
if mapper is not None:
try:
# SA >= 1.3
persist_selectable = mapper.persist_selectable
except AttributeError:
# SA < 1.3
persist_selectable = mapper.mapped_table
info = getattr(persist_selectable, "info", {})
bind_key = info.get("bind_key")
if bind_key is not None:
state = get_state(self.app)
return state.db.get_engine(self.app, bind=bind_key)
return SessionBase.get_bind(self, mapper, *args, **kwargs)


class SQLA(SQLAlchemy):
"""
This is a child class of flask_SQLAlchemy
It's purpose is to override the declarative base of the original
package. So that it is bound to F.A.B. Model class allowing the dev
to be in the same namespace of the security tables (and others)
and can use AuditMixin class alike.
Use it and configure it just like flask_SQLAlchemy
"""

def make_declarative_base(self, model, metadata=None):
base = Model
base.query = _QueryProperty(self)
return base

def get_tables_for_bind(self, bind=None):
"""Returns a list of all tables relevant for a bind."""
result = []
tables = Model.metadata.tables
for key in tables:
if tables[key].info.get("bind_key") == bind:
result.append(tables[key])
return result

def create_session(self, options):
"""
Custom Session factory to support SQLALchemy>=1.4 with flask-sqlalchemy 2.X
https://github.com/pallets/flask-sqlalchemy/issues/953
:param options: dict of keyword arguments passed to session class
"""

return orm.sessionmaker(class_=CustomSignallingSession, db=self, **options)


class ModelDeclarativeMeta(DefaultMeta):
"""
Base Model declarative meta for all Models definitions.
Expand All @@ -102,8 +37,7 @@ class ModelDeclarativeMeta(DefaultMeta):
"""


@as_declarative(name="Model", metaclass=ModelDeclarativeMeta)
class Model(object):
class BaseModel:
"""
Use this class has the base for your models,
it will define your table names automatically
Expand Down Expand Up @@ -132,7 +66,36 @@ def to_json(self):
return result


Model = declarative_base(cls=BaseModel, metaclass=ModelDeclarativeMeta, name="Model")


class SQLA(SQLAlchemy):
"""
This is a child class of flask_SQLAlchemy
It's purpose is to override the declarative base of the original
package. So that it is bound to F.A.B. Model class allowing the dev
to be in the same namespace of the security tables (and others)
and can use AuditMixin class alike.
Use it and configure it just like flask_SQLAlchemy
"""

def __init__(self, *args, **kwargs):
model_class = kwargs.pop("model_class", Model)

super().__init__(*args, model_class=model_class, **kwargs)

def get_tables_for_bind(self, bind=None):
"""Returns a list of all tables relevant for a bind."""
result = []
tables = Model.metadata.tables
for key in tables:
if tables[key].info.get("bind_key") == bind:
result.append(tables[key])
return result


"""
This is for retro compatibility
"""
Base = Model
Base = Model
5 changes: 3 additions & 2 deletions flask_appbuilder/security/sqla/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import datetime

from flask import g
from flask_appbuilder import Model
from flask_appbuilder._compat import as_unicode

from sqlalchemy import (
Boolean,
Column,
Expand All @@ -15,8 +18,6 @@
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship

from ... import Model
from ..._compat import as_unicode

_dont_audit = False

Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ flask-limiter==3.2.0
# via Flask-AppBuilder (setup.py)
flask-login==0.6.0
# via Flask-AppBuilder (setup.py)
flask-sqlalchemy==2.5.1
flask-sqlalchemy==3.0.2
# via Flask-AppBuilder (setup.py)
flask-wtf==1.0.1
# via Flask-AppBuilder (setup.py)
greenlet==1.1.2
# via sqlalchemy
idna==3.3
# via email-validator
importlib-metadata==6.6.0
Expand Down Expand Up @@ -97,7 +99,7 @@ pytz==2021.1
# via
# babel
# flask-babel
pyyaml==5.4.1
pyyaml==5.3.1
# via apispec
rich==12.6.0
# via flask-limiter
Expand Down

0 comments on commit f35377e

Please sign in to comment.