Skip to content

Commit

Permalink
Merge pull request #46 from tutorcruncher/async-webhooks
Browse files Browse the repository at this point in the history
Move webhook sending to async
  • Loading branch information
HenryTraill authored Oct 29, 2024
2 parents 09f893e + b7a2e25 commit b48327e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 92 deletions.
1 change: 1 addition & 0 deletions chronos/pydantic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ class RequestData(BaseModel):
response_headers: str = '{"Message": "No response from endpoint"}'
response_body: str = '{"Message": "No response from endpoint"}'
status_code: int = 999
successful_response: str = False
103 changes: 61 additions & 42 deletions chronos/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import hashlib
import hmac
import json
Expand All @@ -7,6 +8,7 @@
from celery.app import Celery
from fastapi import APIRouter
from fastapi_utilities import repeat_at
from httpx import AsyncClient
from sqlalchemy import delete
from sqlmodel import Session, col, select

Expand All @@ -22,9 +24,10 @@
celery_app.conf.broker_connection_retry_on_startup = True


def webhook_request(url: str, *, method: str = 'POST', webhook_sig: str, data: dict = None):
async def webhook_request(client: AsyncClient, url: str, *, webhook_sig: str, data: dict = None):
"""
Send a request to TutorCruncher
:param client
:param url: The endpoint supplied by clients when creating an integration in TC2
:param method: We should always be sending POST requests as we are sending data to the endpoints
:param webhook_sig: The signature generated by hashing the payload with the shared key
Expand All @@ -39,10 +42,10 @@ def webhook_request(url: str, *, method: str = 'POST', webhook_sig: str, data: d
'webhook-signature': webhook_sig,
}
logfire.debug('TutorCruncher request to url: {url=}: {data=}', url=url, data=data)
with logfire.span('{method=} {url!r}', url=url, method=method):
with logfire.span('{method=} {url!r}', url=url, method='POST'):
r = None
try:
r = session.request(method=method, url=url, json=data, headers=headers, timeout=4)
r = await client.post(url=url, json=data, headers=headers, timeout=4)
except requests.exceptions.HTTPError as httperr:
app_logger.info('HTTP error sending webhook to %s: %s', url, httperr)
except requests.exceptions.ConnectionError as conerr:
Expand All @@ -52,18 +55,15 @@ def webhook_request(url: str, *, method: str = 'POST', webhook_sig: str, data: d
except requests.exceptions.RequestException as rerr:
app_logger.info('Request error sending webhook to %s: %s', url, rerr)
else:
app_logger.info('Request method=%s url=%s status_code=%s', method, url, r.status_code, extra={'data': data})
app_logger.info('Request method=%s url=%s status_code=%s', 'POST', url, r.status_code, extra={'data': data})

request_data = RequestData(request_headers=json.dumps(headers), request_body=json.dumps(data))
if r is None:
webhook_was_received = False
else:
if r is not None:
request_data.response_headers = json.dumps(dict(r.headers))
request_data.response_body = json.dumps(r.content.decode())
request_data.status_code = r.status_code
webhook_was_received = True

return request_data, webhook_was_received
request_data.successful_response = True
return request_data


acceptable_url_schemes = ('http', 'https', 'ftp', 'ftps')
Expand All @@ -83,28 +83,11 @@ def get_qlength():
return qlength


@celery_app.task
def task_send_webhooks(
payload: str,
url_extension: str = None,
):
"""
Send the webhook to the relevant endpoints
"""
loaded_payload = json.loads(payload)
loaded_payload['_request_time'] = loaded_payload.pop('request_time')
branch_id = loaded_payload['events'][0]['branch']

qlength = get_qlength()
app_logger.info('Starting send webhook task for branch %s. qlength=%s.', branch_id, qlength)
if qlength > 100:
app_logger.error('Queue is too long. Check workers and speeds.')

async def _async_post_webhooks(endpoints, url_extension, payload):
webhook_logs = []
total_success, total_failed = 0, 0
with Session(engine) as db:
# Get all the endpoints for the branch
endpoints_query = select(WebhookEndpoint).where(WebhookEndpoint.branch_id == branch_id, WebhookEndpoint.active)
endpoints = db.exec(endpoints_query).all()
async with AsyncClient() as client:
tasks = []
for endpoint in endpoints:
# Check if the webhook URL is valid
if not endpoint.webhook_url.startswith(acceptable_url_schemes):
Expand All @@ -122,9 +105,16 @@ def task_send_webhooks(
if url_extension:
url += f'/{url_extension}'
# Send the Webhook to the endpoint
response, webhook_sent = webhook_request(url, webhook_sig=sig_hex, data=loaded_payload)

if not webhook_sent:
loaded_payload = json.loads(payload)
task = asyncio.ensure_future(webhook_request(client, url, webhook_sig=sig_hex, data=loaded_payload))
tasks.append(task)
webhook_responses = await asyncio.gather(*tasks, return_exceptions=True)
for response in webhook_responses:
if not isinstance(response, RequestData):
app_logger.info('No response from endpoint %s: %s. %s', endpoint.id, endpoint.webhook_url, response)
continue
elif not response.successful_response:
app_logger.info('No response from endpoint %s: %s', endpoint.id, endpoint.webhook_url)

if response.status_code in {200, 201, 202, 204}:
Expand All @@ -135,16 +125,45 @@ def task_send_webhooks(
total_failed += 1

# Log the response
webhooklog = WebhookLog(
webhook_endpoint_id=endpoint.id,
request_headers=response.request_headers,
request_body=response.request_body,
response_headers=response.response_headers,
response_body=response.response_body,
status=status,
status_code=response.status_code,
webhook_logs.append(
WebhookLog(
webhook_endpoint_id=endpoint.id,
request_headers=response.request_headers,
request_body=response.request_body,
response_headers=response.response_headers,
response_body=response.response_body,
status=status,
status_code=response.status_code,
)
)
db.add(webhooklog)
return webhook_logs, total_success, total_failed


@celery_app.task
def task_send_webhooks(
payload: str,
url_extension: str = None,
):
"""
Send the webhook to the relevant endpoints
"""
loaded_payload = json.loads(payload)
loaded_payload['_request_time'] = loaded_payload.pop('request_time')
branch_id = loaded_payload['events'][0]['branch']

qlength = get_qlength()
app_logger.info('Starting send webhook task for branch %s. qlength=%s.', branch_id, qlength)
if qlength > 100:
app_logger.error('Queue is too long. Check workers and speeds.')

with Session(engine) as db:
# Get all the endpoints for the branch
endpoints_query = select(WebhookEndpoint).where(WebhookEndpoint.branch_id == branch_id, WebhookEndpoint.active)
endpoints = db.exec(endpoints_query).all()

webhook_logs, total_success, total_failed = asyncio.run(_async_post_webhooks(endpoints, url_extension, payload))
for webhook_log in webhook_logs:
db.add(webhook_log)
db.commit()
app_logger.info(
'%s Webhooks sent for branch %s. Total Sent: %s. Total failed: %s',
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ setuptools~=75.2.0
SQLAlchemy~=2.0.36
sqlmodel~=0.0.22
starlette~=0.41.2
uvicorn~=0.32.0
uvicorn~=0.32.0
respx~=0.21.1
13 changes: 4 additions & 9 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from requests import Request, Response
from httpx import Response
from requests import Request

from chronos.main import app
from chronos.sql_models import WebhookEndpoint, WebhookLog
Expand Down Expand Up @@ -116,10 +117,7 @@ def get_successful_response(payload, headers, **kwargs) -> Response:
request = Request()
request.headers = headers
request.body = json.dumps(payload).encode()
response = Response()
response.request = request
response.status_code = 200
response._content = json.dumps(response_dict).encode()
response = Response(status_code=200, request=request, content=json.dumps(response_dict).encode())
return response


Expand All @@ -130,8 +128,5 @@ def get_failed_response(payload, headers, **kwargs) -> Response:
request = Request()
request.headers = headers
request.body = json.dumps(payload).encode()
response = Response()
response.request = request
response.status_code = 409
response._content = json.dumps(response_dict).encode()
response = Response(status_code=409, request=request, content=json.dumps(response_dict).encode())
return response
Loading

0 comments on commit b48327e

Please sign in to comment.