Skip to content

Commit

Permalink
Merge pull request #56 from mynhardtburger/bug-fix-register_model_con…
Browse files Browse the repository at this point in the history
…nection

Bugfix - Don't register connection for local TGIS instances
  • Loading branch information
gabe-l-hart authored Jun 7, 2024
2 parents ddbccd3 + 30d5330 commit e66e014
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
23 changes: 23 additions & 0 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def register_model_connection(
"""
Register a remote model connection.
If a local TGIS instance is maintained, do nothing.
If the model connection is already registered, do nothing.
Otherwise create and register the model connection using the TGISBackend's
Expand All @@ -198,7 +200,20 @@ def register_model_connection(
If `fill_with_defaults == True`, missing keys in `conn_cfg` will be populated
with defaults from the TGISBackend's config connection.
"""
# Don't attempt registering a remote model if running local TGIS instance
if self.local_tgis:
log.debug(
"<TGB99277346D> Running a local TGIS instance... won't register a "
"remote model connection"
)
return

if model_id in self._model_connections:
log.debug(
"<TGB08621956D> remote model connection for model %s already exists... "
"nothing to register",
model_id,
)
return # Model connection exists --> do nothing

# Craft new connection config
Expand All @@ -211,6 +226,10 @@ def register_model_connection(
new_conn_cfg.update(conn_cfg)

# Create model connection
error.value_check(
"<TGB17891341E>", new_conn_cfg, "TGISConnection config is empty"
)

model_conn = TGISConnection.from_config(model_id, new_conn_cfg)

error.value_check("<TGB81270235E>", model_conn is not None)
Expand All @@ -219,6 +238,10 @@ def register_model_connection(
if self._test_connections:
model_conn = self._test_connection(model_conn)
if model_conn is not None:
log.debug(
"<TGB16640078D> Registering new remote model connection for %s",
model_id,
)
self._safely_update_state(model_id, model_conn, new_conn_cfg)

def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,24 @@ def test_tgis_backend_register_model_connection(
assert tgis_be._base_connection_cfg == backup_base_cfg


def test_tgis_backend_register_model_connection_local():
tgis_be = TGISBackend()

# Confirm marked as local TGIS instance with no base connection config
assert tgis_be.local_tgis
assert not tgis_be._base_connection_cfg
assert not tgis_be._model_connections
assert not tgis_be._remote_models_cfg

# Register action should do nothing
tgis_be.register_model_connection("should do nothing")

# Confirm nothing was done
assert not tgis_be._base_connection_cfg
assert not tgis_be._model_connections
assert not tgis_be._remote_models_cfg


## Failure Tests ###############################################################


Expand Down

0 comments on commit e66e014

Please sign in to comment.