diff --git a/lyrebird/mitm/mitm_script.py b/lyrebird/mitm/mitm_script.py index c17afae82..cca6b311f 100644 --- a/lyrebird/mitm/mitm_script.py +++ b/lyrebird/mitm/mitm_script.py @@ -55,8 +55,6 @@ def to_mock_server(flow: http.HTTPFlow): flow.request.headers['Lyrebird-Client-Address'] = address flow.request.headers['Mitmproxy-Proxy'] = address - flow.request.headers['Proxy-Raw-Headers'] = json.dumps({name: flow.request.headers[name] - for name in flow.request.headers if name.lower() not in ('host', 'proxy-raw-headers')}, ensure_ascii=False) def request(flow: http.HTTPFlow): diff --git a/lyrebird/mock/context.py b/lyrebird/mock/context.py index 377c98a93..273db5b86 100644 --- a/lyrebird/mock/context.py +++ b/lyrebird/mock/context.py @@ -155,6 +155,24 @@ def get_and_update_selected_filter_by_name(self, filter_name): EMIT_INTERVAL = 1 last_emit_time = {} + +""" +request header may casuse exception: +- Lyrebird internal header. eg. proxy-raw-headers +- The header should change each forwarding. eg. Host +""" +INTERNAL_PROXY_HEADERS = ( + 'lyrebird-client-address', + 'proxy-raw-headers' +) +SINGLE_FORWARD_IDENTIFY_HEADERS = ( + 'cache-control', + 'host', + 'transfer-encoding' +) +LYREBIRD_UNPROXY_HEADERS = INTERNAL_PROXY_HEADERS + SINGLE_FORWARD_IDENTIFY_HEADERS + + application = Application() def make_ok_response(**kwargs): diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index 5be10b4d9..f2b04bb32 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -11,6 +11,7 @@ from multidict import CIMultiDict from lyrebird.mock.extra_mock_server.lyrebird_proxy_protocol import LyrebirdProxyContext +from lyrebird.mock.context import SINGLE_FORWARD_IDENTIFY_HEADERS from lyrebird import log logger = None @@ -26,7 +27,6 @@ def is_filtered(context: LyrebirdProxyContext): allow list like ''' - global lb_config filters = lb_config.get('proxy.filters', []) for _filter in filters: if re.search(_filter, context.origin_url): @@ -34,17 +34,6 @@ def is_filtered(context: LyrebirdProxyContext): return False -def make_raw_headers_line(request: web.Request): - raw_headers = {} - for k, v in request.raw_headers: - raw_header_name = k.decode() - raw_header_value = v.decode() - if raw_header_name.lower() in ['cache-control', 'host', 'transfer-encoding', 'proxy-raw-headers']: - continue - raw_headers[raw_header_name] = raw_header_value - return json.dumps(raw_headers, ensure_ascii=False) - - def upgrade_request_report(context: LyrebirdProxyContext): if not context.request.headers.get('upgrade'): return @@ -55,16 +44,12 @@ def upgrade_request_report(context: LyrebirdProxyContext): def make_request_headers(context: LyrebirdProxyContext, is_proxy): - headers = {k: v for k, v in context.request.headers.items() if k.lower() not in [ - 'cache-control', 'host', 'transfer-encoding']} + headers = {k: v for k, v in context.request.headers.items() + if k.lower() not in SINGLE_FORWARD_IDENTIFY_HEADERS} if is_proxy: - if 'Proxy-Raw-Headers' in context.request.headers: - del headers['Proxy-Raw-Headers'] if 'Lyrebird-Client-Address' in context.request.headers: del headers['Lyrebird-Client-Address'] else: - if 'Proxy-Raw-Headers' not in context.request.headers: - headers['Proxy-Raw-Headers'] = make_raw_headers_line(context.request) if 'Lyrebird-Client-Address' not in context.request.headers: headers['Lyrebird-Client-Address'] = context.request.remote return headers diff --git a/lyrebird/mock/handlers/handler_context.py b/lyrebird/mock/handlers/handler_context.py index 182be4bcd..d7b68c8ce 100644 --- a/lyrebird/mock/handlers/handler_context.py +++ b/lyrebird/mock/handlers/handler_context.py @@ -8,7 +8,9 @@ from lyrebird import utils from lyrebird import application from lyrebird.log import get_logger +from lyrebird.utils import CaseInsensitiveDict from lyrebird.mock.blueprints.apis.bandwidth import config +from lyrebird.mock.context import LYREBIRD_UNPROXY_HEADERS from urllib.parse import urlparse, unquote from .http_data_helper import DataHelper from .http_header_helper import HeadersHelper @@ -66,8 +68,14 @@ def _parse_request(self): raw_headers = None # Read raw headers + # Proxy-Raw-Headers will be removed in future if 'Proxy-Raw-Headers' in self.request.headers: raw_headers = json.loads(self.request.headers['Proxy-Raw-Headers']) + elif '_raw_header' in self.request.environ: + raw_headers = CaseInsensitiveDict(self.request.environ['_raw_header']) + for key in LYREBIRD_UNPROXY_HEADERS: + if key in raw_headers: + del raw_headers[key] # parse path request_info = self._read_origin_request_info_from_url() @@ -208,7 +216,7 @@ def get_request_headers(self): headers = {} unproxy_headers = application.config.get('proxy.ignored_headers', {}) for name, value in self.flow['request']['headers'].items(): - if not value or name in ['Cache-Control', 'Host', 'Transfer-Encoding']: + if not value or name.lower() in LYREBIRD_UNPROXY_HEADERS: continue if name in unproxy_headers and unproxy_headers[name] in value: continue diff --git a/lyrebird/mock/handlers/http_data_helper/__init__.py b/lyrebird/mock/handlers/http_data_helper/__init__.py index 570fd8e3d..84b4247b7 100644 --- a/lyrebird/mock/handlers/http_data_helper/__init__.py +++ b/lyrebird/mock/handlers/http_data_helper/__init__.py @@ -1,6 +1,7 @@ from collections import OrderedDict from . import content_encoding, content_type from lyrebird.utils import CaseInsensitiveDict +from lyrebird.mock.context import LYREBIRD_UNPROXY_HEADERS import json origin2flow_handlers = OrderedDict({ @@ -25,10 +26,15 @@ def origin2flow(origin_obj, output=None, chain=None): if not _data: return - # Read raw headers, support the request from extra mock 9999 port if 'Proxy-Raw-Headers' in origin_obj.headers: _origin_headers = json.loads(origin_obj.headers['Proxy-Raw-Headers']) raw_headers = CaseInsensitiveDict(_origin_headers) + # Read raw headers, support the request from extra mock 9999 port + elif hasattr(origin_obj, 'environ') and '_raw_header' in origin_obj.environ: + raw_headers = CaseInsensitiveDict(origin_obj.environ['_raw_header']) + for key in LYREBIRD_UNPROXY_HEADERS: + if key in raw_headers: + del raw_headers[key] else: raw_headers = origin_obj.headers @@ -69,10 +75,15 @@ def origin2string(origin_obj, output=None): if not _data: return - # Read raw headers, support the request from extra mock 9999 port if 'Proxy-Raw-Headers' in origin_obj.headers: _origin_headers = json.loads(origin_obj.headers['Proxy-Raw-Headers']) raw_headers = CaseInsensitiveDict(_origin_headers) + # Read raw headers, support the request from extra mock 9999 port + elif hasattr(origin_obj, 'environ') and '_raw_header' in origin_obj.environ: + raw_headers = CaseInsensitiveDict(origin_obj.environ['_raw_header']) + for key in LYREBIRD_UNPROXY_HEADERS: + if key in raw_headers: + del raw_headers[key] else: raw_headers = origin_obj.headers diff --git a/lyrebird/mock/handlers/http_header_helper/__init__.py b/lyrebird/mock/handlers/http_header_helper/__init__.py index 4fb9fd247..47a9e9092 100644 --- a/lyrebird/mock/handlers/http_header_helper/__init__.py +++ b/lyrebird/mock/handlers/http_header_helper/__init__.py @@ -1,6 +1,7 @@ from collections import OrderedDict from .content_length import ContentLengthHandler from ..duplicate_header_key_handler import DuplicateHeaderKeyHandler +from lyrebird.utils import CaseInsensitiveDict origin2flow_handlers = OrderedDict({ }) @@ -41,4 +42,3 @@ def flow2origin(flow_obj, output=None, chain=None): output.headers = _headers else: return _headers - diff --git a/lyrebird/mock/mock_server.py b/lyrebird/mock/mock_server.py index 767c9585d..41aa2fa76 100644 --- a/lyrebird/mock/mock_server.py +++ b/lyrebird/mock/mock_server.py @@ -5,12 +5,14 @@ from .blueprints.ui import ui from .blueprints.core import core from flask_socketio import SocketIO +from werkzeug.serving import WSGIRequestHandler from ..version import VERSION from lyrebird.base_server import ThreadServer from lyrebird import application from lyrebird import log import sys import traceback +import functools """ Mock server @@ -121,3 +123,22 @@ def terminate(self): except Exception as e: pass print('CoreServer shutdown') + + +def monkey_patch_wsgi_request_handler(): + """ + environ of Werkzeug lost some Header capitalization information when processing request headers. + Although this is in compliance with RFC-7230, it may cause exceptions in some non-standard scenarios, + so it is recorded here to achieve compatibility. + """ + original_make_environ = WSGIRequestHandler.make_environ + + @functools.wraps(original_make_environ) + def patched_make_environ(self): + environ = original_make_environ(self) + environ['_raw_header'] = self.headers._headers + return environ + + WSGIRequestHandler.make_environ = patched_make_environ + +monkey_patch_wsgi_request_handler() diff --git a/lyrebird/utils.py b/lyrebird/utils.py index 9e70afdfb..bc9f744d1 100644 --- a/lyrebird/utils.py +++ b/lyrebird/utils.py @@ -347,10 +347,17 @@ class CaseInsensitiveDict(dict): def __init__(self, raw_dict=None): self.__key_map = {} - if raw_dict: + if not raw_dict: + return + if raw_dict and isinstance(raw_dict, dict): for k, v in raw_dict.items(): self.__setitem__(k, v) - + elif raw_dict and isinstance(raw_dict, list): + for k, v in raw_dict: + self.__setitem__(k, v) + else: + raise TypeError("Unexpected type for CaseInsensitiveDict") + def __getstate__(self): return { 'key_map': self.__key_map, diff --git a/setup.py b/setup.py index e957d9eda..a6764e45d 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ import runpy import os +import sys +from packaging import version from setuptools import setup, find_packages here = os.path.abspath(os.path.dirname(__file__)) @@ -9,10 +11,30 @@ )["VERSION"] -def read_requirements(name): - with open(os.path.join(here, name), encoding='utf-8') as f: - require_str = f.read() - return require_str.split() +def read_requirements(file_path): + with open(file_path, encoding='utf-8') as f: + return [ + line.strip().split(';')[0].strip() + for line in f + if line.strip() and not line.startswith('#') and ( + '; python_version' not in line or + check_version_condition(line.split(';')[1].strip()) + ) + ] + +def check_version_condition(condition): + if not condition.startswith('python_version'): + return True + op, ver = condition.split(' ', 1)[1].split(' ') + current_ver = version.parse('.'.join(map(str, sys.version_info[:2]))) + ver = version.parse(ver.strip('"')) + return { + '>': current_ver > ver, + '>=': current_ver >= ver, + '<': current_ver < ver, + '<=': current_ver <= ver, + '==': current_ver == ver + }.get(op, False) with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: