diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index 09520132364e..7d2b68da441a 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -593,6 +593,23 @@ def test_make_celery_app(): "redis://127.0.0.1:6379/10", {}, ), + ( + Environment.production, + True, + None, + ( + "rediss://user:pass@redis.example.com:6379/10" + "?socket_timeout=5&irreleveant=0" + "&ssl_cert_reqs=required&ssl_ca_certs=/p/a/t/h/cacert.pem" + ), + ( + "rediss://user:pass@redis.example.com:6379/10" + "?ssl_cert_reqs=required&ssl_ca_certs=/p/a/t/h/cacert.pem" + ), + { + "socket_timeout": 5, + }, + ), ], ) def test_includeme( diff --git a/warehouse/tasks.py b/warehouse/tasks.py index 830906cb88a3..4c1056ac2e48 100644 --- a/warehouse/tasks.py +++ b/warehouse/tasks.py @@ -222,6 +222,31 @@ def includeme(config): "tcp_keepalive": True, } + if broker_url.startswith("redis"): + parsed_url = urllib.parse.urlparse( # noqa: WH001, going to urlunparse this + broker_url + ) + parsed_query = urllib.parse.parse_qs(parsed_url.query) + + celery_transport_options = { + "socket_timeout": int, + } + + for key, value in parsed_query.copy().items(): + if key.startswith("ssl_"): + continue + else: + if key in celery_transport_options: + broker_transport_options[key] = celery_transport_options[key]( + value[0] + ) + del parsed_query[key] + + parsed_url = parsed_url._replace( + query=urllib.parse.urlencode(parsed_query, doseq=True, safe="/") + ) + broker_url = urllib.parse.urlunparse(parsed_url) + config.registry["celery.app"] = celery.Celery( "warehouse", autofinalize=False, set_as_current=False )