From 3b11287ef354c6af5761814413678af1a6e3d399 Mon Sep 17 00:00:00 2001 From: noO0oOo0ob <38344038+noO0oOo0ob@users.noreply.github.com> Date: Mon, 20 May 2024 19:40:58 +0800 Subject: [PATCH] Feature/add enable mulitprocess (#842) * change post processing * version * add report in eventserver * fix bug * fix unit test * add enable multiprocess * fix bug * fix bug * fix e2e and bug * fix e2e * fix e2e * dockerfile * add config * remove usless code * add report * bugfix * fixbug * reportor add threadpool * fix activate_group * fix activate_group * 1.add redishook 2.add dm_group patch 3.add signal IGN in processserver * supprt python38 syncmanager * 1 * review * dockerfile * version * recover pip install image source * deepcopy in channels, origin in any * change deepcopy only in any channel * recover deepcopy --------- Co-authored-by: wujiasheng03 --- Dockerfile | 6 +- e2e_tests/conftest.py | 7 +- lyrebird/__init__.py | 9 +- lyrebird/application.py | 83 ++++- lyrebird/base_server.py | 109 ++++-- lyrebird/checker/event.py | 20 +- lyrebird/compatibility.py | 234 +++++++++++++ lyrebird/config/__init__.py | 11 + lyrebird/db/database_server.py | 12 +- lyrebird/event.py | 357 +++++++++++++++----- lyrebird/log.py | 40 ++- lyrebird/manager.py | 52 ++- lyrebird/mitm/proxy_server.py | 6 +- lyrebird/mock/dm/__init__.py | 24 ++ lyrebird/mock/extra_mock_server/__init__.py | 6 +- lyrebird/mock/extra_mock_server/server.py | 6 +- lyrebird/mock/mock_server.py | 7 +- lyrebird/notice_center.py | 6 +- lyrebird/plugins/plugin_manager.py | 11 +- lyrebird/reporter.py | 105 ++++-- lyrebird/task.py | 4 +- lyrebird/utils.py | 244 ++++++++++++- lyrebird/version.py | 2 +- requirements.txt | 1 + requirements.txt.lock | 1 + tests/conftest.py | 10 + tests/test_common_api.py | 1 + tests/test_conf_api.py | 1 + tests/test_db.py | 2 + tests/test_dm_api.py | 1 + tests/test_dm_config.py | 1 - tests/test_event.py | 8 +- tests/test_mock_server.py | 1 + tests/test_plugins.py | 1 + 34 files changed, 1181 insertions(+), 208 deletions(-) create mode 100644 lyrebird/compatibility.py diff --git a/Dockerfile b/Dockerfile index 95808d411..47dd3a44b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,9 @@ WORKDIR /usr/src COPY --from=nodebuilder /usr/src/lyrebird/client/ /usr/src/lyrebird/client/ RUN if [[ -n "$USE_MIRROR" ]] ; then sed -i 's/dl-cdn.alpinelinux.org/mirrors.ustc.edu.cn/g' /etc/apk/repositories ; fi \ && apk update \ - && apk add --no-cache build-base jpeg-dev zlib-dev libffi-dev openssl-dev \ - && if [[ -n "$USE_MIRROR" ]] ; then pip install --upgrade pip && pip install --no-cache-dir . facebook-wda==0.8.1 jsonschema -i https://pypi.douban.com/simple ; else pip install --upgrade pip && pip install --no-cache-dir . facebook-wda==0.8.1 jsonschema ; fi \ - && pip install werkzeug==2.2.2 mitmproxy -t /usr/local/mitmenv \ + && apk add --no-cache build-base jpeg-dev zlib-dev libffi-dev openssl-dev redis \ + && if [[ -n "$USE_MIRROR" ]] ; then pip install --upgrade pip -i https://pypi.douban.com/simple && pip install --no-cache-dir . facebook-wda==0.8.1 jsonschema redis -i https://pypi.douban.com/simple ; else pip install --upgrade pip && pip install --no-cache-dir . facebook-wda==0.8.1 jsonschema redis ; fi \ + && if [[ -n "$USE_MIRROR" ]] ; then pip install werkzeug==2.2.2 mitmproxy -t /usr/local/mitmenv -i https://pypi.douban.com/simple ; else pip install werkzeug==2.2.2 mitmproxy -t /usr/local/mitmenv ; fi \ && rm -rf /usr/src \ && apk del --purge build-base jpeg-dev zlib-dev libffi-dev openssl-dev diff --git a/e2e_tests/conftest.py b/e2e_tests/conftest.py index 6f444de3f..6e58c7fb2 100644 --- a/e2e_tests/conftest.py +++ b/e2e_tests/conftest.py @@ -1,4 +1,5 @@ import os +import sys import signal import pytest import time @@ -89,7 +90,7 @@ def _find_free_port(self): def start(self, checker_path=[]): - cmdline = f'python3 -m lyrebird -b -v --no-mitm --mock {self.port} --extra-mock {self.extra_mock_port}' + cmdline = f'{sys.executable} -m lyrebird -b -v --no-mitm --mock {self.port} --extra-mock {self.extra_mock_port}' for path in checker_path: cmdline = cmdline + f' --script {path}' self.lyrebird_process = subprocess.Popen(cmdline, shell=True, start_new_session=True) @@ -103,9 +104,9 @@ def start(self, checker_path=[]): def stop(self): if self.lyrebird_process: try: - os.killpg(self.lyrebird_process.pid, signal.SIGTERM) + os.killpg(self.lyrebird_process.pid, signal.SIGINT) except PermissionError: - os.kill(self.lyrebird_process.pid, signal.SIGTERM) + os.kill(self.lyrebird_process.pid, signal.SIGINT) _wait_exception(requests.get, args=[self.api_status]) self.lyrebird_process = None diff --git a/lyrebird/__init__.py b/lyrebird/__init__.py index 3ab7881bd..bc4ba2d4d 100644 --- a/lyrebird/__init__.py +++ b/lyrebird/__init__.py @@ -33,7 +33,7 @@ def emit(event, *args, **kwargs): context.application.socket_io.emit(event, *args, **kwargs) -def subscribe(channel, func, *args, **kwargs): +def subscribe(channel, func, name='', *args, **kwargs): """ 订阅信号 @@ -42,7 +42,12 @@ def subscribe(channel, func, *args, **kwargs): :param sender: 信号发送者标识 """ # context.application.event_bus.subscribe(channel, func) - application.server['event'].subscribe(channel, func, *args, **kwargs) + func_info = { + 'name': name, + 'channel': channel, + 'func': func + } + application.server['event'].subscribe(func_info, *args, **kwargs) def publish(channel, event, *args, **kwargs): diff --git a/lyrebird/application.py b/lyrebird/application.py index 095e1a97a..7b32f0454 100644 --- a/lyrebird/application.py +++ b/lyrebird/application.py @@ -1,5 +1,7 @@ import webbrowser +import multiprocessing +from queue import Queue from flask import jsonify from functools import reduce @@ -75,6 +77,74 @@ def start_server(): def stop_server(): for name in server: server[name].stop() + sync_manager.broadcast_to_queues(None) + +def terminate_server(): + for name in server: + server[name].terminate() + + +class SyncManager(): + def __init__(self) -> None: + global sync_namespace + self.manager = multiprocessing.Manager() + self.async_objs = { + 'manager_queues': [], + 'multiprocessing_queues': [], + 'namespace': [], + 'locks': [] + } + sync_namespace = self.get_namespace() + + def get_namespace(self): + namespace = self.manager.Namespace() + self.async_objs['namespace'].append(namespace) + return namespace + + def get_queue(self): + queue = self.manager.Queue() + self.async_objs['manager_queues'].append(queue) + return queue + + def get_thread_queue(self): + queue = Queue() + return queue + + def get_multiprocessing_queue(self): + queue = multiprocessing.Queue() + self.async_objs['multiprocessing_queues'].append(queue) + return queue + + def get_lock(self): + lock = multiprocessing.Lock() + self.async_objs['locks'].append(lock) + return lock + + def broadcast_to_queues(self, msg): + for q in self.async_objs['multiprocessing_queues']: + q.put(msg) + for q in self.async_objs['manager_queues']: + q.put(msg) + + def destory(self): + for q in self.async_objs['multiprocessing_queues']: + q.close() + del q + for q in self.async_objs['manager_queues']: + q._close() + del q + for ns in self.async_objs['namespace']: + del ns + for lock in self.async_objs['locks']: + del lock + self.manager.shutdown() + self.manager.join() + self.manager = None + self.async_objs = None + + +sync_manager = {} +sync_namespace = {} class ConfigProxy: @@ -89,7 +159,10 @@ def __getitem__(self, k): return _cm.config[k] def raw(self): - return _cm.config + if hasattr(_cm.config, 'raw'): + return _cm.config.raw() + else: + return _cm.config config = ConfigProxy() @@ -151,6 +224,8 @@ def status_listener(event): module_status = system.get('status') if module_status == 'READY': status_checkpoints[module] = True + else: + status_checkpoints[module] = False is_all_status_checkpoints_ok = reduce(lambda x, y: x and y, status_checkpoints.values()) if is_all_status_checkpoints_ok: @@ -165,7 +240,11 @@ def status_listener(event): webbrowser.open(f'http://localhost:{config["mock.port"]}') def process_status_listener(): - server['event'].subscribe('system', status_listener) + server['event'].subscribe({ + 'name': 'status_listener', + 'channel': 'system', + 'func': status_listener + }) def status_ready(): diff --git a/lyrebird/base_server.py b/lyrebird/base_server.py index 79863e5dc..2fd276eb1 100644 --- a/lyrebird/base_server.py +++ b/lyrebird/base_server.py @@ -2,12 +2,21 @@ Base threading server class """ +import inspect from threading import Thread -from multiprocessing import Process, Queue +from multiprocessing import Process from lyrebird import application -service_msg_queue = Queue() +def check_process_server_run_function_compatibility(function): + # Check whether the run method is an old or new version by params. + if len(inspect.signature(function).parameters) == 4 and list(inspect.signature(function).parameters.keys())[0] == 'async_obj': + return True + else: + return False + + +service_msg_queue = None class ProcessServer: @@ -16,54 +25,80 @@ def __init__(self): self.running = False self.name = None self.event_thread = None + self.async_obj = {} self.args = [] self.kwargs = {} - def run(self, msg_queue, config, log_queue, *args, **kwargs): + def run(self, async_obj, config, *args, **kwargs): ''' - msg_queue - message queue for process server and main process - - #1. Send event to main process, - { - "type": "event", - "channel": "", - "content": {} - } - - #2. Send message to frontend - support channel: msgSuccess msgInfo msgError - { - "type": "ws", - "channel": "", - "content": "" - } - - config - lyrebird config dict - - log_queue - send log msg to logger process + async_obj is a dict + used to pass in all objects used for synchronization/communication between multiple processes + Usually msg_queue, config and log_queue is included + msg_queue: + message queue for process server and main process + + #1. Send event to main process, + { + "type": "event", + "channel": "", + "content": {} + } + + #2. Send message to frontend + support channel: msgSuccess msgInfo msgError + { + "type": "ws", + "channel": "", + "content": "" + } + + config: + lyrebird config dict + + log_queue: + send log msg to logger process ''' pass def start(self): if self.running: return + + from lyrebird.log import get_logger + logger = get_logger() global service_msg_queue - config = application.config.raw() + if service_msg_queue is None: + service_msg_queue = application.sync_manager.get_multiprocessing_queue() + config = application._cm.config logger_queue = application.server['log'].queue - self.server_process = Process(group=None, target=self.run, - args=[service_msg_queue, config, logger_queue, self.args], - kwargs=self.kwargs, - daemon=True) + + # run method has too many arguments. Merge the msg_queue, log_queue and so on into async_obj + # This code is used for compatibility with older versions of the run method in the plugin + # This code should be removed after all upgrades have been confirmed + if check_process_server_run_function_compatibility(self.run): + self.async_obj['logger_queue'] = logger_queue + self.async_obj['msg_queue'] = service_msg_queue + self.server_process = Process(group=None, target=self.run, + args=[self.async_obj, config, self.args], + kwargs=self.kwargs, + daemon=True) + else: + logger.warning(f'The run method in {type(self).__name__} is an old parameter format that will be removed in the future') + self.server_process = Process(group=None, target=self.run, + args=[service_msg_queue, config, logger_queue, self.args], + kwargs=self.kwargs, + daemon=True) self.server_process.start() self.running = True def stop(self): + self.running = False + + def terminate(self): if self.server_process: self.server_process.terminate() + self.server_process.join() self.server_process = None @@ -84,6 +119,9 @@ def start(self, *args, **kwargs): def stop(self): self.running = False # TODO terminate self.server_thread + + def terminate(self): + pass def run(self): """ @@ -100,16 +138,23 @@ def start(self, *args, **kwargs): def stop(self): pass + def terminate(self): + pass + class MultiProcessServerMessageDispatcher(ThreadServer): def run(self): global service_msg_queue + if service_msg_queue is None: + service_msg_queue = application.sync_manager.get_multiprocessing_queue() emit = application.server['mock'].socket_io.emit publish = application.server['event'].publish - while True: + while self.running: msg = service_msg_queue.get() + if msg is None: + break type = msg.get('type') if type == 'event': channel = msg.get('channel') diff --git a/lyrebird/checker/event.py b/lyrebird/checker/event.py index 39ddcdb11..9ac547004 100644 --- a/lyrebird/checker/event.py +++ b/lyrebird/checker/event.py @@ -5,32 +5,38 @@ class CheckerEventHandler: - def __call__(self, channel, *args, **kw): + def __call__(self, channel, process=True, *args, **kw): def func(origin_func): if not checker.scripts_tmp_storage.get(checker.TYPE_EVENT): checker.scripts_tmp_storage[checker.TYPE_EVENT] = [] checker.scripts_tmp_storage[checker.TYPE_EVENT].append({ 'name': origin_func.__name__, + 'origin': origin_func.__code__.co_filename, 'func': origin_func, - 'channel': channel + 'channel': channel, + 'process': process }) return origin_func return func def issue(self, title, message): + from lyrebird import application notice = { "title": title, "message": message } - self.check_notice(notice) + self.__class__.check_notice(notice) application.server['event'].publish('notice', notice) def publish(self, channel, message, *args, **kwargs): + from lyrebird import application if channel == 'notice': - self.check_notice(message) + self.__class__.check_notice(message) application.server['event'].publish(channel, message, *args, **kwargs) - def check_notice(self, notice): + @staticmethod + def check_notice(notice): + from lyrebird import application stack = inspect.stack() script_path = stack[2].filename script_name = script_path[script_path.rfind('/') + 1:] @@ -40,10 +46,10 @@ def check_notice(self, notice): @staticmethod def register(func_info): - application.server['event'].subscribe(func_info['channel'], func_info['func']) + application.server['event'].subscribe(func_info) @staticmethod def unregister(func_info): - application.server['event'].unsubscribe(func_info['channel'], func_info['func']) + application.server['event'].unsubscribe(func_info) event = CheckerEventHandler() diff --git a/lyrebird/compatibility.py b/lyrebird/compatibility.py new file mode 100644 index 000000000..b092c0eb1 --- /dev/null +++ b/lyrebird/compatibility.py @@ -0,0 +1,234 @@ +import sys +import platform +import importlib.util +from lyrebird import log +from functools import wraps +from inspect import signature +from multiprocessing import managers +from multiprocessing.managers import Namespace + +logger = log.get_logger() + + +application_white_map = { + 'config', + '_cm' +} + + +context_white_map = { + 'application.data_manager.activated_data', + 'application.data_manager.activated_group', + 'application.data_manager.async_activate_group' +} + + +decorator_compat_code = ''' +def jit(*args, **kwargs): + def decorator(func): + print(f"{func.__name__} run in a cover function") + return func + return decorator +''' + +PYTHON_MIN_VERSION = (3, 8, 0) +PYTHON_MAX_VERSION = (3, 11, float('inf')) + + +''' +Python3.8 and earlier do not support managing SyncManger managed objects +github original issue: https://github.com/python/cpython/pull/4819 +''' +orig_AutoProxy = managers.AutoProxy + +@wraps(managers.AutoProxy) +def AutoProxy(*args, incref=True, manager_owned=False, **kwargs): + # Create the autoproxy without the manager_owned flag, then + # update the flag on the generated instance. If the manager_owned flag + # is set, `incref` is disabled, so set it to False here for the same + # result. + autoproxy_incref = False if manager_owned else incref + proxy = orig_AutoProxy(*args, incref=autoproxy_incref, **kwargs) + proxy._owned_by_manager = manager_owned + return proxy + + +def compat_async_manager_to_python_3_8(): + if "manager_owned" in signature(managers.AutoProxy).parameters: + return + + logger.debug("Patching multiprocessing.managers.AutoProxy to add manager_owned") + managers.AutoProxy = AutoProxy + + # re-register any types already registered to SyncManager without a custom + # proxy type, as otherwise these would all be using the old unpatched AutoProxy + SyncManager = managers.SyncManager + registry = managers.SyncManager._registry + for typeid, (callable, exposed, method_to_typeid, proxytype) in registry.items(): + if proxytype is not orig_AutoProxy: + continue + create_method = hasattr(managers.SyncManager, typeid) + SyncManager.register( + typeid, + callable=callable, + exposed=exposed, + method_to_typeid=method_to_typeid, + create_method=create_method, + ) + + +def import_compat_util(module_name:str, module_content:list): + module_spec = importlib.util.spec_from_loader(module_name, loader=None, origin='string', is_package=False) + module_obj = importlib.util.module_from_spec(module_spec) + for code in module_content: + exec(code, module_obj.__dict__) + sys.modules[module_name] = module_obj + + +def compat_redis_check(): + try: + import redis + except Exception as e: + logger.error(f'redis import failed. Please check that the library is installed correctly in your python environment') + return False + return True + + +def compat_python_version_check(): + version = platform.python_version_tuple() + major = int(version[0]) + minor = int(version[1]) + minor_minor = int(version[2]) + if major < PYTHON_MIN_VERSION[0] or \ + (major == PYTHON_MIN_VERSION[0] and minor < PYTHON_MIN_VERSION[1]) or \ + (major == PYTHON_MIN_VERSION[0] and minor == PYTHON_MIN_VERSION[1] and minor_minor == PYTHON_MIN_VERSION[2]): + msg = ( + 'The python version is too early. Please use Python version ' + f'{PYTHON_MIN_VERSION[0]}.{PYTHON_MIN_VERSION[1]}.{PYTHON_MIN_VERSION[2]} or later.' + ) + logger.error(msg) + return False + if major > PYTHON_MAX_VERSION[0] or \ + (major == PYTHON_MAX_VERSION[0] and minor > PYTHON_MAX_VERSION[1]) or \ + (major == PYTHON_MAX_VERSION[0] and minor == PYTHON_MAX_VERSION[1] and minor_minor > PYTHON_MAX_VERSION[2]): + msg = ( + 'python version is too high, Lyrebird is not supported,' + 'the current Lyrebird support version is ' + f'{PYTHON_MAX_VERSION[0]}.{PYTHON_MAX_VERSION[1]}.{"x" if isinstance(PYTHON_MAX_VERSION[2], float) else PYTHON_MAX_VERSION[2]}.' + ) + logger.warning(msg) + return True + + +def prepare_application_for_monkey_patch() -> Namespace: + from lyrebird import application, context + namespace = application.sync_manager.get_namespace() + namespace.application = ProcessApplicationInfo(application, application_white_map) + namespace.context = ProcessApplicationInfo(context, context_white_map) + namespace.queue = application.server['event'].event_queue + return namespace + + +def monkey_patch_application(async_obj, async_funcs=None, async_values=None): + import lyrebird + from lyrebird.event import EventServer + from lyrebird import event + + msg_queue = async_obj['msg_queue'] + process_namespace = async_obj['process_namespace'] + + lyrebird.application = process_namespace.application + lyrebird.application.config = process_namespace.application._cm.config + lyrebird.context = process_namespace.context + + monkey_patch_datamanager() + + if async_funcs: + checker_event_server = EventServer(True) + checker_event_server.event_queue = msg_queue + lyrebird.application['server'] = ProcessApplicationInfo() + lyrebird.application.server['event'] = checker_event_server + checker_event_server.__class__.publish = async_funcs['publish'] + event.__class__.publish = async_funcs['publish'] + event.__class__.issue = async_funcs['issue'] + + +def monkey_patch_datamanager(): + from lyrebird import context + context.application.data_manager.activate_group = context.application.data_manager.async_activate_group + + +def monkey_patch_publish(channel, message, publish_queue, *args, **kwargs): + from lyrebird.event import EventServer + from lyrebird.checker.event import CheckerEventHandler + if channel == 'notice': + CheckerEventHandler.check_notice(message) + + event_id, channel, message = EventServer.get_publish_message(channel, message) + publish_queue.put((event_id, channel, message, args, kwargs)) + + +def monkey_patch_issue(title, message, publish_queue, *args, **kwargs): + from lyrebird.event import EventServer + from lyrebird.checker.event import CheckerEventHandler + notice = { + "title": title, + "message": message + } + CheckerEventHandler.check_notice(notice) + + event_id, channel, message = EventServer.get_publish_message('notice', notice) + publish_queue.put((event_id, channel, message, args, kwargs)) + + +class ProcessApplicationInfo(dict): + + def __init__(self, data=None, white_map={}): + super().__init__() + for path in white_map: + value = self._get_value_from_path(data, path) + if value is not None: + self._set_value_to_path(path, value) + + def _get_value_from_path(self, data, path): + keys = path.split('.') + value = data + for key in keys: + value = getattr(value, key) + if value is None: + return None + return value + + def _set_value_to_path(self, path, value): + keys = path.split('.') + current_dict = self + for key in keys[:-1]: + if key not in current_dict: + current_dict[key] = ProcessApplicationInfo() + current_dict = current_dict[key] + current_dict[keys[-1]] = value + + + def __getattr__(self, item): + value = self + for key in item.split('.'): + value = value.get(key) + if value is None: + break + return value + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, item): + del self[item] + + +compat_python_version_check() +compat_async_manager_to_python_3_8() diff --git a/lyrebird/config/__init__.py b/lyrebird/config/__init__.py index 90216dcc1..278bf573e 100644 --- a/lyrebird/config/__init__.py +++ b/lyrebird/config/__init__.py @@ -8,6 +8,7 @@ from lyrebird import log as nlog from lyrebird import application +from lyrebird.utils import RedisDict from lyrebird.config.diff_mode import SettingDiffMode from lyrebird.config.checker_switch import SettingCheckerSwitch @@ -63,6 +64,16 @@ def __init__(self, conf_path_list=None, custom_conf=None): if custom_conf: self.update_conf_custom(custom_conf) + if self.config.get('enable_multiprocess', False): + try: + self.config = RedisDict(data=self.config, + host=self.config.get('redis_host', '127.0.0.1'), + port=self.config.get('redis_port', 6379), + db=self.config.get('redis_db', 0)) + except Exception as e: + self.config['enable_multiprocess'] = False + logger.error(f'Start enable multiprocess failed, Redis connection error:\n{e}') + self.initialize_personal_config() def update_conf_source(self, path): diff --git a/lyrebird/db/database_server.py b/lyrebird/db/database_server.py index 20a10a2d4..7a3db972e 100644 --- a/lyrebird/db/database_server.py +++ b/lyrebird/db/database_server.py @@ -4,7 +4,6 @@ import traceback import time import copy -from queue import Queue from pathlib import Path from lyrebird import application from lyrebird import log @@ -59,10 +58,15 @@ def __init__(self, path=None): logger.warning("Restarting will delete the broken database by default, historical events in inspector-pro will be lost, please be careful.") # init queue - self.storage_queue = Queue() + self.storage_queue = application.sync_manager.get_queue() # subscribe all channel - application.server['event'].subscribe('any', self.event_receiver) + application.server['event'].subscribe({ + 'name': 'event_receiver', + 'origin': self, + 'channel': 'any', + 'func': self.event_receiver + }) def auto_alter_tables(self, engine): metadata = MetaData() @@ -153,6 +157,8 @@ def run(self): while self.running: try: event = self.storage_queue.get() + if event is None: + break session.add(event) session.commit() context.emit('db_action', 'add event log') diff --git a/lyrebird/event.py b/lyrebird/event.py index 16fdf56a5..ac53ff88d 100644 --- a/lyrebird/event.py +++ b/lyrebird/event.py @@ -4,20 +4,33 @@ Worked as a backgrund thread Run events handler and background task worker """ -from queue import Queue -from concurrent.futures import ThreadPoolExecutor -import traceback -import inspect +import os +import imp +import sys +import copy import uuid import time import copy -from lyrebird.base_server import ThreadServer +import types +import signal +import pickle +import inspect +import importlib +import functools +import traceback +import setuptools +from concurrent.futures import ThreadPoolExecutor +from lyrebird.base_server import ThreadServer, ProcessServer +from lyrebird.compatibility import prepare_application_for_monkey_patch, monkey_patch_application, monkey_patch_issue, monkey_patch_publish from lyrebird import application from lyrebird.mock import context -from lyrebird.log import get_logger +from lyrebird import log +from pathlib import Path -logger = get_logger() +logger = log.get_logger() +# only report the checker which duration more the 5s +LYREBIRD_METRICS_REPORT_DURSTION = 5000 class InvalidMessage(Exception): @@ -34,25 +47,175 @@ def __init__(self, event_id, channel, message): self.channel = channel self.message = message + def __getstate__(self): + return pickle.dumps({ + 'event_id':self.id, + 'channel':self.channel, + 'message':self.message + }) -class EventServer(ThreadServer): + def __setstate__(self, state): + data = pickle.loads(state) + self.id = data['event_id'] + self.channel = data['channel'] + self.message = data['message'] + + +def import_func_module(path): + path = os.path.dirname(path) + packages = setuptools.find_packages(path) + for pkg in packages: + manifest_file = Path(path)/pkg/'manifest.py' + if not manifest_file.exists(): + continue + if pkg in sys.modules: + continue + sys.path.append(str(Path(path)/pkg)) + imp.load_package(pkg, Path(path)/pkg) + + +def import_func_from_file(filepath, func_name): + name = os.path.basename(filepath)[:-3] + if name in sys.modules: + module = sys.modules[name] + else: + # 从文件加载模块 + spec = importlib.util.spec_from_file_location(name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules[name] = module + return getattr(module, func_name) + + +def get_func_from_obj(obj, method_name): + return getattr(obj, method_name) + + +def get_callback_func(func_ori, func_name): + if isinstance(func_ori, str): + return import_func_from_file(func_ori, func_name) + elif isinstance(func_ori, object): + return get_func_from_obj(func_ori, func_name) + else: + logger.error(f'The source type of method {func_name} is invalid, exception method source: {func_ori}') + + +def callback_func_run_statistic(callback_fn, args, kwargs, report_info): + from lyrebird import application + event_start_time = time.time() + callback_fn(*args, **kwargs) + event_end_time = time.time() + event_duration = (event_end_time - event_start_time) * 1000 + # Report the operation of Event + # Prevent loop reporting, and only long time event(more than 5s) are reported + if event_duration < LYREBIRD_METRICS_REPORT_DURSTION: + return + if report_info['channel'] == 'lyrebird_metrics': + return + if not application.config.get('event.lyrebird_metrics_report', True): + return + application.server['event'].publish('lyrebird_metrics', { + 'sender': 'EventServer', + 'action': 'broadcast_handler', + 'duration': event_duration, + 'trace_info': str(report_info['trace_info']) + }) + + +class CustomExecuteServer(ProcessServer): + def __init__(self): + super().__init__() + self.event_thread_executor = None + + def run(self, async_obj, config, *args, **kwargs): + self.event_thread_executor = ThreadPoolExecutor(max_workers=async_obj['max_thread_workers']) + + signal.signal(signal.SIGINT, signal.SIG_IGN) + + for _, path in async_obj['plugins']: + import_func_module(path) + + log_queue = async_obj['logger_queue'] + process_queue = async_obj['process_queue'] + publish_queue = async_obj['publish_queue'] + + # monkey_patch is performed on the context content of the process to ensure + # that functions of Lyrebird can still be used in the process. + + async_funcs = {} + async_funcs['publish'] = functools.partial(monkey_patch_publish, publish_queue=publish_queue) + async_funcs['issue'] = functools.partial(monkey_patch_issue, publish_queue=publish_queue) + monkey_patch_application(async_obj, async_funcs) + + log.init(config, log_queue) + self.running = True + + while self.running: + try: + msg = process_queue.get() + if not msg: + break + func_ori, func_name, callback_args, callback_kwargs, info = msg + callback_fn = get_callback_func(func_ori, func_name) + self.event_thread_executor.submit(callback_func_run_statistic, callback_fn, callback_args, callback_kwargs, info) + except Exception: + traceback.print_exc() + + def start(self): + plugins = application.server['plugin'].plugins + plugins = [(p_name, plugin.location) for p_name, plugin in plugins.items()] + self.async_obj['plugins'] = plugins + self.async_obj['max_thread_workers'] = application.config.get('event.multiprocess.thread_max_worker', 1) + super().start() + +class PublishServer(ThreadServer): def __init__(self): super().__init__() - self.event_queue = Queue() + self.publish_msg_queue = application.sync_manager.get_multiprocessing_queue() + + def run(self): + while self.running: + try: + msg = self.publish_msg_queue.get() + if not msg: + break + event_id, channel, message, args, kwargs = msg + application.server['event'].publish(channel, message, event_id=event_id, *args, **kwargs) + except Exception: + traceback.print_exc() + + +class EventServer(ThreadServer): + + async_starting = False + + def __init__(self, no_start = False): + super().__init__() self.state = {} self.pubsub_channels = {} # channel name is 'any'. Linstening on all channel self.any_channel = [] - self.broadcast_executor = ThreadPoolExecutor(thread_name_prefix='event-broadcast-') - self.only_report_channel = application.config.get('event.only_report_channel', []) - self.lyrebird_metrics_report = application.config.get('event.lyrebird_metrics_report', True) - - def broadcast_handler(self, callback_fn, event, args, kwargs): + self.process_executor_queue = None + self.event_queue = None + self.broadcast_executor = None + self.process_executor = None + self.publish_server = None + self.only_report_channel = None + if not no_start: + self.only_report_channel = application.config.get('event.only_report_channel', []) + self.process_executor_queue = application.sync_manager.get_multiprocessing_queue() + self.event_queue = application.sync_manager.get_queue() + self.broadcast_executor = ThreadPoolExecutor(thread_name_prefix='event-broadcast-') + self.process_executor = CustomExecuteServer() + self.publish_server = PublishServer() + + def broadcast_handler(self, func_info, event, args, kwargs, process_queue=None): """ """ - event_start_time = time.time() + callback_fn = func_info.get('func') + is_process = func_info.get('process') # Check func_sig = inspect.signature(callback_fn) @@ -60,6 +223,9 @@ def broadcast_handler(self, callback_fn, event, args, kwargs): if len(func_parameters) < 1 or func_parameters[0].default != inspect._empty: logger.error(f'Event callback function [{callback_fn.__name__}] need a argument for receiving event object') return + + # get enable multiprocess channel list + multiprocess_channel_list = application.config.get('event.multiprocess.channels', []) # Append event content to args callback_args = [] @@ -73,55 +239,68 @@ def broadcast_handler(self, callback_fn, event, args, kwargs): callback_kwargs['channel'] = event.channel if 'event_id' in func_sig.parameters: callback_kwargs['event_id'] = event.id + # add report info + info = dict() + info['trace_info'] = { + 'channel': event.channel, + 'event_id': event.id, + 'callback_fn': callback_fn.__name__, + 'callback_kwargs': str(callback_kwargs) + } + info['channel'] = event.channel # Execute callback function try: - callback_fn(*callback_args, **callback_kwargs) + if EventServer.async_starting and is_process and isinstance(callback_fn, types.FunctionType) and event.channel in multiprocess_channel_list: + process_queue.put(( + func_info.get('origin'), + func_info.get('name'), + callback_args, + callback_kwargs, + info + )) + else: + callback_func_run_statistic(callback_fn, callback_args, callback_kwargs, info) except Exception: logger.error(f'Event callback function [{callback_fn.__name__}] error. {traceback.format_exc()}') - finally: - # Report the operation of Event - # Prevent loop reporting, and only time-consuming event(more than 1ms) are reported - if event.channel == 'lyrebird_metrics': - return - if not self.lyrebird_metrics_report: - return - event_end_time = time.time() - event_duration = (event_end_time - event_start_time) * 1000 - if event_duration > 1: - trace_info = { - 'channel': event.channel, - 'callback_fn': callback_fn.__name__, - 'callback_args': str(callback_args), - 'callback_kwargs': str(callback_kwargs) - } - self.publish('lyrebird_metrics', { - 'sender': 'EventServer', - 'action': 'broadcast_handler', - 'duration': event_duration, - 'trace_info': str(trace_info) - }) def run(self): while self.running: try: e = self.event_queue.get() + if not e: + break # Deep copy event for async event system e = copy.deepcopy(e) callback_fn_list = self.pubsub_channels.get(e.channel) if callback_fn_list: for callback_fn, args, kwargs in callback_fn_list: - self.broadcast_executor.submit(self.broadcast_handler, callback_fn, e, args, kwargs) + self.broadcast_executor.submit(self.broadcast_handler, callback_fn, e, args, kwargs, self.process_executor_queue) for callback_fn, args, kwargs in self.any_channel: self.broadcast_executor.submit(self.broadcast_handler, callback_fn, e, args, kwargs) except Exception: # empty event traceback.print_exc() + def async_start(self): + if not self.publish_server.running: + self.publish_server.start() + self.process_namespace = prepare_application_for_monkey_patch() + self.process_executor.async_obj['process_queue'] = self.process_executor_queue + self.process_executor.async_obj['process_namespace'] = self.process_namespace + self.process_executor.async_obj['publish_queue'] = self.publish_server.publish_msg_queue + self.process_executor.async_obj['eventserver'] = EventServer + self.process_executor.start() + EventServer.async_starting = True + def stop(self): - super().stop() self.publish('system', {'name': 'event.stop'}) + time.sleep(1) + super().stop() + self.process_executor.stop() + self.publish_server.stop() - def _check_message_format(self, message): + @staticmethod + def _check_message_format(message): """ Check if the message content is valid. Such as: 'message' value must be a string. @@ -132,7 +311,37 @@ def _check_message_format(self, message): if not isinstance(message_value, str): raise InvalidMessage('Value of key "message" must be a string.') - def publish(self, channel, message, state=False, *args, **kwargs): + @staticmethod + def get_publish_message(channel, message, event_id=None): + if not event_id: + # Make event id + event_id = str(uuid.uuid4()) + + # Make sure event is dict + if not isinstance(message, dict): + # Plugins send a array list as message, then set this message to raw property + _msg = {'raw': message} + message = _msg + + EventServer._check_message_format(message) + + message['channel'] = channel + message['id'] = event_id + message['timestamp'] = round(time.time(), 3) + + # Add event sender + stack = inspect.stack() + script_path = stack[2].filename + script_name = script_path[script_path.rfind('/') + 1:] + function_name = stack[2].function + sender_dict = { + "file": script_name, + "function": function_name + } + message['sender'] = sender_dict + return (event_id, channel, message) + + def publish(self, channel, message, state=False, event_id=None, *args, **kwargs): """ publish message @@ -145,31 +354,7 @@ def publish(self, channel, message, state=False, *args, **kwargs): if state is true, message will be kept as state """ - # Make event id - event_id = str(uuid.uuid4()) - - # Make sure event is dict - if not isinstance(message, dict): - # Plugins send a array list as message, then set this message to raw property - _msg = {'raw': message} - message = _msg - - self._check_message_format(message) - - message['channel'] = channel - message['id'] = event_id - message['timestamp'] = round(time.time(), 3) - - # Add event sender - stack = inspect.stack() - script_path = stack[2].filename - script_name = script_path[script_path.rfind('/') + 1:] - function_name = stack[2].function - sender_dict = { - "file": script_name, - "function": function_name - } - message['sender'] = sender_dict + event_id, channel, message = EventServer.get_publish_message(channel, message, event_id) if channel in self.pubsub_channels or channel not in self.only_report_channel: self.event_queue.put(Event(event_id, channel, message)) @@ -194,7 +379,7 @@ def publish(self, channel, message, state=False, *args, **kwargs): logger.debug(f'channel={channel} state={state}\nmessage:\n-----------\n{message}\n-----------\n') - def subscribe(self, channel, callback_fn, *args, **kwargs): + def subscribe(self, func_info, *args, **kwargs): """ Subscribe channel with a callback function That function will be called when a new message was published into it's channel @@ -202,25 +387,29 @@ def subscribe(self, channel, callback_fn, *args, **kwargs): callback function kwargs: channel=None receive channel name """ + channel = func_info['channel'] + if 'process' not in func_info: + func_info['process'] = True if channel == 'any': - self.any_channel.append([callback_fn, args, kwargs]) + self.any_channel.append([func_info, args, kwargs]) else: callback_fn_list = self.pubsub_channels.setdefault(channel, []) - callback_fn_list.append([callback_fn, args, kwargs]) + callback_fn_list.append([func_info, args, kwargs]) - def unsubscribe(self, channel, target_callback_fn, *args, **kwargs): + def unsubscribe(self, target_func_info, *args, **kwargs): """ Unsubscribe callback function from channel """ + channel = target_func_info['channel'] if channel == 'any': - for any_channel_fn, *_ in self.any_channel: - if target_callback_fn == any_channel_fn: - self.any_channel.remove([target_callback_fn, *_]) + for any_info, *_ in self.any_channel: + if target_func_info['func'] == any_info['func']: + self.any_channel.remove([target_func_info, *_]) else: callback_fn_list = self.pubsub_channels.get(channel) - for callback_fn, *_ in callback_fn_list: - if target_callback_fn == callback_fn: - callback_fn_list.remove([target_callback_fn, *_]) + for callback_fn_info, *_ in callback_fn_list: + if target_func_info['func'] == callback_fn_info['func']: + callback_fn_list.remove([target_func_info, *_]) class CustomEventReceiver: @@ -247,11 +436,19 @@ def func(origin_func): def register(self, event_bus): for listener in self.listeners: - event_bus.subscribe(listener['channel'], listener['func']) + event_bus.subscribe({ + 'name': 'CustomEventReceiver', + 'channel': listener['channel'], + 'func': listener['func'] + }) def unregister(self, event_bus): for listener in self.listeners: - event_bus.unsubscribe(listener['channel'], listener['func']) + event_bus.unsubscribe({ + 'name': 'CustomEventReceiver', + 'channel': listener['channel'], + 'func': listener['func'] + }) def publish(self, channel, message, *args, **kwargs): application.server['event'].publish(channel, message, *args, **kwargs) @@ -259,6 +456,6 @@ def publish(self, channel, message, *args, **kwargs): def issue(self, title, message): notice = { "title": title, - "message": message + "message": f'[CustomEventReceiver]: message' } application.server['event'].publish('notice', notice) diff --git a/lyrebird/log.py b/lyrebird/log.py index a5c55f717..3ebeaf367 100644 --- a/lyrebird/log.py +++ b/lyrebird/log.py @@ -1,10 +1,11 @@ import logging +from lyrebird import application from .base_server import ProcessServer -from multiprocessing import Queue, Lock from logging.handlers import TimedRotatingFileHandler from colorama import Fore, Style, Back from collections import namedtuple from pathlib import Path +import signal import os DEFAULT_LOG_PATH = '~/.lyrebird/lyrebird.log' LOGGER_INITED = False @@ -21,6 +22,7 @@ ) process = None +queue_handler = None def colorit(message, levelname): @@ -79,6 +81,7 @@ def check_path(path): def init(config, log_queue = None): global LOGGER_INITED + global queue_handler if LOGGER_INITED: return @@ -115,18 +118,18 @@ class LogServer(ProcessServer): def __init__(self): super().__init__() - self.queue = Queue() - self.log_process_lock = Lock() + self.queue = application.sync_manager.get_multiprocessing_queue() def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super().__new__(cls, *args, **kwargs) return cls._instance - def run(self, msg_queue, config, log_queue, *args, **kwargs): - if not self.log_process_lock.acquire(timeout=10): - return - + def run(self, async_obj, config, *args, **kwargs): + log_queue = async_obj['logger_queue'] + + signal.signal(signal.SIGINT, signal.SIG_IGN) + logging.addLevelName(60, 'NOTICE') stream_handler = make_stream_handler() @@ -154,15 +157,32 @@ def run(self, msg_queue, config, log_queue, *args, **kwargs): if log_path and not check_path(log_path): lyrebird_logger.warning(f'Illegal log path: {log_path}, log file path have changed to the default path: {DEFAULT_LOG_PATH}') + + self.running = True - while True: + while self.running: try: log = log_queue.get() + if log is None: + break logger = logging.getLogger(log.name) logger.handle(log) except KeyboardInterrupt: - self.log_process_lock.release() - break + break + + def stop(self): + super().stop() + self.queue = None + logging.shutdown() + for _logger_name in ['lyrebird', 'socketio', 'engineio', 'mock', 'werkzeug', 'flask']: + logger = logging.getLogger(_logger_name) + for handler in logger.handlers[:]: + logger.removeHandler(handler) + logger.setLevel(logging.CRITICAL) + + def terminate(self): + super().terminate() + logging.shutdown() def get_logger(): diff --git a/lyrebird/manager.py b/lyrebird/manager.py index 17e811371..c1423e058 100644 --- a/lyrebird/manager.py +++ b/lyrebird/manager.py @@ -26,12 +26,12 @@ from lyrebird.task import BackgroundTaskServer from lyrebird.base_server import MultiProcessServerMessageDispatcher from lyrebird.log import LogServer +from lyrebird.utils import RedisDict, RedisManager +from lyrebird.compatibility import compat_redis_check from lyrebird import utils - logger = log.get_logger() - def main(): """ Command line main entry @@ -79,6 +79,10 @@ def main(): parser.add_argument('--database', dest='database', help='Set a database path. Default is "~/.lyrebird/lyrebird.db"') parser.add_argument('--es', dest='extra_string', action='append', nargs=2, help='Set a custom config') parser.add_argument('--no-mitm', dest='no_mitm', action='store_true', help='Start without mitmproxy on 4272') + parser.add_argument('--enable-multiprocess', dest='enable_multiprocess', action='store_true', help='change event based on multithread to multiprocess(reply on redis)') + parser.add_argument('--redis-port', dest='redis_port', type=int, help='specifies the redis service port currently in use, defalut is 6379') + parser.add_argument('--redis-ip', dest='redis_ip', help='specifies the redis service ip currently in use, defalut is localhost') + parser.add_argument('--redis-db', dest='redis_db', help='specifies the redis service db currently in use, defalut is 0') subparser = parser.add_subparsers(dest='sub_command') @@ -93,9 +97,24 @@ def main(): Path('~/.lyrebird').expanduser().mkdir(parents=True, exist_ok=True) - custom_conf = {es[0]: es[1] for es in args.extra_string} if args.extra_string else None + custom_conf = {es[0]: es[1] for es in args.extra_string} if args.extra_string else {} + + # Parameters set directly through the redis command have a higher priority than those set through --es + if args.redis_ip: + custom_conf['redis_ip'] = args.redis_ip + if args.redis_port: + custom_conf['redis_port'] = args.redis_port + if args.redis_db: + custom_conf['redis_db'] = args.redis_db + if args.enable_multiprocess and compat_redis_check(): + custom_conf['enable_multiprocess'] = True + else: + custom_conf['enable_multiprocess'] = False + application._cm = ConfigManager(conf_path_list=args.config, custom_conf=custom_conf) + application.sync_manager = application.SyncManager() + # init logger for main process application._cm.config['verbose'] = args.verbose application._cm.config['log'] = args.log @@ -105,12 +124,12 @@ def main(): # Add exception hook def process_excepthook(exc_type, exc_value, tb): - logger.error(traceback.format_tb(tb)) + print(traceback.format_tb(tb)) sys.excepthook = process_excepthook def thread_excepthook(args): - logger.error(f'Thread except {args}') - logger.error("".join(traceback.format_tb(args[2]))) + print(f'Thread except {args}') + print("".join(traceback.format_tb(args[2]))) # add threading excepthook after python3.8 if hasattr(threading, 'excepthook'): threading.excepthook = thread_excepthook @@ -168,7 +187,8 @@ def run(args: argparse.Namespace): # show current config contents print_lyrebird_info() - config_str = json.dumps(application._cm.config, ensure_ascii=False, indent=4) + config_dict = application._cm.config.raw() if isinstance(application._cm.config, RedisDict) else application._cm.config + config_str = json.dumps(config_dict, ensure_ascii=False, indent=4) logger.warning(f'Lyrebird start with config:\n{config_str}') # Main server @@ -200,15 +220,16 @@ def run(args: argparse.Namespace): # Mock mush init after other servers application.server['mock'] = LyrebirdMockServer() + # int statistics reporter + application.server['reporter'] = reporter.Reporter() + application.reporter = application.server['reporter'] + # handle progress message application.process_status_listener() # Start server without mock server, mock server must start after all blueprint is done application.start_server_without_mock_and_log() - - # int statistics reporter - application.reporter = reporter.Reporter() - reporter.start() + # activate notice center application.notice = NoticeCenter() @@ -229,6 +250,9 @@ def run(args: argparse.Namespace): if args.script: application.server['checker'].load_scripts(args.script) + if application.config.get('enable_multiprocess', False): + application.server['event'].async_start() + # Start server without mock server, mock server must start after all blueprint is done application.start_mock_server() @@ -240,10 +264,12 @@ def run(args: argparse.Namespace): # stop event handler def signal_handler(signum, frame): - reporter.stop() application.stop_server() + application.terminate_server() + application.sync_manager.destory() + RedisManager.destory() threading.Event().set() - logger.warning('!!!Ctrl-C pressed. Lyrebird stop!!!') + print('!!!Ctrl-C pressed. Lyrebird stop!!!') os._exit(0) signal.signal(signal.SIGINT, signal_handler) diff --git a/lyrebird/mitm/proxy_server.py b/lyrebird/mitm/proxy_server.py index fc526d0a6..e7f75ca48 100644 --- a/lyrebird/mitm/proxy_server.py +++ b/lyrebird/mitm/proxy_server.py @@ -5,6 +5,7 @@ import json import requests import time +import signal from lyrebird.base_server import ProcessServer from lyrebird.mitm.mitm_installer import init_mitm """ @@ -103,7 +104,10 @@ def publish_init_status(self, queue, status): } }) - def run(self, msg_queue, config, log_queue, *args, **kwargs): + def run(self, async_obj, config, *args, **kwargs): + signal.signal(signal.SIGINT, signal.SIG_IGN) + log_queue = async_obj['logger_queue'] + msg_queue = async_obj['msg_queue'] # Init logger log.init(config, log_queue) logger = log.get_logger() diff --git a/lyrebird/mock/dm/__init__.py b/lyrebird/mock/dm/__init__.py index 467c8df54..25ee04d1f 100644 --- a/lyrebird/mock/dm/__init__.py +++ b/lyrebird/mock/dm/__init__.py @@ -10,6 +10,7 @@ from collections import OrderedDict from lyrebird import utils, application from lyrebird.log import get_logger +from lyrebird.utils import RedisDict from lyrebird.application import config from lyrebird.mock import context from lyrebird.mock.dm.match import MatchRules @@ -57,6 +58,17 @@ def __init__(self): self.tree = [] self.open_nodes = [] + if config.get('enable_multiprocess', False): + try: + self.async_activate_group = RedisDict(host=config.get('redis_host', '127.0.0.1'), + port=config.get('redis_port', 6379), + db=config.get('redis_db', 0)) + self.activate = dm_asyncio_activate_decorator(self, self.activate) + self.deactivate = dm_asyncio_activate_decorator(self, self.deactivate) + except Exception as e: + config['enable_multiprocess'] = False + logger.error(f'Start enable multiprocess failed, Redis connection error: {e}') + @property def snapshot_workspace(self): if not self._snapshot_workspace: @@ -1481,6 +1493,18 @@ def save_data(self, data): self.reload() +# ----------------- +# decorator +# ----------------- + +def dm_asyncio_activate_decorator(self, func): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + self.async_activate_group.clear() + self.async_activate_group.update(self.activated_group) + return result + return wrapper + # ----------------- # Exceptions # ----------------- diff --git a/lyrebird/mock/extra_mock_server/__init__.py b/lyrebird/mock/extra_mock_server/__init__.py index f821ec705..ce1119642 100644 --- a/lyrebird/mock/extra_mock_server/__init__.py +++ b/lyrebird/mock/extra_mock_server/__init__.py @@ -1,12 +1,16 @@ from lyrebird.log import get_logger from .server import serve, publish_init_status from lyrebird.base_server import ProcessServer +import signal logger = get_logger() class ExtraMockServer(ProcessServer): - def run(self, msg_queue, config, log_queue, *args, **kwargs): + def run(self, async_obj, config, *args, **kwargs): + signal.signal(signal.SIGINT, signal.SIG_IGN) + log_queue = async_obj['logger_queue'] + msg_queue = async_obj['msg_queue'] publish_init_status(msg_queue, 'READY') serve(msg_queue, config, log_queue, *args, **kwargs) diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index 20597cd1c..65e1d19e7 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -15,6 +15,7 @@ logger = None logger_queue = None lb_config = {} +semaphore = None def is_filtered(context: LyrebirdProxyContext): @@ -103,7 +104,8 @@ async def req_handler(request: web.Request): proxy_ctx = LyrebirdProxyContext.parse(request, lb_config) if is_filtered(proxy_ctx): # forward to lyrebird - return await forward(proxy_ctx) + async with semaphore: + return await forward(proxy_ctx) else: # proxy return await proxy(proxy_ctx) @@ -127,6 +129,8 @@ def init_app(config): async def _run_app(config): + global semaphore + semaphore = asyncio.Semaphore(10) app = init_app(config) port = config.get('extra.mock.port') diff --git a/lyrebird/mock/mock_server.py b/lyrebird/mock/mock_server.py index 4ace8e05b..767c9585d 100644 --- a/lyrebird/mock/mock_server.py +++ b/lyrebird/mock/mock_server.py @@ -113,8 +113,11 @@ def stop(self): """ super().stop() + + def terminate(self): + super().terminate() try: self.socket_io.stop() - except Exception: + except Exception as e: pass - _logger.warning('CoreServer shutdown') + print('CoreServer shutdown') diff --git a/lyrebird/notice_center.py b/lyrebird/notice_center.py index 7971fe086..34095dd77 100644 --- a/lyrebird/notice_center.py +++ b/lyrebird/notice_center.py @@ -16,7 +16,11 @@ def __init__(self): self.notice_hashmap = {} self.notice_list = [] self.not_remind_list = [] - application.server['event'].subscribe('notice', self.new_notice) + application.server['event'].subscribe({ + 'name': 'new_notice', + 'channel': 'notice', + 'func': self.new_notice + }) self.load_history_notice() def storage_notice(self, storage_date): diff --git a/lyrebird/plugins/plugin_manager.py b/lyrebird/plugins/plugin_manager.py index 27d0c6718..69e2098d8 100644 --- a/lyrebird/plugins/plugin_manager.py +++ b/lyrebird/plugins/plugin_manager.py @@ -84,9 +84,14 @@ def print_plugin_api(response): # Subscribe event linstener event_service = application.server['event'] for event_option in plugin.manifest.event: - channel = event_option[0] - callback_func = event_option[1] - event_service.subscribe(channel, callback_func) + func_info = { + 'channel': event_option[0], + 'func': event_option[1], + 'name': event_option[1].__name__, + 'origin': event_option[1].__code__.co_filename, + 'process': event_option[2] if len(event_option)>2 and isinstance(event_option[2], bool) else False + } + event_service.subscribe(func_info) # Subscribe handler on request for handler in plugin.manifest.on_request: diff --git a/lyrebird/reporter.py b/lyrebird/reporter.py index 33af588ae..5713a7f4b 100644 --- a/lyrebird/reporter.py +++ b/lyrebird/reporter.py @@ -5,24 +5,31 @@ from importlib import machinery import traceback import datetime +import signal +from lyrebird.base_server import ProcessServer +from concurrent.futures import ThreadPoolExecutor +from lyrebird import application +from lyrebird.compatibility import prepare_application_for_monkey_patch, monkey_patch_application logger = get_logger() - -class Reporter: +class Reporter(ProcessServer): def __init__(self): + super().__init__() self.scripts = [] - workspace = application.config.get('reporter.workspace') - if not workspace: + self.workspace = application.config.get('reporter.workspace') + self.report_queue = application.sync_manager.get_multiprocessing_queue() + if not self.workspace: logger.debug(f'reporter.workspace not set.') - else: - self._read_reporter(workspace) + elif not application.config.get('enable_multiprocess', False): + self.scripts = self._read_reporter(self.workspace) logger.debug(f'Load statistics scripts {self.scripts}') def _read_reporter(self, workspace): target_dir = Path(workspace) + scripts = [] if not target_dir.exists(): logger.error('Reporter workspace not found') for report_script_file in target_dir.iterdir(): @@ -47,42 +54,76 @@ def _read_reporter(self, workspace): if not callable(_script_module.report): logger.warning(f'Skip report script: report method not callable, {report_script_file}') continue - self.scripts.append(_script_module.report) + scripts.append(_script_module.report) + return scripts + + def start(self): + if not application.config.get('enable_multiprocess', False): + return + self.process_namespace = prepare_application_for_monkey_patch() + self.async_obj['report_queue'] = self.report_queue + self.async_obj['workspace'] = self.workspace + self.async_obj['process_namespace'] = self.process_namespace + super().start() - def report(self, data): - task_manager = application.server.get('task') + def run(self, async_obj, config, *args, **kwargs): + + signal.signal(signal.SIGINT, signal.SIG_IGN) + + workspace = async_obj['workspace'] + reportor_queue = async_obj['report_queue'] + + monkey_patch_application(async_obj) + scripts = self._read_reporter(workspace) - def send_report(): - new_data = deepcopy(data) - for script in self.scripts: - try: - script(new_data) - except Exception: - logger.error(f'Send report failed:\n{traceback.format_exc()}') - task_manager.add_task('send-report', send_report) + self.thread_executor = ThreadPoolExecutor(max_workers=10) + self.running = True -last_page = None -last_page_in_time = None -lyrebird_start_time = None + while self.running: + try: + data = reportor_queue.get() + if not data: + break + new_data = deepcopy(data) + for script in scripts: + try: + self.thread_executor.submit(script, new_data) + except Exception: + print(f'Send report failed:\n{traceback.format_exc()}') + except Exception: + logger.error(f'Reporter run error:\n{traceback.format_exc()}') + + def report(self, data): + if self.running: + self.report_queue.put(data) + else: + task_manager = application.server.get('task') + + def send_report(): + new_data = deepcopy(data) + for script in self.scripts: + try: + script(new_data) + except Exception: + logger.error(f'Send report failed:\n{traceback.format_exc()}') + task_manager.add_task('send-report', send_report) def _page_out(): - global last_page - global last_page_in_time - if last_page and last_page_in_time: - duration = (datetime.datetime.now() - last_page_in_time).total_seconds() + if hasattr(application.sync_namespace,'last_page') and hasattr(application.sync_namespace,'last_page_in_time'): + duration = (datetime.datetime.now() - application.sync_namespace.last_page_in_time).total_seconds() application.server['event'].publish('system', { 'system': { - 'action': 'page.out', 'page': last_page, 'duration': duration + 'action': 'page.out', 'page': application.sync_namespace.last_page, 'duration': duration } }) # TODO remove below application.reporter.report({ 'action': 'page.out', - 'page': last_page, + 'page': application.sync_namespace.last_page, 'duration': duration }) @@ -90,9 +131,6 @@ def _page_out(): def page_in(name): _page_out() - global last_page - global last_page_in_time - application.server['event'].publish('system', { 'system': {'action': 'page.in', 'page': name} }) @@ -103,13 +141,12 @@ def page_in(name): 'page': name }) - last_page = name - last_page_in_time = datetime.datetime.now() + application.sync_namespace.last_page = name + application.sync_namespace.last_page_in_time = datetime.datetime.now() def start(): - global lyrebird_start_time - lyrebird_start_time = datetime.datetime.now() + application.sync_namespace.lyrebird_start_time = datetime.datetime.now() application.server['event'].publish('system', { 'system': {'action': 'start'} }) @@ -122,7 +159,7 @@ def start(): def stop(): _page_out() - duration = (datetime.datetime.now() - lyrebird_start_time).total_seconds() + duration = (datetime.datetime.now() - application.sync_namespace.lyrebird_start_time).total_seconds() application.server['event'].publish('system', { 'system': { 'action': 'stop', diff --git a/lyrebird/task.py b/lyrebird/task.py index 766825f23..a83de459e 100644 --- a/lyrebird/task.py +++ b/lyrebird/task.py @@ -36,13 +36,13 @@ class BackgroundTaskServer(ThreadServer): def __init__(self): super().__init__() self.tasks = [] - self.cmds = Queue() + self.cmds = application.sync_manager.get_queue() self.executor = ThreadPoolExecutor(thread_name_prefix='bg-') def run(self): while self.running: cmd = self.cmds.get() - if cmd == 'stop': + if cmd is None or cmd == 'stop': break elif cmd == 'clear': dead_tasks = [] diff --git a/lyrebird/utils.py b/lyrebird/utils.py index f238b8d45..bdd4cb2b0 100644 --- a/lyrebird/utils.py +++ b/lyrebird/utils.py @@ -3,6 +3,9 @@ import json import math import time +import uuid +import redis +import pickle import socket import tarfile import requests @@ -10,6 +13,7 @@ import netifaces import traceback from pathlib import Path +from copy import deepcopy from jinja2 import Template, StrictUndefined from jinja2.exceptions import UndefinedError, TemplateSyntaxError from contextlib import closing @@ -19,6 +23,7 @@ logger = get_logger() +REDIS_EXPIRE_TIME = 60*60*24 def convert_size(size_bytes): if size_bytes == 0: @@ -340,10 +345,21 @@ class CaseInsensitiveDict(dict): & will be treated as the same key, only one exists in this dict. ''' - def __init__(self, raw_dict): + def __init__(self, raw_dict=None): self.__key_map = {} - for k, v in raw_dict.items(): - self.__setitem__(k, v) + if raw_dict: + for k, v in raw_dict.items(): + self.__setitem__(k, v) + + def __getstate__(self): + return { + 'key_map': self.__key_map, + 'data': dict(self) + } + + def __setstate__(self, state): + self.__key_map = state['key_map'] + self.update(state['data']) def __get_real_key(self, key): return self.__key_map.get(key.lower(), key) @@ -401,6 +417,9 @@ def update(self, __m=None, **kwargs) -> None: for k, v in kwargs.items(): self.__setitem__(k, v) + def __reduce__(self): + return (self.__class__, (dict(self),)) + class HookedDict(dict): ''' @@ -425,6 +444,9 @@ def __setitem__(self, __k, __v) -> None: __v = HookedDict(__v) return super(HookedDict, self).__setitem__(__k, __v) + def __reduce__(self): + return (self.__class__, (dict(self),)) + class TargetMatch: @@ -479,3 +501,219 @@ def json(self): elif isinstance(prop_obj, datetime.datetime): prop_collection[prop] = prop_obj.timestamp() return prop_collection + + +class RedisManager: + + redis_dicts = set() + + @staticmethod + def put(obj): + RedisManager.redis_dicts.add(obj) + + @staticmethod + def destory(): + for i in RedisManager.redis_dicts: + i.destory() + RedisManager.redis_dicts.clear() + + @staticmethod + def serialize(): + return pickle.dumps(RedisManager.redis_dicts) + + @staticmethod + def deserialize(data): + RedisManager.redis_dicts = pickle.loads(data) + + +class RedisData: + + host = 'localhost' + port = 6379 + db = 0 + + def __init__(self, host=None, port=None, db=None, param_uuid=None): + if not host: + host = RedisData.host + if not port: + port = RedisData.port + if not db: + db = RedisData.db + self.port = port + self.host = host + self.db = db + if not param_uuid: + self.uuid = str(uuid.uuid4()) + else: + self.uuid = param_uuid + self.redis = redis.Redis(host=self.host, port=self.port, db=self.db) + RedisManager.put(self) + + def destory(self): + self.redis.delete(self.uuid) + self.redis.close() + + + def __getstate__(self): + return pickle.dumps({ + 'uuid':self.uuid, + 'port':self.port, + 'host':self.host, + 'db':self.db + }) + + def __setstate__(self, state): + data = pickle.loads(state) + self.port = data['port'] + self.host = data['host'] + self.db = data['db'] + self.uuid = data['uuid'] + self.redis = redis.Redis(host=self.host, port=self.port, db=self.db) + + +class RedisDict(RedisData): + + def __init__(self, host=None, port=None, db=None, param_uuid=None, data={}): + super().__init__(host, port, db, param_uuid) + for k in data.keys(): + self[k] = data[k] + + def __getitem__(self, key): + value = self.redis.hget(self.uuid, key) + if value is None: + raise KeyError(key) + value = json.loads(value.decode()) + return _hook_value(self, key, value) + + def __setitem__(self, key, value): + value = json.dumps(value, ensure_ascii=False) + self.redis.hset(self.uuid, key, value) + self.redis.expire(self.uuid, REDIS_EXPIRE_TIME) + + def __delitem__(self, key): + if not self.redis.hexists(self.uuid, key): + raise KeyError(key) + self.redis.hdel(self.uuid, key) + self.redis.expire(self.uuid, REDIS_EXPIRE_TIME) + + def __contains__(self, key): + return self.redis.hexists(self.uuid, key) + + def keys(self): + return [key.decode() for key in self.redis.hkeys(self.uuid)] + + def values(self): + return [json.loads(value.decode()) for value in self.redis.hgetall(self.uuid).values()] + + def items(self): + return [(key.decode(), json.loads(value.decode())) for key, value in self.redis.hgetall(self.uuid).items()] + + def get(self, key, default=None): + value = self.redis.hget(self.uuid, key) + if value is None: + return default + return _hook_value(self, key, json.loads(value.decode())) + + def update(self, data): + for key, value in data.items(): + self[key] = value + self.redis.expire(self.uuid, REDIS_EXPIRE_TIME) + + def clear(self): + self.redis.delete(self.uuid) + + def raw(self): + return {key.decode(): json.loads(value.decode()) for key, value in self.redis.hgetall(self.uuid).items()} + + def __len__(self): + return len(self.redis.hkeys(self.uuid)) + + def __repr__(self): + return repr(dict(self.items())) + + def __deepcopy__(self, memo): + return self.raw() + + +def _hook_value(parent, key, value): + if isinstance(value, dict): + return RedisHookedDict(parent, key, value) + elif isinstance(value, list): + return RedisHookedList(parent, key, value) + elif isinstance(value, set): + return RedisHookedSet(parent, key, value) + else: + return value + + +class RedisHook: + def __init__(self, parent, key): + self.parent = parent + self.key = key + + +class RedisHookedDict(RedisHook, dict): + def __init__(self, parent, key, value): + RedisHook.__init__(self, parent, key) + dict.__init__(self, value) + + def get(self, key, default=None): + res = dict.get(self, key, default) + return _hook_value(self, key, res) + + def __getitem__(self, key): + return _hook_value(self, key, dict.__getitem__(self, key)) + + def __setitem__(self, key, value): + dict.__setitem__(self, key, value) + self.parent[self.key] = self + + def __delitem__(self, key): + dict.__delitem__(self, key) + self.parent[self.key] = self + + def update(self, *args, **kwargs): + dict.update(self, *args, **kwargs) + self.parent[self.key] = self + + def __deepcopy__(self, memo): + return deepcopy(dict(self), memo) + +class RedisHookedList(RedisHook, list): + def __init__(self, parent, key, value): + list.__init__(self, value) + RedisHook.__init__(self, parent, key) + + def __getitem__(self, index): + return _hook_value(self, index, list.__getitem__(self, index)) + + def __setitem__(self, index, value): + list.__setitem__(self, index, _hook_value(value)) + self.parent[self.key] = self + + def __delitem__(self, index): + list.__delitem__(self, index) + self.parent[self.key] = self + + def append(self, value): + list.append(self, value) + self.parent[self.key] = self + + def __deepcopy__(self, memo): + return deepcopy(list(self), memo) + +class RedisHookedSet(RedisHook, set): + def __init__(self, parent, key, value): + set.__init__(self, value) + RedisHook.__init__(self, parent, key) + + def add(self, value): + set.add(self, _hook_value(value)) + self.parent[self.key] = self + + def remove(self, value): + set.remove(self, value) + self.parent[self.key] = self + + def __deepcopy__(self, memo): + return deepcopy(set(self), memo) diff --git a/lyrebird/version.py b/lyrebird/version.py index 811c2dafb..d6a31cc9f 100644 --- a/lyrebird/version.py +++ b/lyrebird/version.py @@ -1,3 +1,3 @@ -IVERSION = (2, 26, 4) +IVERSION = (3, 0, 0) VERSION = ".".join(str(i) for i in IVERSION) LYREBIRD = "Lyrebird " + VERSION diff --git a/requirements.txt b/requirements.txt index 4008f48ae..b3b2f0732 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ aiohttp==3.8.3 netifaces==0.11.0 jsonschema==4.17.0 Flask-Cors==4.0.0 +redis==4.6.0 diff --git a/requirements.txt.lock b/requirements.txt.lock index 6e1ef3f49..b0dc31f35 100644 --- a/requirements.txt.lock +++ b/requirements.txt.lock @@ -43,3 +43,4 @@ urllib3==1.24.3 Werkzeug==2.2.2 yarl==1.8.1 zipp==3.10.0 +redis==4.6.0 diff --git a/tests/conftest.py b/tests/conftest.py index dda89ace9..cfb3f794c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import pytest from lyrebird import application from lyrebird.mock import context +from lyrebird.application import SyncManager SERVER_NAMES = ['event', 'log', 'mock', 'task', 'checker', 'db', 'plugin'] @@ -37,3 +38,12 @@ def setup_and_teardown_environment(): application.on_response_upstream = bak_on_response_upstream context.application.socket_io = bak_socketio context.application.data_manager = bak_dm + + +@pytest.fixture(scope='function', autouse=True) +def init_sync_manager(): + application.sync_manager = SyncManager() + yield + application.sync_manager.broadcast_to_queues(None) + application.sync_manager.destory() + application.sync_manager = None diff --git a/tests/test_common_api.py b/tests/test_common_api.py index b01728efb..6079ee1bd 100644 --- a/tests/test_common_api.py +++ b/tests/test_common_api.py @@ -24,6 +24,7 @@ def client(): server = LyrebirdMockServer() with server.app.test_client() as client: yield client + server.terminate() def test_render_api_without_json(client): resp = client.put('/api/render') diff --git a/tests/test_conf_api.py b/tests/test_conf_api.py index a08a165a8..485c40eeb 100644 --- a/tests/test_conf_api.py +++ b/tests/test_conf_api.py @@ -28,6 +28,7 @@ def client(): server = LyrebirdMockServer() with server.app.test_client() as client: yield client + server.terminate() def test_patch_conf_api_with_no_param(client): diff --git a/tests/test_db.py b/tests/test_db.py index 2bd6a913c..bfbefcdb2 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -28,6 +28,8 @@ def event_server(tmpdir): application.server['event'] = server yield server server.stop() + application.sync_manager.broadcast_to_queues(None) + server.terminate() @pytest.fixture diff --git a/tests/test_dm_api.py b/tests/test_dm_api.py index c692517a7..c80fc5881 100644 --- a/tests/test_dm_api.py +++ b/tests/test_dm_api.py @@ -77,6 +77,7 @@ def client(root, tmpdir): _dm.set_root(root) with server.app.test_client() as client: yield client + server.terminate() del server diff --git a/tests/test_dm_config.py b/tests/test_dm_config.py index 8fbcb8f82..38f11b2e5 100644 --- a/tests/test_dm_config.py +++ b/tests/test_dm_config.py @@ -6,7 +6,6 @@ import lyrebird from .utils import FakeSocketio, FakeEvnetServer from lyrebird.mock import dm -from lyrebird.event import EventServer from lyrebird.config import ConfigManager from lyrebird.checker import LyrebirdCheckerServer from lyrebird.config import CONFIG_TREE_SHOW_CONFIG diff --git a/tests/test_event.py b/tests/test_event.py index bdc2b3adf..df80b3404 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -39,6 +39,8 @@ def event_server(): lyrebird.application.server['event'] = server yield server server.stop() + application.sync_manager.broadcast_to_queues(None) + server.terminate() @pytest.fixture @@ -52,7 +54,7 @@ def test_event(callback_tester, event_server, task_server): cb_tester = CallbackTester() - event_server.subscribe('Test', cb_tester.callback) + event_server.subscribe({'channel':'Test', 'func':cb_tester.callback}) assert event_server.pubsub_channels.get('Test') @@ -72,7 +74,7 @@ def test_event_default_information(callback_tester, event_server, task_server): cb_tester = CallbackTester() - event_server.subscribe('Test', cb_tester.callback) + event_server.subscribe({'channel':'Test', 'func':cb_tester.callback}) assert event_server.pubsub_channels.get('Test') @@ -91,7 +93,7 @@ def test_event_default_information_with_sender(callback_tester, event_server, ta cb_tester = CallbackTester() - event_server.subscribe('Test', cb_tester.callback) + event_server.subscribe({'channel':'Test', 'func':cb_tester.callback}) assert event_server.pubsub_channels.get('Test') diff --git a/tests/test_mock_server.py b/tests/test_mock_server.py index 9770085e1..4b820a1f9 100644 --- a/tests/test_mock_server.py +++ b/tests/test_mock_server.py @@ -39,6 +39,7 @@ def client(): server.app.testing = True with server.app.test_client() as client: yield client + server.terminate() @pytest.fixture diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 99a4d1fc9..dd198c118 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -122,6 +122,7 @@ def mock_server(): application.server['task'] = FakeBackgroundTaskServer() application.server['mock'] = server yield server + server.terminate() del server