diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index a2162aa0ba..aaed2e528a 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -331,10 +331,14 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 38 +LIBPATCH = 39 PYDEPS = ["ops>=2.0.0"] +# Starting from what LIBPATCH number to apply legacy solutions +# v0.17 was the last version without secrets +LEGACY_SUPPORT_FROM = 17 + logger = logging.getLogger(__name__) Diff = namedtuple("Diff", "added changed deleted") @@ -351,36 +355,16 @@ def _on_topic_requested(self, event: TopicRequestedEvent): GROUP_MAPPING_FIELD = "secret_group_mapping" GROUP_SEPARATOR = "@" +MODEL_ERRORS = { + "not_leader": "this unit is not the leader", + "no_label_and_uri": "ERROR either URI or label should be used for getting an owned secret but not both", + "owner_no_refresh": "ERROR secret owner cannot use --refresh", +} -class SecretGroup(str): - """Secret groups specific type.""" - - -class SecretGroupsAggregate(str): - """Secret groups with option to extend with additional constants.""" - - def __init__(self): - self.USER = SecretGroup("user") - self.TLS = SecretGroup("tls") - self.EXTRA = SecretGroup("extra") - - def __setattr__(self, name, value): - """Setting internal constants.""" - if name in self.__dict__: - raise RuntimeError("Can't set constant!") - else: - super().__setattr__(name, SecretGroup(value)) - - def groups(self) -> list: - """Return the list of stored SecretGroups.""" - return list(self.__dict__.values()) - - def get_group(self, group: str) -> Optional[SecretGroup]: - """If the input str translates to a group name, return that.""" - return SecretGroup(group) if group in self.groups() else None - -SECRET_GROUPS = SecretGroupsAggregate() +############################################################################## +# Exceptions +############################################################################## class DataInterfacesError(Exception): @@ -407,6 +391,15 @@ class IllegalOperationError(DataInterfacesError): """To be used when an operation is not allowed to be performed.""" +############################################################################## +# Global helpers / utilities +############################################################################## + +############################################################################## +# Databag handling and comparison methods +############################################################################## + + def get_encoded_dict( relation: Relation, member: Union[Unit, Application], field: str ) -> Optional[Dict[str, str]]: @@ -482,6 +475,11 @@ def diff(event: RelationChangedEvent, bucket: Optional[Union[Unit, Application]] return Diff(added, changed, deleted) +############################################################################## +# Module decorators +############################################################################## + + def leader_only(f): """Decorator to ensure that only leader can perform given operation.""" @@ -536,6 +534,36 @@ def wrapper(self, *args, **kwargs): return wrapper +def legacy_apply_from_version(version: int) -> Callable: + """Decorator to decide whether to apply a legacy function or not. + + Based on LEGACY_SUPPORT_FROM module variable value, the importer charm may only want + to apply legacy solutions starting from a specific LIBPATCH. + + NOTE: All 'legacy' functions have to be defined and called in a way that they return `None`. + This results in cleaner and more secure execution flows in case the function may be disabled. + This requirement implicitly means that legacy functions change the internal state strictly, + don't return information. + """ + + def decorator(f: Callable[..., None]): + """Signature is ensuring None return value.""" + f.legacy_version = version + + def wrapper(self, *args, **kwargs) -> None: + if version >= LEGACY_SUPPORT_FROM: + return f(self, *args, **kwargs) + + return wrapper + + return decorator + + +############################################################################## +# Helper classes +############################################################################## + + class Scope(Enum): """Peer relations scope.""" @@ -543,9 +571,35 @@ class Scope(Enum): UNIT = "unit" -################################################################################ -# Secrets internal caching -################################################################################ +class SecretGroup(str): + """Secret groups specific type.""" + + +class SecretGroupsAggregate(str): + """Secret groups with option to extend with additional constants.""" + + def __init__(self): + self.USER = SecretGroup("user") + self.TLS = SecretGroup("tls") + self.EXTRA = SecretGroup("extra") + + def __setattr__(self, name, value): + """Setting internal constants.""" + if name in self.__dict__: + raise RuntimeError("Can't set constant!") + else: + super().__setattr__(name, SecretGroup(value)) + + def groups(self) -> list: + """Return the list of stored SecretGroups.""" + return list(self.__dict__.values()) + + def get_group(self, group: str) -> Optional[SecretGroup]: + """If the input str translates to a group name, return that.""" + return SecretGroup(group) if group in self.groups() else None + + +SECRET_GROUPS = SecretGroupsAggregate() class CachedSecret: @@ -554,6 +608,8 @@ class CachedSecret: The data structure is precisely re-using/simulating as in the actual Secret Storage """ + KNOWN_MODEL_ERRORS = [MODEL_ERRORS["no_label_and_uri"], MODEL_ERRORS["owner_no_refresh"]] + def __init__( self, model: Model, @@ -571,6 +627,95 @@ def __init__( self.legacy_labels = legacy_labels self.current_label = None + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + + try: + self._secret_meta = self._model.get_secret(label=self.label) + except SecretNotFoundError: + # Falling back to seeking for potential legacy labels + self._legacy_compat_find_secret_by_old_label() + + # If still not found, to be checked by URI, to be labelled with the proposed label + if not self._secret_meta and self._secret_uri: + self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) + return self._secret_meta + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on rolling upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see the spec.) + # All data involves: + # - databag contents + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Compatibility + + @legacy_apply_from_version(34) + def _legacy_compat_find_secret_by_old_label(self) -> None: + """Compatibility function, allowing to find a secret by a legacy label. + + This functionality is typically needed when secret labels changed over an upgrade. + Until the first write operation, we need to maintain data as it was, including keeping + the old secret label. In order to keep track of the old label currently used to access + the secret, and additional 'current_label' field is being defined. + """ + for label in self.legacy_labels: + try: + self._secret_meta = self._model.get_secret(label=label) + except SecretNotFoundError: + pass + else: + if label != self.label: + self.current_label = label + return + + # Migrations + + @legacy_apply_from_version(34) + def _legacy_migration_to_new_label_if_needed(self) -> None: + """Helper function to re-create the secret with a different label. + + Juju does not provide a way to change secret labels. + Thus whenever moving from secrets version that involves secret label changes, + we "re-create" the existing secret, and attach the new label to the new + secret, to be used from then on. + + Note: we replace the old secret with a new one "in place", as we can't + easily switch the containing SecretCache structure to point to a new secret. + Instead we are changing the 'self' (CachedSecret) object to point to the + new instance. + """ + if not self.current_label or not (self.meta and self._secret_meta): + return + + # Create a new secret with the new label + content = self._secret_meta.get_content() + self._secret_uri = None + + # It will be nice to have the possibility to check if we are the owners of the secret... + try: + self._secret_meta = self.add_secret(content, label=self.label) + except ModelError as err: + if MODEL_ERRORS["not_leader"] not in str(err): + raise + self.current_label = None + + ########################################################################## + # Public functions + ########################################################################## + def add_secret( self, content: Dict[str, str], @@ -593,28 +738,6 @@ def add_secret( self._secret_meta = secret return self._secret_meta - @property - def meta(self) -> Optional[Secret]: - """Getting cached secret meta-information.""" - if not self._secret_meta: - if not (self._secret_uri or self.label): - return - - for label in [self.label] + self.legacy_labels: - try: - self._secret_meta = self._model.get_secret(label=label) - except SecretNotFoundError: - pass - else: - if label != self.label: - self.current_label = label - break - - # If still not found, to be checked by URI, to be labelled with the proposed label - if not self._secret_meta and self._secret_uri: - self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) - return self._secret_meta - def get_content(self) -> Dict[str, str]: """Getting cached secret content.""" if not self._secret_content: @@ -624,35 +747,14 @@ def get_content(self) -> Dict[str, str]: except (ValueError, ModelError) as err: # https://bugs.launchpad.net/juju/+bug/2042596 # Only triggered when 'refresh' is set - known_model_errors = [ - "ERROR either URI or label should be used for getting an owned secret but not both", - "ERROR secret owner cannot use --refresh", - ] if isinstance(err, ModelError) and not any( - msg in str(err) for msg in known_model_errors + msg in str(err) for msg in self.KNOWN_MODEL_ERRORS ): raise # Due to: ValueError: Secret owner cannot use refresh=True self._secret_content = self.meta.get_content() return self._secret_content - def _move_to_new_label_if_needed(self): - """Helper function to re-create the secret with a different label.""" - if not self.current_label or not (self.meta and self._secret_meta): - return - - # Create a new secret with the new label - content = self._secret_meta.get_content() - self._secret_uri = None - - # I wish we could just check if we are the owners of the secret... - try: - self._secret_meta = self.add_secret(content, label=self.label) - except ModelError as err: - if "this unit is not the leader" not in str(err): - raise - self.current_label = None - def set_content(self, content: Dict[str, str]) -> None: """Setting cached secret content.""" if not self.meta: @@ -663,7 +765,7 @@ def set_content(self, content: Dict[str, str]) -> None: return if content: - self._move_to_new_label_if_needed() + self._legacy_migration_to_new_label_if_needed() self.meta.set_content(content) self._secret_content = content else: @@ -926,6 +1028,23 @@ def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" raise NotImplementedError + # Optional overrides + + def _legacy_apply_on_fetch(self) -> None: + """This function should provide a list of compatibility functions to be applied when fetching (legacy) data.""" + pass + + def _legacy_apply_on_update(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when writing data. + + Since data may be at a legacy version, migration may be mandatory. + """ + pass + + def _legacy_apply_on_delete(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when deleting (legacy) data.""" + pass + # Internal helper methods @staticmethod @@ -1178,6 +1297,16 @@ def get_relation(self, relation_name, relation_id) -> Relation: return relation + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Get the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[self.component].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, secret_uri: str) -> None: + """Set the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + relation.data[self.component][secret_field] = secret_uri + def fetch_relation_data( self, relation_ids: Optional[List[int]] = None, @@ -1194,6 +1323,8 @@ def fetch_relation_data( a dict of the values stored in the relation data bag for all relation instances (indexed by the relation ID). """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1232,6 +1363,8 @@ def fetch_my_relation_data( NOTE: Since only the leader can read the relation's 'this_app'-side Application databag, the functionality is limited to leaders """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1263,6 +1396,8 @@ def fetch_my_relation_field( @leader_only def update_relation_data(self, relation_id: int, data: dict) -> None: """Update the data within the relation.""" + self._legacy_apply_on_update(list(data.keys())) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._update_relation_data(relation, data) @@ -1270,6 +1405,8 @@ def update_relation_data(self, relation_id: int, data: dict) -> None: @leader_only def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: """Remove field from the relation.""" + self._legacy_apply_on_delete(fields) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._delete_relation_data(relation, fields) @@ -1336,8 +1473,7 @@ def _add_relation_secret( uri_to_databag=True, ) -> bool: """Add a new Juju Secret that will be registered in the relation databag.""" - secret_field = self._generate_secret_field_name(group_mapping) - if uri_to_databag and relation.data[self.component].get(secret_field): + if uri_to_databag and self.get_secret_uri(relation, group_mapping): logging.error("Secret for relation %s already exists, not adding again", relation.id) return False @@ -1348,7 +1484,7 @@ def _add_relation_secret( # According to lint we may not have a Secret ID if uri_to_databag and secret.meta and secret.meta.id: - relation.data[self.component][secret_field] = secret.meta.id + self.set_secret_uri(relation, group_mapping, secret.meta.id) # Return the content that was added return True @@ -1449,8 +1585,7 @@ def _get_relation_secret( if not relation: return - secret_field = self._generate_secret_field_name(group_mapping) - if secret_uri := relation.data[self.local_app].get(secret_field): + if secret_uri := self.get_secret_uri(relation, group_mapping): return self.secrets.get(label, secret_uri) def _fetch_specific_relation_data( @@ -1603,11 +1738,10 @@ def _register_secrets_to_relation(self, relation: Relation, params_name_list: Li for group in SECRET_GROUPS.groups(): secret_field = self._generate_secret_field_name(group) - if secret_field in params_name_list: - if secret_uri := relation.data[relation.app].get(secret_field): - self._register_secret_to_relation( - relation.name, relation.id, secret_uri, group - ) + if secret_field in params_name_list and ( + secret_uri := self.get_secret_uri(relation, group) + ): + self._register_secret_to_relation(relation.name, relation.id, secret_uri, group) def _is_resource_created_for_relation(self, relation: Relation) -> bool: if not relation.app: @@ -1618,6 +1752,17 @@ def _is_resource_created_for_relation(self, relation: Relation) -> bool: ) return bool(data.get("username")) and bool(data.get("password")) + # Public functions + + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Getting relation secret URI for the corresponding Secret Group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[relation.app].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, uri: str) -> None: + """Setting relation secret URI is not possible for a Requirer.""" + raise NotImplementedError("Requirer can not change the relation secret URI.") + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -1768,7 +1913,6 @@ def __init__( secret_field_name: Optional[str] = None, deleted_label: Optional[str] = None, ): - """Manager of base client relations.""" RequirerData.__init__( self, model, @@ -1779,6 +1923,11 @@ def __init__( self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME self.deleted_label = deleted_label self._secret_label_map = {} + + # Legacy information holders + self._legacy_labels = [] + self._legacy_secret_uri = None + # Secrets that are being dynamically added within the scope of this event handler run self._new_secrets = [] self._additional_secret_group_mapping = additional_secret_group_mapping @@ -1853,10 +2002,12 @@ def set_secret( value: The string value of the secret group_mapping: The name of the "secret group", in case the field is to be added to an existing secret """ + self._legacy_apply_on_update([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: self._new_secrets.append(full_field) - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): self.update_relation_data(relation_id, {full_field: value}) # Unlike for set_secret(), there's no harm using this operation with static secrets @@ -1869,6 +2020,8 @@ def get_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to fetch secrets only.""" + self._legacy_apply_on_fetch() + full_field = self._field_to_internal_name(field, group_mapping) if ( self.secrets_enabled @@ -1876,7 +2029,7 @@ def get_secret( and field not in self.current_secret_fields ): return - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): return self.fetch_my_relation_field(relation_id, full_field) @dynamic_secrets_only @@ -1887,14 +2040,19 @@ def delete_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to delete secrets only.""" + self._legacy_apply_on_delete([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: logger.warning(f"Secret {field} from group {group_mapping} was not found") return - if self._no_group_with_databag(field, full_field): + + if self.valid_field_pattern(field, full_field): self.delete_relation_data(relation_id, [full_field]) + ########################################################################## # Helpers + ########################################################################## @staticmethod def _field_to_internal_name(field: str, group: Optional[SecretGroup]) -> str: @@ -1936,10 +2094,69 @@ def _content_for_secret_group( if k in self.secret_fields } - # Backwards compatibility + def valid_field_pattern(self, field: str, full_field: str) -> bool: + """Check that no secret group is attempted to be used together without secrets being enabled. + + Secrets groups are impossible to use with versions that are not yet supporting secrets. + """ + if not self.secrets_enabled and full_field != field: + logger.error( + f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." + ) + return False + return True + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see spec.) + # All data involves: + # - databag + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Full legacy stack for each operation + + def _legacy_apply_on_fetch(self) -> None: + """All legacy functions to be applied on fetch.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + + def _legacy_apply_on_update(self, fields) -> None: + """All legacy functions to be applied on update.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_migration_remove_secret_from_databag(relation, fields) + self._legacy_migration_remove_secret_field_name_from_databag(relation) + + def _legacy_apply_on_delete(self, fields) -> None: + """All legacy functions to be applied on delete.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_compat_check_deleted_label(relation, fields) + + # Compatibility + + @legacy_apply_from_version(18) + def _legacy_compat_check_deleted_label(self, relation, fields) -> None: + """Helper function for legacy behavior. + + As long as https://bugs.launchpad.net/juju/+bug/2028094 wasn't fixed, + we did not delete fields but rather kept them in the secret with a string value + expressing invalidity. This function is maintainnig that behavior when needed. + """ + if not self.deleted_label: + return - def _check_deleted_label(self, relation, fields) -> None: - """Helper function for legacy behavior.""" current_data = self.fetch_my_relation_data([relation.id], fields) if current_data is not None: # Check if the secret we wanna delete actually exists @@ -1952,7 +2169,43 @@ def _check_deleted_label(self, relation, fields) -> None: ", ".join(non_existent), ) - def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + @legacy_apply_from_version(18) + def _legacy_compat_secret_uri_from_databag(self, relation) -> None: + """Fetching the secret URI from the databag, in case stored there.""" + self._legacy_secret_uri = relation.data[self.component].get( + self._generate_secret_field_name(), None + ) + + @legacy_apply_from_version(34) + def _legacy_compat_generate_prev_labels(self) -> None: + """Generator for legacy secret label names, for backwards compatibility. + + Secret label is part of the data that MUST be maintained across rolling upgrades. + In case there may be a change on a secret label, the old label must be recognized + after upgrades, and left intact until the first write operation -- when we roll over + to the new label. + + This function keeps "memory" of previously used secret labels. + NOTE: Return value takes decorator into account -- all 'legacy' functions may return `None` + + v0.34 (rev69): Fixing issue https://github.com/canonical/data-platform-libs/issues/155 + meant moving from '.' (i.e. 'mysql.app', 'mysql.unit') + to labels '..' (like 'peer.mysql.app') + """ + if self._legacy_labels: + return + + result = [] + members = [self._model.app.name] + if self.scope: + members.append(self.scope.value) + result.append(f"{'.'.join(members)}") + self._legacy_labels = result + + # Migration + + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_from_databag(self, relation, fields: List[str]) -> None: """For Rolling Upgrades -- when moving from databag to secrets usage. Practically what happens here is to remove stuff from the databag that is @@ -1966,10 +2219,16 @@ def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: if self._fetch_relation_data_without_secrets(self.component, relation, [field]): self._delete_relation_data_without_secrets(self.component, relation, [field]) - def _remove_secret_field_name_from_databag(self, relation) -> None: + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_field_name_from_databag(self, relation) -> None: """Making sure that the old databag URI is gone. This action should not be executed more than once. + + There was a phase (before moving secrets usage to libs) when charms saved the peer + secret URI to the databag, and used this URI from then on to retrieve their secret. + When upgrading to charm versions using this library, we need to add a label to the + secret and access it via label from than on, and remove the old traces from the databag. """ # Nothing to do if 'internal-secret' is not in the databag if not (relation.data[self.component].get(self._generate_secret_field_name())): @@ -1985,25 +2244,9 @@ def _remove_secret_field_name_from_databag(self, relation) -> None: # Databag reference to the secret URI can be removed, now that it's labelled relation.data[self.component].pop(self._generate_secret_field_name(), None) - def _previous_labels(self) -> List[str]: - """Generator for legacy secret label names, for backwards compatibility.""" - result = [] - members = [self._model.app.name] - if self.scope: - members.append(self.scope.value) - result.append(f"{'.'.join(members)}") - return result - - def _no_group_with_databag(self, field: str, full_field: str) -> bool: - """Check that no secret group is attempted to be used together with databag.""" - if not self.secrets_enabled and full_field != field: - logger.error( - f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." - ) - return False - return True - + ########################################################################## # Event handlers + ########################################################################## def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" @@ -2013,7 +2256,9 @@ def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: """Event emitted when the secret has changed.""" pass + ########################################################################## # Overrides of Relation Data handling functions + ########################################################################## def _generate_secret_label( self, relation_name: str, relation_id: int, group_mapping: SecretGroup @@ -2050,13 +2295,14 @@ def _get_relation_secret( return label = self._generate_secret_label(relation_name, relation_id, group_mapping) - secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) # URI or legacy label is only to applied when moving single legacy secret to a (new) label if group_mapping == SECRET_GROUPS.EXTRA: # Fetching the secret with fallback to URI (in case label is not yet known) # Label would we "stuck" on the secret in case it is found - return self.secrets.get(label, secret_uri, legacy_labels=self._previous_labels()) + return self.secrets.get( + label, self._legacy_secret_uri, legacy_labels=self._legacy_labels + ) return self.secrets.get(label) def _get_group_secret_contents( @@ -2086,7 +2332,6 @@ def _fetch_my_specific_relation_data( @either_static_or_dynamic_secrets def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - self._remove_secret_from_databag(relation, list(data.keys())) _, normal_fields = self._process_secret_fields( relation, self.secret_fields, @@ -2095,7 +2340,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non data=data, uri_to_databag=False, ) - self._remove_secret_field_name_from_databag(relation) normal_content = {k: v for k, v in data.items() if k in normal_fields} self._update_relation_data_without_secrets(self.component, relation, normal_content) @@ -2104,8 +2348,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" if self.secret_fields and self.deleted_label: - # Legacy, backwards compatibility - self._check_deleted_label(relation, fields) _, normal_fields = self._process_secret_fields( relation, @@ -2141,7 +2383,9 @@ def fetch_relation_field( "fetch_my_relation_data() and fetch_my_relation_field()" ) + ########################################################################## # Public functions -- inherited + ########################################################################## fetch_my_relation_data = Data.fetch_my_relation_data fetch_my_relation_field = Data.fetch_my_relation_field diff --git a/lib/charms/grafana_agent/v0/cos_agent.py b/lib/charms/grafana_agent/v0/cos_agent.py index 870ba62a17..582b70c079 100644 --- a/lib/charms/grafana_agent/v0/cos_agent.py +++ b/lib/charms/grafana_agent/v0/cos_agent.py @@ -206,19 +206,34 @@ def __init__(self, *args): ``` """ +import enum import json import logging +import socket from collections import namedtuple from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + List, + Literal, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import pydantic from cosl import GrafanaDashboard, JujuTopology from cosl.rules import AlertRules from ops.charm import RelationChangedEvent from ops.framework import EventBase, EventSource, Object, ObjectEvents -from ops.model import Relation +from ops.model import ModelError, Relation from ops.testing import CharmType if TYPE_CHECKING: @@ -234,9 +249,9 @@ class _MetricsEndpointDict(TypedDict): LIBID = "dc15fa84cef84ce58155fb84f6c6213a" LIBAPI = 0 -LIBPATCH = 8 +LIBPATCH = 10 -PYDEPS = ["cosl", "pydantic < 2"] +PYDEPS = ["cosl", "pydantic"] DEFAULT_RELATION_NAME = "cos-agent" DEFAULT_PEER_RELATION_NAME = "peers" @@ -249,7 +264,207 @@ class _MetricsEndpointDict(TypedDict): SnapEndpoint = namedtuple("SnapEndpoint", "owner, name") -class CosAgentProviderUnitData(pydantic.BaseModel): +# Note: MutableMapping is imported from the typing module and not collections.abc +# because subscripting collections.abc.MutableMapping was added in python 3.9, but +# most of our charms are based on 20.04, which has python 3.8. + +_RawDatabag = MutableMapping[str, str] + + +class TransportProtocolType(str, enum.Enum): + """Receiver Type.""" + + http = "http" + grpc = "grpc" + + +receiver_protocol_to_transport_protocol = { + "zipkin": TransportProtocolType.http, + "kafka": TransportProtocolType.http, + "tempo_http": TransportProtocolType.http, + "tempo_grpc": TransportProtocolType.grpc, + "otlp_grpc": TransportProtocolType.grpc, + "otlp_http": TransportProtocolType.http, + "jaeger_thrift_http": TransportProtocolType.http, +} + +_tracing_receivers_ports = { + # OTLP receiver: see + # https://github.com/open-telemetry/opentelemetry-collector/tree/v0.96.0/receiver/otlpreceiver + "otlp_http": 4318, + "otlp_grpc": 4317, + # Jaeger receiver: see + # https://github.com/open-telemetry/opentelemetry-collector-contrib/tree/v0.96.0/receiver/jaegerreceiver + "jaeger_grpc": 14250, + "jaeger_thrift_http": 14268, + # Zipkin receiver: see + # https://github.com/open-telemetry/opentelemetry-collector-contrib/tree/v0.96.0/receiver/zipkinreceiver + "zipkin": 9411, +} + +ReceiverProtocol = Literal["otlp_grpc", "otlp_http", "zipkin", "jaeger_thrift_http", "jaeger_grpc"] + + +class TracingError(Exception): + """Base class for custom errors raised by tracing.""" + + +class NotReadyError(TracingError): + """Raised by the provider wrapper if a requirer hasn't published the required data (yet).""" + + +class ProtocolNotRequestedError(TracingError): + """Raised if the user attempts to obtain an endpoint for a protocol it did not request.""" + + +class DataValidationError(TracingError): + """Raised when data validation fails on IPU relation data.""" + + +class AmbiguousRelationUsageError(TracingError): + """Raised when one wrongly assumes that there can only be one relation on an endpoint.""" + + +# TODO we want to eventually use `DatabagModel` from cosl but it likely needs a move to common package first +if int(pydantic.version.VERSION.split(".")[0]) < 2: # type: ignore + + class DatabagModel(pydantic.BaseModel): # type: ignore + """Base databag model.""" + + class Config: + """Pydantic config.""" + + # ignore any extra fields in the databag + extra = "ignore" + """Ignore any extra fields in the databag.""" + allow_population_by_field_name = True + """Allow instantiating this class by field name (instead of forcing alias).""" + + _NEST_UNDER = None + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + if cls._NEST_UNDER: + return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {f.alias for f in cls.__fields__.values()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.parse_raw(json.dumps(data)) # type: ignore + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + :param databag: the databag to write the data to. + :param clear: ensure the databag is cleared before writing it. + """ + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + + if self._NEST_UNDER: + databag[self._NEST_UNDER] = self.json(by_alias=True) + return databag + + dct = self.dict() + for key, field in self.__fields__.items(): # type: ignore + value = dct[key] + databag[field.alias or key] = json.dumps(value) + + return databag + +else: + from pydantic import ConfigDict + + class DatabagModel(pydantic.BaseModel): + """Base databag model.""" + + model_config = ConfigDict( + # ignore any extra fields in the databag + extra="ignore", + # Allow instantiating this class by field name (instead of forcing alias). + populate_by_name=True, + # Custom config key: whether to nest the whole datastructure (as json) + # under a field or spread it out at the toplevel. + _NEST_UNDER=None, # type: ignore + arbitrary_types_allowed=True, + ) + """Pydantic config.""" + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + nest_under = cls.model_config.get("_NEST_UNDER") # type: ignore + if nest_under: + return cls.model_validate(json.loads(databag[nest_under])) # type: ignore + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {(f.alias or n) for n, f in cls.__fields__.items()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.model_validate_json(json.dumps(data)) # type: ignore + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + :param databag: the databag to write the data to. + :param clear: ensure the databag is cleared before writing it. + """ + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + nest_under = self.model_config.get("_NEST_UNDER") + if nest_under: + databag[nest_under] = self.model_dump_json( # type: ignore + by_alias=True, + # skip keys whose values are default + exclude_defaults=True, + ) + return databag + + dct = self.model_dump() # type: ignore + for key, field in self.model_fields.items(): # type: ignore + value = dct[key] + if value == field.default: + continue + databag[field.alias or key] = json.dumps(value) + + return databag + + +class CosAgentProviderUnitData(DatabagModel): """Unit databag model for `cos-agent` relation.""" # The following entries are the same for all units of the same principal. @@ -267,13 +482,16 @@ class CosAgentProviderUnitData(pydantic.BaseModel): metrics_scrape_jobs: List[Dict] log_slots: List[str] + # Requested tracing protocols. + tracing_protocols: Optional[List[str]] = None + # when this whole datastructure is dumped into a databag, it will be nested under this key. # while not strictly necessary (we could have it 'flattened out' into the databag), # this simplifies working with the model. KEY: ClassVar[str] = "config" -class CosAgentPeersUnitData(pydantic.BaseModel): +class CosAgentPeersUnitData(DatabagModel): """Unit databag model for `peers` cos-agent machine charm peer relation.""" # We need the principal unit name and relation metadata to be able to render identifiers @@ -304,6 +522,83 @@ def app_name(self) -> str: return self.unit_name.split("/")[0] +if int(pydantic.version.VERSION.split(".")[0]) < 2: # type: ignore + + class ProtocolType(pydantic.BaseModel): # type: ignore + """Protocol Type.""" + + class Config: + """Pydantic config.""" + + use_enum_values = True + """Allow serializing enum values.""" + + name: str = pydantic.Field( + ..., + description="Receiver protocol name. What protocols are supported (and what they are called) " + "may differ per provider.", + examples=["otlp_grpc", "otlp_http", "tempo_http"], + ) + + type: TransportProtocolType = pydantic.Field( + ..., + description="The transport protocol used by this receiver.", + examples=["http", "grpc"], + ) + +else: + + class ProtocolType(pydantic.BaseModel): + """Protocol Type.""" + + model_config = pydantic.ConfigDict( + # Allow serializing enum values. + use_enum_values=True + ) + """Pydantic config.""" + + name: str = pydantic.Field( + ..., + description="Receiver protocol name. What protocols are supported (and what they are called) " + "may differ per provider.", + examples=["otlp_grpc", "otlp_http", "tempo_http"], + ) + + type: TransportProtocolType = pydantic.Field( + ..., + description="The transport protocol used by this receiver.", + examples=["http", "grpc"], + ) + + +class Receiver(pydantic.BaseModel): + """Specification of an active receiver.""" + + protocol: ProtocolType = pydantic.Field(..., description="Receiver protocol name and type.") + url: str = pydantic.Field( + ..., + description="""URL at which the receiver is reachable. If there's an ingress, it would be the external URL. + Otherwise, it would be the service's fqdn or internal IP. + If the protocol type is grpc, the url will not contain a scheme.""", + examples=[ + "http://traefik_address:2331", + "https://traefik_address:2331", + "http://tempo_public_ip:2331", + "https://tempo_public_ip:2331", + "tempo_public_ip:2331", + ], + ) + + +class CosAgentRequirerUnitData(DatabagModel): # noqa: D101 + """Application databag model for the COS-agent requirer.""" + + receivers: List[Receiver] = pydantic.Field( + ..., + description="List of all receivers enabled on the tracing provider.", + ) + + class COSAgentProvider(Object): """Integration endpoint wrapper for the provider side of the cos_agent interface.""" @@ -318,6 +613,7 @@ def __init__( log_slots: Optional[List[str]] = None, dashboard_dirs: Optional[List[str]] = None, refresh_events: Optional[List] = None, + tracing_protocols: Optional[List[str]] = None, *, scrape_configs: Optional[Union[List[dict], Callable]] = None, ): @@ -336,6 +632,7 @@ def __init__( in the form ["snap-name:slot", ...]. dashboard_dirs: Directory where the dashboards are stored. refresh_events: List of events on which to refresh relation data. + tracing_protocols: List of protocols that the charm will be using for sending traces. scrape_configs: List of standard scrape_configs dicts or a callable that returns the list in case the configs need to be generated dynamically. The contents of this list will be merged with the contents of `metrics_endpoints`. @@ -353,6 +650,8 @@ def __init__( self._log_slots = log_slots or [] self._dashboard_dirs = dashboard_dirs self._refresh_events = refresh_events or [self._charm.on.config_changed] + self._tracing_protocols = tracing_protocols + self._is_single_endpoint = charm.meta.relations[relation_name].limit == 1 events = self._charm.on[relation_name] self.framework.observe(events.relation_joined, self._on_refresh) @@ -377,6 +676,7 @@ def _on_refresh(self, event): dashboards=self._dashboards, metrics_scrape_jobs=self._scrape_jobs, log_slots=self._log_slots, + tracing_protocols=self._tracing_protocols, ) relation.data[self._charm.unit][data.KEY] = data.json() except ( @@ -441,6 +741,103 @@ def _dashboards(self) -> List[GrafanaDashboard]: dashboards.append(dashboard) return dashboards + @property + def relations(self) -> List[Relation]: + """The tracing relations associated with this endpoint.""" + return self._charm.model.relations[self._relation_name] + + @property + def _relation(self) -> Optional[Relation]: + """If this wraps a single endpoint, the relation bound to it, if any.""" + if not self._is_single_endpoint: + objname = type(self).__name__ + raise AmbiguousRelationUsageError( + f"This {objname} wraps a {self._relation_name} endpoint that has " + "limit != 1. We can't determine what relation, of the possibly many, you are " + f"referring to. Please pass a relation instance while calling {objname}, " + "or set limit=1 in the charm metadata." + ) + relations = self.relations + return relations[0] if relations else None + + def is_ready(self, relation: Optional[Relation] = None): + """Is this endpoint ready?""" + relation = relation or self._relation + if not relation: + logger.debug(f"no relation on {self._relation_name !r}: tracing not ready") + return False + if relation.data is None: + logger.error(f"relation data is None for {relation}") + return False + if not relation.app: + logger.error(f"{relation} event received but there is no relation.app") + return False + try: + unit = next(iter(relation.units), None) + if not unit: + return False + databag = dict(relation.data[unit]) + CosAgentRequirerUnitData.load(databag) + + except (json.JSONDecodeError, pydantic.ValidationError, DataValidationError): + logger.info(f"failed validating relation data for {relation}") + return False + return True + + def get_all_endpoints( + self, relation: Optional[Relation] = None + ) -> Optional[CosAgentRequirerUnitData]: + """Unmarshalled relation data.""" + relation = relation or self._relation + if not relation or not self.is_ready(relation): + return None + unit = next(iter(relation.units), None) + if not unit: + return None + return CosAgentRequirerUnitData.load(relation.data[unit]) # type: ignore + + def _get_tracing_endpoint( + self, relation: Optional[Relation], protocol: ReceiverProtocol + ) -> Optional[str]: + unit_data = self.get_all_endpoints(relation) + if not unit_data: + return None + receivers: List[Receiver] = [i for i in unit_data.receivers if i.protocol.name == protocol] + if not receivers: + logger.error(f"no receiver found with protocol={protocol!r}") + return None + if len(receivers) > 1: + logger.error( + f"too many receivers with protocol={protocol!r}; using first one. Found: {receivers}" + ) + return None + + receiver = receivers[0] + return receiver.url + + def get_tracing_endpoint( + self, protocol: ReceiverProtocol, relation: Optional[Relation] = None + ) -> Optional[str]: + """Receiver endpoint for the given protocol.""" + endpoint = self._get_tracing_endpoint(relation or self._relation, protocol=protocol) + if not endpoint: + requested_protocols = set() + relations = [relation] if relation else self.relations + for relation in relations: + try: + databag = CosAgentProviderUnitData.load(relation.data[self._charm.unit]) + except DataValidationError: + continue + + if databag.tracing_protocols: + requested_protocols.update(databag.tracing_protocols) + + if protocol not in requested_protocols: + raise ProtocolNotRequestedError(protocol, relation) + + return None + return endpoint + class COSAgentDataChanged(EventBase): """Event emitted by `COSAgentRequirer` when relation data changes.""" @@ -554,6 +951,12 @@ def _on_relation_data_changed(self, event: RelationChangedEvent): if not (provider_data := self._validated_provider_data(raw)): return + # write enabled receivers to cos-agent relation + try: + self.update_tracing_receivers() + except ModelError: + raise + # Copy data from the cos_agent relation to the peer relation, so the leader could # follow up. # Save the originating unit name, so it could be used for topology later on by the leader. @@ -574,6 +977,37 @@ def _on_relation_data_changed(self, event: RelationChangedEvent): # need to emit `on.data_changed`), so we're emitting `on.data_changed` either way. self.on.data_changed.emit() # pyright: ignore + def update_tracing_receivers(self): + """Updates the list of exposed tracing receivers in all relations.""" + try: + for relation in self._charm.model.relations[self._relation_name]: + CosAgentRequirerUnitData( + receivers=[ + Receiver( + url=f"{self._get_tracing_receiver_url(protocol)}", + protocol=ProtocolType( + name=protocol, + type=receiver_protocol_to_transport_protocol[protocol], + ), + ) + for protocol in self.requested_tracing_protocols() + ], + ).dump(relation.data[self._charm.unit]) + + except ModelError as e: + # args are bytes + msg = e.args[0] + if isinstance(msg, bytes): + if msg.startswith( + b"ERROR cannot read relation application settings: permission denied" + ): + logger.error( + f"encountered error {e} while attempting to update_relation_data." + f"The relation must be gone." + ) + return + raise + def _validated_provider_data(self, raw) -> Optional[CosAgentProviderUnitData]: try: return CosAgentProviderUnitData(**json.loads(raw)) @@ -586,6 +1020,55 @@ def trigger_refresh(self, _): # FIXME: Figure out what we should do here self.on.data_changed.emit() # pyright: ignore + def _get_requested_protocols(self, relation: Relation): + # Coherence check + units = relation.units + if len(units) > 1: + # should never happen + raise ValueError( + f"unexpected error: subordinate relation {relation} " + f"should have exactly one unit" + ) + + unit = next(iter(units), None) + + if not unit: + return None + + if not (raw := relation.data[unit].get(CosAgentProviderUnitData.KEY)): + return None + + if not (provider_data := self._validated_provider_data(raw)): + return None + + return provider_data.tracing_protocols + + def requested_tracing_protocols(self): + """All receiver protocols that have been requested by our related apps.""" + requested_protocols = set() + for relation in self._charm.model.relations[self._relation_name]: + try: + protocols = self._get_requested_protocols(relation) + except NotReadyError: + continue + if protocols: + requested_protocols.update(protocols) + return requested_protocols + + def _get_tracing_receiver_url(self, protocol: str): + scheme = "http" + try: + if self._charm.cert.enabled: # type: ignore + scheme = "https" + # not only Grafana Agent can implement cos_agent. If the charm doesn't have the `cert` attribute + # using our cert_handler, it won't have the `enabled` parameter. In this case, we pass and assume http. + except AttributeError: + pass + # the assumption is that a subordinate charm will always be accessible to its principal charm under its fqdn + if receiver_protocol_to_transport_protocol[protocol] == TransportProtocolType.grpc: + return f"{socket.getfqdn()}:{_tracing_receivers_ports[protocol]}" + return f"{scheme}://{socket.getfqdn()}:{_tracing_receivers_ports[protocol]}" + @property def _remote_data(self) -> List[Tuple[CosAgentProviderUnitData, JujuTopology]]: """Return a list of remote data from each of the related units. @@ -721,8 +1204,18 @@ def metrics_jobs(self) -> List[Dict]: @property def snap_log_endpoints(self) -> List[SnapEndpoint]: """Fetch logging endpoints exposed by related snaps.""" + endpoints = [] + endpoints_with_topology = self.snap_log_endpoints_with_topology + for endpoint, _ in endpoints_with_topology: + endpoints.append(endpoint) + + return endpoints + + @property + def snap_log_endpoints_with_topology(self) -> List[Tuple[SnapEndpoint, JujuTopology]]: + """Fetch logging endpoints and charm topology for each related snap.""" plugs = [] - for data, _ in self._remote_data: + for data, topology in self._remote_data: targets = data.log_slots if targets: for target in targets: @@ -733,15 +1226,16 @@ def snap_log_endpoints(self) -> List[SnapEndpoint]: "endpoints; this should not happen." ) else: - plugs.append(target) + plugs.append((target, topology)) endpoints = [] - for plug in plugs: + for plug, topology in plugs: if ":" not in plug: logger.error(f"invalid plug definition received: {plug}. Ignoring...") else: endpoint = SnapEndpoint(*plug.split(":")) - endpoints.append(endpoint) + endpoints.append((endpoint, topology)) + return endpoints @property @@ -804,3 +1298,67 @@ def dashboards(self) -> List[Dict[str, str]]: ) return dashboards + + +def charm_tracing_config( + endpoint_requirer: COSAgentProvider, cert_path: Optional[Union[Path, str]] +) -> Tuple[Optional[str], Optional[str]]: + """Utility function to determine the charm_tracing config you will likely want. + + If no endpoint is provided: + disable charm tracing. + If https endpoint is provided but cert_path is not found on disk: + disable charm tracing. + If https endpoint is provided and cert_path is None: + ERROR + Else: + proceed with charm tracing (with or without tls, as appropriate) + + Usage: + If you are using charm_tracing >= v1.9: + >>> from lib.charms.tempo_k8s.v1.charm_tracing import trace_charm + >>> from lib.charms.tempo_k8s.v0.cos_agent import charm_tracing_config + >>> @trace_charm(tracing_endpoint="my_endpoint", cert_path="cert_path") + >>> class MyCharm(...): + >>> _cert_path = "/path/to/cert/on/charm/container.crt" + >>> def __init__(self, ...): + >>> self.cos_agent = COSAgentProvider(...) + >>> self.my_endpoint, self.cert_path = charm_tracing_config( + ... self.cos_agent, self._cert_path) + + If you are using charm_tracing < v1.9: + >>> from lib.charms.tempo_k8s.v1.charm_tracing import trace_charm + >>> from lib.charms.tempo_k8s.v2.tracing import charm_tracing_config + >>> @trace_charm(tracing_endpoint="my_endpoint", cert_path="cert_path") + >>> class MyCharm(...): + >>> _cert_path = "/path/to/cert/on/charm/container.crt" + >>> def __init__(self, ...): + >>> self.cos_agent = COSAgentProvider(...) + >>> self.my_endpoint, self.cert_path = charm_tracing_config( + ... self.cos_agent, self._cert_path) + >>> @property + >>> def my_endpoint(self): + >>> return self._my_endpoint + >>> @property + >>> def cert_path(self): + >>> return self._cert_path + + """ + if not endpoint_requirer.is_ready(): + return None, None + + endpoint = endpoint_requirer.get_tracing_endpoint("otlp_http") + if not endpoint: + return None, None + + is_https = endpoint.startswith("https://") + + if is_https: + if cert_path is None: + raise TracingError("Cannot send traces to an https endpoint without a certificate.") + if not Path(cert_path).exists(): + # if endpoint is https BUT we don't have a server_cert yet: + # disable charm tracing until we do to prevent tls errors + return None, None + return endpoint, str(cert_path) + return endpoint, None diff --git a/lib/charms/postgresql_k8s/v0/postgresql.py b/lib/charms/postgresql_k8s/v0/postgresql.py index f7d361b5a9..c3412d36df 100644 --- a/lib/charms/postgresql_k8s/v0/postgresql.py +++ b/lib/charms/postgresql_k8s/v0/postgresql.py @@ -36,7 +36,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 33 +LIBPATCH = 34 INVALID_EXTRA_USER_ROLE_BLOCKING_MESSAGE = "invalid role(s) for extra user roles" @@ -425,14 +425,20 @@ def get_postgresql_timezones(self) -> Set[str]: timezones = cursor.fetchall() return {timezone[0] for timezone in timezones} - def get_postgresql_version(self) -> str: + def get_postgresql_version(self, current_host=True) -> str: """Returns the PostgreSQL version. Returns: PostgreSQL version number. """ + if current_host: + host = self.current_host + else: + host = None try: - with self._connect_to_database() as connection, connection.cursor() as cursor: + with self._connect_to_database( + database_host=host + ) as connection, connection.cursor() as cursor: cursor.execute("SELECT version();") # Split to get only the version number. return cursor.fetchone()[0].split(" ")[1] diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index ebe022e00d..fa926539bc 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -172,6 +172,59 @@ def my_tracing_endpoint(self) -> Optional[str]: provide an *absolute* path to the certificate file instead. """ + +def _remove_stale_otel_sdk_packages(): + """Hack to remove stale opentelemetry sdk packages from the charm's python venv. + + See https://github.com/canonical/grafana-agent-operator/issues/146 and + https://bugs.launchpad.net/juju/+bug/2058335 for more context. This patch can be removed after + this juju issue is resolved and sufficient time has passed to expect most users of this library + have migrated to the patched version of juju. When this patch is removed, un-ignore rule E402 for this file in the pyproject.toml (see setting + [tool.ruff.lint.per-file-ignores] in pyproject.toml). + + This only has an effect if executed on an upgrade-charm event. + """ + # all imports are local to keep this function standalone, side-effect-free, and easy to revert later + import os + + if os.getenv("JUJU_DISPATCH_PATH") != "hooks/upgrade-charm": + return + + import logging + import shutil + from collections import defaultdict + + from importlib_metadata import distributions + + otel_logger = logging.getLogger("charm_tracing_otel_patcher") + otel_logger.debug("Applying _remove_stale_otel_sdk_packages patch on charm upgrade") + # group by name all distributions starting with "opentelemetry_" + otel_distributions = defaultdict(list) + for distribution in distributions(): + name = distribution._normalized_name # type: ignore + if name.startswith("opentelemetry_"): + otel_distributions[name].append(distribution) + + otel_logger.debug(f"Found {len(otel_distributions)} opentelemetry distributions") + + # If we have multiple distributions with the same name, remove any that have 0 associated files + for name, distributions_ in otel_distributions.items(): + if len(distributions_) <= 1: + continue + + otel_logger.debug(f"Package {name} has multiple ({len(distributions_)}) distributions.") + for distribution in distributions_: + if not distribution.files: # Not None or empty list + path = distribution._path # type: ignore + otel_logger.info(f"Removing empty distribution of {name} at {path}.") + shutil.rmtree(path) + + otel_logger.debug("Successfully applied _remove_stale_otel_sdk_packages patch. ") + + +_remove_stale_otel_sdk_packages() + + import functools import inspect import logging @@ -197,14 +250,15 @@ def my_tracing_endpoint(self) -> Optional[str]: from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import INVALID_SPAN, Tracer -from opentelemetry.trace import get_current_span as otlp_get_current_span from opentelemetry.trace import ( + INVALID_SPAN, + Tracer, get_tracer, get_tracer_provider, set_span_in_context, set_tracer_provider, ) +from opentelemetry.trace import get_current_span as otlp_get_current_span from ops.charm import CharmBase from ops.framework import Framework @@ -217,7 +271,7 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 11 +LIBPATCH = 14 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] @@ -391,6 +445,9 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): _service_name = service_name or f"{self.app.name}-charm" unit_name = self.unit.name + # apply hacky patch to remove stale opentelemetry sdk packages on upgrade-charm. + # it could be trouble if someone ever decides to implement their own tracer parallel to + # ours and before the charm has inited. We assume they won't. resource = Resource.create( attributes={ "service.name": _service_name, @@ -612,38 +669,58 @@ def trace_type(cls: _T) -> _T: dev_logger.info(f"skipping {method} (dunder)") continue - new_method = trace_method(method) - if isinstance(inspect.getattr_static(cls, method.__name__), staticmethod): + # the span title in the general case should be: + # method call: MyCharmWrappedMethods.b + # if the method has a name (functools.wrapped or regular method), let + # _trace_callable use its default algorithm to determine what name to give the span. + trace_method_name = None + try: + qualname_c0 = method.__qualname__.split(".")[0] + if not hasattr(cls, method.__name__): + # if the callable doesn't have a __name__ (probably a decorated method), + # it probably has a bad qualname too (such as my_decorator..wrapper) which is not + # great for finding out what the trace is about. So we use the method name instead and + # add a reference to the decorator name. Result: + # method call: @my_decorator(MyCharmWrappedMethods.b) + trace_method_name = f"@{qualname_c0}({cls.__name__}.{name})" + except Exception: # noqa: failsafe + pass + + new_method = trace_method(method, name=trace_method_name) + + if isinstance(inspect.getattr_static(cls, name), staticmethod): new_method = staticmethod(new_method) setattr(cls, name, new_method) return cls -def trace_method(method: _F) -> _F: +def trace_method(method: _F, name: Optional[str] = None) -> _F: """Trace this method. A span will be opened when this method is called and closed when it returns. """ - return _trace_callable(method, "method") + return _trace_callable(method, "method", name=name) -def trace_function(function: _F) -> _F: +def trace_function(function: _F, name: Optional[str] = None) -> _F: """Trace this function. A span will be opened when this function is called and closed when it returns. """ - return _trace_callable(function, "function") + return _trace_callable(function, "function", name=name) -def _trace_callable(callable: _F, qualifier: str) -> _F: +def _trace_callable(callable: _F, qualifier: str, name: Optional[str] = None) -> _F: dev_logger.info(f"instrumenting {callable}") # sig = inspect.signature(callable) @functools.wraps(callable) def wrapped_function(*args, **kwargs): # type: ignore - name = getattr(callable, "__qualname__", getattr(callable, "__name__", str(callable))) - with _span(f"{qualifier} call: {name}"): # type: ignore + name_ = name or getattr( + callable, "__qualname__", getattr(callable, "__name__", str(callable)) + ) + with _span(f"{qualifier} call: {name_}"): # type: ignore return callable(*args, **kwargs) # type: ignore # wrapped_function.__signature__ = sig diff --git a/src/relations/db.py b/src/relations/db.py index d3944d76b3..778c52cfd3 100644 --- a/src/relations/db.py +++ b/src/relations/db.py @@ -319,14 +319,14 @@ def update_endpoints(self, relation: Relation = None) -> None: if len(relations) == 0: return + postgresql_version = None try: postgresql_version = self.charm.postgresql.get_postgresql_version() - except PostgreSQLGetPostgreSQLVersionError as e: - logger.exception(e) - self.charm.unit.status = BlockedStatus( - f"Failed to retrieve the PostgreSQL version to initialise/update {self.relation_name} relation" + except PostgreSQLGetPostgreSQLVersionError: + logger.exception( + "Failed to retrieve the PostgreSQL version to initialise/update %s relation" + % self.relation_name ) - return # List the replicas endpoints. replicas_endpoint = list(self.charm.members_ips - {self.charm.primary_endpoint}) @@ -383,7 +383,6 @@ def update_endpoints(self, relation: Relation = None) -> None: "port": DATABASE_PORT, "user": user, "schema_user": user, - "version": postgresql_version, "password": password, "schema_password": password, "database": database, @@ -392,6 +391,8 @@ def update_endpoints(self, relation: Relation = None) -> None: "state": self._get_state(), "extensions": ",".join(required_extensions), } + if postgresql_version: + data["version"] = postgresql_version # Set the data only in the unit databag. unit_relation_databag.update(data) diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py index b518757567..9c29f47d3b 100644 --- a/tests/unit/test_db.py +++ b/tests/unit/test_db.py @@ -426,6 +426,9 @@ def test_update_endpoints_with_relation(harness): with ( patch.object(PostgresqlOperatorCharm, "postgresql", Mock()) as postgresql_mock, patch("charm.Patroni.get_primary") as _get_primary, + patch( + "relations.db.logger", + ) as _logger, patch( "charm.PostgresqlOperatorCharm.members_ips", new_callable=PropertyMock, @@ -442,9 +445,9 @@ def test_update_endpoints_with_relation(harness): # Set some side effects to test multiple situations. postgresql_mock.get_postgresql_version = PropertyMock( side_effect=[ - PostgreSQLGetPostgreSQLVersionError, POSTGRESQL_VERSION, POSTGRESQL_VERSION, + PostgreSQLGetPostgreSQLVersionError, ] ) @@ -475,11 +478,6 @@ def test_update_endpoints_with_relation(harness): harness.charm.set_secret("app", user, password) harness.charm.set_secret("app", f"{user}-database", DATABASE) - # BlockedStatus due to a PostgreSQLGetPostgreSQLVersionError. - harness.charm.legacy_db_relation.update_endpoints(relation) - assert isinstance(harness.model.unit.status, BlockedStatus) - assert harness.get_relation_data(rel_id, harness.charm.unit.name) == {} - # Test with both a primary and a replica. # Update the endpoints with the event and check that it updated only # the right relation databags (the app and unit databags from the event). @@ -536,12 +534,21 @@ def test_update_endpoints_with_relation(harness): and standbys + user == unit_relation_data["standbys"] ) + # version is not updated due to a PostgreSQLGetPostgreSQLVersionError. + harness.charm.legacy_db_relation.update_endpoints() + _logger.exception.assert_called_once_with( + "Failed to retrieve the PostgreSQL version to initialise/update db relation" + ) + @patch_network_get(private_address="1.1.1.1") def test_update_endpoints_without_relation(harness): with ( patch.object(PostgresqlOperatorCharm, "postgresql", Mock()) as postgresql_mock, patch("charm.Patroni.get_primary") as _get_primary, + patch( + "relations.db.logger", + ) as _logger, patch( "charm.PostgresqlOperatorCharm.members_ips", new_callable=PropertyMock, @@ -558,9 +565,9 @@ def test_update_endpoints_without_relation(harness): # Set some side effects to test multiple situations. postgresql_mock.get_postgresql_version = PropertyMock( side_effect=[ - PostgreSQLGetPostgreSQLVersionError, POSTGRESQL_VERSION, POSTGRESQL_VERSION, + PostgreSQLGetPostgreSQLVersionError, ] ) _get_primary.return_value = harness.charm.unit.name @@ -588,11 +595,6 @@ def test_update_endpoints_without_relation(harness): harness.charm.set_secret("app", user, password) harness.charm.set_secret("app", f"{user}-database", DATABASE) - # BlockedStatus due to a PostgreSQLGetPostgreSQLVersionError. - harness.charm.legacy_db_relation.update_endpoints() - assert isinstance(harness.model.unit.status, BlockedStatus) - assert harness.get_relation_data(rel_id, harness.charm.unit.name) == {} - # Test with both a primary and a replica. # Update the endpoints and check that all relations' databags are updated. harness.charm.legacy_db_relation.update_endpoints() @@ -622,6 +624,12 @@ def test_update_endpoints_without_relation(harness): and standbys + user == unit_relation_data["standbys"] ) + # version is not updated due to a PostgreSQLGetPostgreSQLVersionError. + harness.charm.legacy_db_relation.update_endpoints() + _logger.exception.assert_called_once_with( + "Failed to retrieve the PostgreSQL version to initialise/update db relation" + ) + @patch_network_get(private_address="1.1.1.1") def test_get_allowed_units(harness):