Skip to content

Commit

Permalink
Introduce DBID resolution context
Browse files Browse the repository at this point in the history
Summary: Create a context manager for resolving DBIDs to allow resolving a DBID more than once. This will be useful when writing multiple runs to the database, since it means that objects shared between runs will still be written out once for each run saved.

Reviewed By: fahndrich

Differential Revision: D67112789

fbshipit-source-id: 1f6438394bc8eb17bb0681a6950f59c2e9f9377a
  • Loading branch information
scottblaha authored and facebook-github-bot committed Dec 13, 2024
1 parent b1bb482 commit f829638
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
32 changes: 31 additions & 1 deletion sapp/db_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,22 @@

from __future__ import annotations

import contextlib
import logging
from collections import namedtuple
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union
from typing import (
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
Union,
)

from munch import Munch
from sqlalchemy import Column, exc, inspect, String, tuple_, types
Expand Down Expand Up @@ -152,6 +165,23 @@ def load_dialect_impl(self, dialect: Dialect):
return self.impl


@contextlib.contextmanager
def dbid_resolution_context() -> Generator[None, None, None]:
"""Track the resolution of DBIDs and unresolve them on exiting the context."""

def resolve(self: DBID, id: Union[int, None, DBID], is_new: bool = True) -> DBID:
dbids.append((self, getattr(self, "_id", None), getattr(self, "is_new", None)))
return old_resolve(self, id, is_new)

dbids: List[Tuple[DBID, Union[int, None, DBID], bool]] = []
old_resolve: Callable[[DBID, Union[int, None, DBID], bool], DBID] = DBID.resolve
DBID.resolve = resolve # pyre-ignore[8] Pyre doesn't like patching methods
yield
DBID.resolve = old_resolve
for dbid, id, is_new in reversed(dbids):
dbid.resolve(id, is_new)


class PrepareMixin:
@classmethod
def prepare(
Expand Down
19 changes: 18 additions & 1 deletion sapp/tests/db_support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from unittest import TestCase

from ..db_support import DBID
from ..db_support import DBID, dbid_resolution_context


class DBSupportTest(TestCase):
Expand Down Expand Up @@ -38,3 +38,20 @@ def test_dbid_reassign_after_resolved(self) -> None:
def test_dbid_resolved_to_none(self) -> None:
primary_key = DBID()
self.assertEqual(None, primary_key.resolved())

def test_dbid_resolution_context(self) -> None:
primary_key = DBID()
foreign_key = DBID(primary_key)
with dbid_resolution_context():
primary_key.resolve(1)
self.assertEqual(primary_key.resolved(), 1)
self.assertEqual(foreign_key.resolved(), 1)
self.assertIsNone(primary_key.resolved())
self.assertIsNone(foreign_key.resolved())
primary_key.resolve(2)
with dbid_resolution_context():
primary_key.resolve(3)
self.assertEqual(primary_key.resolved(), 3)
self.assertEqual(foreign_key.resolved(), 3)
self.assertEqual(primary_key.resolved(), 2)
self.assertEqual(foreign_key.resolved(), 2)

0 comments on commit f829638

Please sign in to comment.