Skip to content

Commit

Permalink
request headers derives from dict
Browse files Browse the repository at this point in the history
  • Loading branch information
vdusek committed Sep 24, 2024
1 parent 414ac66 commit 649eaa5
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 93 deletions.
75 changes: 3 additions & 72 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from __future__ import annotations

from _collections_abc import dict_items, dict_keys
from collections.abc import Iterator, Mapping, MutableMapping
from collections.abc import Iterator, MutableMapping
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Annotated, Any, cast, overload
from typing import Annotated, Any, cast

from pydantic import (
BaseModel,
Expand All @@ -17,12 +16,11 @@
JsonValue,
PlainSerializer,
PlainValidator,
RootModel,
TypeAdapter,
)
from typing_extensions import Self

from crawlee._types import EnqueueStrategy, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._types import EnqueueStrategy, HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._utils.requests import compute_unique_key, unique_key_to_request_id
from crawlee._utils.urls import extract_query_params, validate_http_url

Expand Down Expand Up @@ -98,73 +96,6 @@ def __len__(self) -> int:
user_data_adapter = TypeAdapter(UserData)


class HttpHeaders(RootModel):
"""An immutable mapping for HTTP headers that ensures case-insensitivity for header names."""

def __init__(self, headers: Mapping[str, str] | None = None) -> None:
"""Create a new instance.
Args:
headers: A mapping of header names to values.
"""
# Ensure immutability by sorting and fixing the order.
headers = headers or {}
headers = {k.lower(): v for k, v in headers.items()}
self._headers = dict(sorted(headers.items()))

@property
def __dict__(self) -> dict[str, str]:
"""Return the headers as a dictionary."""
# We have to implement this because of `BaseModel.__iter__` implementation.
return dict(self._headers)

@__dict__.setter
def __dict__(self, value: dict[str, str]) -> None:
"""Set the headers from a dictionary."""
self._headers = {k.lower(): v for k, v in value.items()}

def __len__(self) -> int:
"""Return the number of headers."""
return len(self._headers)

def __repr__(self) -> str:
"""Return a string representation of the object."""
return f'{self.__class__.__name__}({self._headers})'

def __getitem__(self, key: str) -> str:
"""Get the value of a header by its name, case-insensitive."""
return self._headers[key.lower()]

def __setitem__(self, key: str, value: str) -> None:
"""Prevent setting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __delitem__(self, key: str) -> None:
"""Prevent deleting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def keys(self) -> dict_keys[str, str]:
"""Return an iterator over the header names."""
return self._headers.keys()

def items(self) -> dict_items[str, str]:
"""Return an iterator over the header names and values."""
return self._headers.items()

@overload
def get(self, key: str) -> str | None: ...

@overload
def get(self, key: str, default: str) -> str: ...

@overload
def get(self, key: str, default: None) -> None: ...

def get(self, key: str, default: str | None = None) -> str | None:
"""Returns the value of the header if it exists, otherwise returns the default."""
return self._headers.get(key, default)


class BaseRequestData(BaseModel):
"""Data needed to create a new crawling request."""

Expand Down
46 changes: 44 additions & 2 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
if TYPE_CHECKING:
import logging
import re
from collections.abc import Coroutine, Sequence
from collections.abc import Coroutine, Iterator, Sequence

from crawlee import Glob
from crawlee._request import BaseRequestData, HttpHeaders, Request
from crawlee._request import BaseRequestData, Request
from crawlee.base_storage_client._models import DatasetItemsListPage
from crawlee.http_clients import HttpResponse
from crawlee.proxy_configuration import ProxyInfo
Expand Down Expand Up @@ -221,3 +221,45 @@ async def add_requests(
) -> None:
"""Track a call to the `add_requests` context helper."""
self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs))


class HttpHeaders(dict[str, str]):
"""An immutable mapping for HTTP headers that ensures case-insensitivity for header names."""

def __init__(self, headers: dict[str, str] | None = None) -> None:
"""Create a new instance.
Args:
headers: A mapping of header names to values.
"""
with open('text.log', 'a') as f:
f.write(f'5: {headers}\n\n\n')

# Ensure immutability by sorting and fixing the order.
headers = headers or {}
headers = {k.lower(): v for k, v in headers.items()}
self._headers = dict(sorted(headers.items()))

def __iter__(self) -> Iterator[str]:
"""Return an iterator over the header names."""
return iter(self._headers)

def __len__(self) -> int:
"""Return the number of headers."""
return len(self._headers)

def __repr__(self) -> str:
"""Return a string representation of the object."""
return f'{self._headers}'

def __getitem__(self, key: str) -> str:
"""Get the value of a header by its name, case-insensitive."""
return self._headers[key.lower()]

