diff --git a/emmaa/subscription/notifications.py b/emmaa/subscription/notifications.py index 07dae306f..efa8517f2 100644 --- a/emmaa/subscription/notifications.py +++ b/emmaa/subscription/notifications.py @@ -296,7 +296,8 @@ def get_model_deltas(model_name, date, model_stats, test_stats_by_corpus): model_name : str A name of the model to get the updates for. date : str - A date for which the updates should be generated. + A date for which the updates should be generated. The + format should be "YYYY-MM-DD". model_stats : dict A dictionary containing the stats for the given model. test_stats_by_corpus : dict @@ -308,18 +309,24 @@ def get_model_deltas(model_name, date, model_stats, test_stats_by_corpus): A dictionary containing the deltas for the given model and test corpora. """ - deltas = {} - deltas['model_name'] = model_name - deltas['date'] = date # Model deltas stmts_delta = model_stats['model_delta']['statements_hashes_delta'] paper_delta = model_stats['paper_delta']['raw_paper_ids_delta'] new_papers = len(paper_delta['added']) - deltas['stmts_delta'] = stmts_delta - deltas['new_papers'] = new_papers + # Test deltas - deltas['tests'] = {} + deltas = { + 'model_name': model_name, + 'date': date, + 'stmts_delta': stmts_delta, + 'new_papers': new_papers, + 'tests': {} + } for test_corpus, test_stats in test_stats_by_corpus.items(): + if test_stats is None: + logger.info(f"No test stats for {test_corpus}") + continue + test_deltas = {} test_name = None test_data = test_stats['test_round_summary'].get('test_data') @@ -392,7 +399,7 @@ def get_all_update_messages(deltas, is_tweet=False): return msg_dicts -def tweet_deltas(deltas, twitter_cred): +def tweet_deltas(deltas, twitter_cred, verbose=False): """Tweet the model updates. Parameters @@ -403,12 +410,19 @@ def tweet_deltas(deltas, twitter_cred): twitter_cred : dict A dictionary containing consumer_token, consumer_secret, access_token, and access_secret for a model Twitter account. + verbose : bool + If True, the return from `tweepy.Client.create_tweet` will be printed """ msgs = get_all_update_messages(deltas, is_tweet=True) for msg in msgs: - update_status(msg['message'], twitter_cred) + res = update_status(msg['message'], twitter_cred) + if verbose: + print(res) time.sleep(1) - logger.info('Done tweeting') + if msgs: + logger.info(f'Done tweeting {len(msgs)} messages') + else: + logger.info('No tweets to send') def make_model_html_email(msg_dicts, email, domain='emmaa.indra.bio'): @@ -428,7 +442,8 @@ def get_all_stats(model_name, test_corpora, date): test_corpora : list[str] A list of test corpora names to get the test updates for. date : str - A date for which the updates should be generated. + A date for which the updates should be generated. The + format should be "YYYY-MM-DD". Returns ------- @@ -443,7 +458,9 @@ def get_all_stats(model_name, test_corpora, date): test_stats, _ = get_model_stats(model_name, 'test', tests=test_corpus, date=date) if not test_stats: - logger.info(f'Could not find test stats for {test_corpus}') + logger.info( + f"Could not find test stats for {test_corpus} for date {date}" + ) test_stats_by_corpus[test_corpus] = test_stats return model_stats, test_stats_by_corpus diff --git a/emmaa/tests/test_reactome_prior.py b/emmaa/tests/test_reactome_prior.py index 2b6d08606..233700d6a 100644 --- a/emmaa/tests/test_reactome_prior.py +++ b/emmaa/tests/test_reactome_prior.py @@ -29,7 +29,7 @@ def test_get_pathways_containing_genes(): # Check if function returns a reasonable number of pathways assert len(KRAS_pathways) > 3 # Signaling Downstream of RAS mutants - assert 'R-HSA-9649948.1' in KRAS_pathways + assert 'R-HSA-9649948.2' in KRAS_pathways def test_get_genes_contained_in_pathway(): diff --git a/emmaa/util.py b/emmaa/util.py index 4de362151..e415c6d87 100644 --- a/emmaa/util.py +++ b/emmaa/util.py @@ -348,36 +348,76 @@ class NotAClassName(Exception): pass -def get_credentials(key): - client = boto3.client('ssm') +def get_credentials( + key: str, profile_name: str = None, cred_type: str = "oauth1_0a" +): + """Get twitter credentials from AWS SSM + + Parameters + ---------- + key : str + The initial key to the credentials in SSM. The full key will be + /twitter/{key}/{par} where par is determined by the type of + credentials. + profile_name : str + The name of the AWS profile to use. If None (default), the default + profile will be used. + cred_type : str + The type of credentials to get. Choices are "oauth1_0a" and "bearer". + Default: "oauth1_0a". Bearer uses OAuth 2.0. + + Returns + ------- + dict + A dictionary with the requested credentials. + """ + if profile_name is not None: + client = boto3.session.Session( + profile_name=profile_name).client('ssm', region_name='us-east-1') + else: + client = boto3.client('ssm') + params = ['app_id'] + if cred_type == 'oauth1_0a': + params += ['consumer_token', 'consumer_secret', + 'access_token', 'access_secret'] + elif cred_type == 'bearer': + params += ['bearer_token'] + else: + raise ValueError(f"Unknown credential type: {cred_type}. Must be one " + f"of oath1_0a or bearer.") auth_dict = {} - for par in ['consumer_token', 'consumer_secret', 'access_token', - 'access_secret']: + for par in params: name = f'/twitter/{key}/{par}' try: response = client.get_parameter(Name=name, WithDecryption=True) val = response['Parameter']['Value'] auth_dict[par] = val except Exception as e: - print(e) + logger.exception(e) break return auth_dict -def get_oauth_dict(auth_dict): - oauth = tweepy.OAuthHandler(auth_dict.get('consumer_token'), - auth_dict.get('consumer_secret')) - oauth.set_access_token(auth_dict.get('access_token'), - auth_dict.get('access_secret')) - return oauth - - def update_status(msg, twitter_cred): - twitter_auth = get_oauth_dict(twitter_cred) - if twitter_auth is None: - return - twitter_api = tweepy.API(twitter_auth) - twitter_api.update_status(msg) + if 'consumer_secret' in twitter_cred and 'access_secret' in twitter_cred: + twitter_client = tweepy.Client( + consumer_key=twitter_cred['consumer_token'], + consumer_secret=twitter_cred['consumer_secret'], + access_token=twitter_cred['access_token'], + access_token_secret=twitter_cred['access_secret'] + ) + user_auth = True + elif 'bearer_token' in twitter_cred: + twitter_client = tweepy.Client( + bearer_token=twitter_cred['bearer_token'] + ) + user_auth = False + else: + raise ValueError('Missing credentials') + + # Set user_auth=True when authenticating with consumer key/secret pair + # and access token/secret pair, and False when using bearer token + return twitter_client.create_tweet(text=msg, user_auth=user_auth) def _make_delta_msg(model_name, msg_type, delta, date, mc_type=None,