def __setitem__(self, key: str, value: str) -> None:
"""Prevent setting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __delitem__(self, key: str) -> None:
"""Prevent deleting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')
4 changes: 2 additions & 2 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from crawlee._autoscaling.snapshotter import Snapshotter
from crawlee._autoscaling.system_status import SystemStatus
from crawlee._log_config import configure_logger, get_configured_log_level
from crawlee._request import BaseRequestData, HttpHeaders, Request, RequestState
from crawlee._types import BasicCrawlingContext, RequestHandlerRunResult, SendRequestFunction
from crawlee._request import BaseRequestData, Request, RequestState
from crawlee._types import BasicCrawlingContext, HttpHeaders, RequestHandlerRunResult, SendRequestFunction
from crawlee._utils.byte_size import ByteSize
from crawlee._utils.http import is_status_code_client_error
from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute
Expand Down
6 changes: 1 addition & 5 deletions src/crawlee/fingerprint_suite/_header_generator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from __future__ import annotations

import random
from typing import TYPE_CHECKING

from crawlee.fingerprint_suite._consts import COMMON_ACCEPT, COMMON_ACCEPT_LANGUAGE, USER_AGENT_POOL

if TYPE_CHECKING:
from collections.abc import Mapping


class HeaderGenerator:
"""Generates common headers for HTTP requests."""

def get_common_headers(self) -> Mapping[str, str]:
def get_common_headers(self) -> dict[str, str]:
"""Get common headers for HTTP requests.
We do not modify the 'Accept-Encoding', 'Connection' and other headers. They should be included and handled
Expand Down
3 changes: 1 addition & 2 deletions src/crawlee/http_clients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from crawlee._request import HttpHeaders
from crawlee._types import HttpMethod, HttpPayload, HttpQueryParams
from crawlee._types import HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams
from crawlee.base_storage_client._models import Request
from crawlee.proxy_configuration import ProxyInfo
from crawlee.sessions import Session
Expand Down
8 changes: 4 additions & 4 deletions src/crawlee/http_clients/_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx
from typing_extensions import override

from crawlee._request import HttpHeaders
from crawlee._types import HttpHeaders
from crawlee._utils.blocked import ROTATE_PROXY_ERRORS
from crawlee.errors import ProxyError
from crawlee.fingerprint_suite import HeaderGenerator
Expand Down Expand Up @@ -130,7 +130,7 @@ async def crawl(
http_request = client.build_request(
url=request.url,
method=request.method,
headers=dict(headers) if headers else None,
headers=headers,
params=request.query_params,
data=request.payload,
cookies=session.cookies if session else None,
Expand Down Expand Up @@ -177,7 +177,7 @@ async def send_request(
http_request = client.build_request(
url=url,
method=method,
headers=dict(headers) if headers else None,
headers=headers,
params=query_params,
data=payload,
extensions={'crawlee_session': session if self._persist_cookies_per_session else None},
Expand Down Expand Up @@ -230,7 +230,7 @@ def _combine_headers(self, explicit_headers: HttpHeaders | None) -> HttpHeaders
headers = HttpHeaders(common_headers)

if explicit_headers:
headers = HttpHeaders({**dict(headers), **dict(headers)})
headers = HttpHeaders({**headers, **explicit_headers})

return headers if headers else None

Expand Down
7 changes: 3 additions & 4 deletions src/crawlee/http_clients/curl_impersonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

from curl_cffi.requests import Response

from crawlee._request import HttpHeaders
from crawlee._types import HttpMethod, HttpPayload, HttpQueryParams
from crawlee._types import HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams
from crawlee.base_storage_client._models import Request
from crawlee.proxy_configuration import ProxyInfo
from crawlee.sessions import Session
Expand Down Expand Up @@ -119,7 +118,7 @@ async def crawl(
response = await client.request(
url=request.url,
method=request.method.upper(), # type: ignore # curl-cffi requires uppercase method
headers=dict(request.headers) if request.headers else None,
headers=request.headers,
params=request.query_params,
data=request.payload,
cookies=session.cookies if session else None,
Expand Down Expand Up @@ -164,7 +163,7 @@ async def send_request(
response = await client.request(
url=url,
method=method.upper(), # type: ignore # curl-cffi requires uppercase method
headers=dict(headers) if headers else None,
headers=headers,
params=query_params,
data=payload,
cookies=session.cookies if session else None,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import pytest

from crawlee import ConcurrencySettings, EnqueueStrategy, Glob
from crawlee._request import BaseRequestData, HttpHeaders, Request
from crawlee._types import AddRequestsKwargs, BasicCrawlingContext
from crawlee._request import BaseRequestData, Request
from crawlee._types import AddRequestsKwargs, BasicCrawlingContext, HttpHeaders
from crawlee.basic_crawler import BasicCrawler
from crawlee.configuration import Configuration
from crawlee.errors import SessionError, UserDefinedErrorHandlerError
Expand Down

0 comments on commit 649eaa5

Please sign in to comment.