diff --git a/elastalert/config.py b/elastalert/config.py index d1c27f44a..8a9f73995 100644 --- a/elastalert/config.py +++ b/elastalert/config.py @@ -100,6 +100,7 @@ def get_module(module_name): """ Loads a module and returns a specific object. module_name should 'module.file.object'. Returns object or raises EAException on error. """ + sys.path.append(os.getcwd()) try: module_path, module_class = module_name.rsplit('.', 1) base_module = __import__(module_path, globals(), locals(), [module_class]) @@ -195,6 +196,8 @@ def load_options(rule, conf, filename, args=None): rule['query_delay'] = datetime.timedelta(**rule['query_delay']) if 'buffer_time' in rule: rule['buffer_time'] = datetime.timedelta(**rule['buffer_time']) + if 'run_every' in rule: + rule['run_every'] = datetime.timedelta(**rule['run_every']) if 'bucket_interval' in rule: rule['bucket_interval_timedelta'] = datetime.timedelta(**rule['bucket_interval']) if 'exponential_realert' in rule: diff --git a/elastalert/create_index.py b/elastalert/create_index.py index b12ee7e5e..953c3a1ee 100644 --- a/elastalert/create_index.py +++ b/elastalert/create_index.py @@ -57,6 +57,8 @@ def main(): filename = 'config.yaml' elif os.path.isfile(args.config): filename = args.config + elif os.path.isfile('../config.yaml'): + filename = '../config.yaml' else: filename = '' diff --git a/elastalert/elastalert.py b/elastalert/elastalert.py index 29cc87568..e889161cd 100755 --- a/elastalert/elastalert.py +++ b/elastalert/elastalert.py @@ -5,8 +5,10 @@ import json import logging import os +import random import signal import sys +import threading import time import timeit import traceback @@ -17,8 +19,10 @@ import dateutil.tz import kibana +import pytz import yaml from alerts import DebugAlerter +from apscheduler.schedulers.background import BackgroundScheduler from config import get_rule_hashes from config import load_configuration from config import load_rules @@ -63,6 +67,8 @@ class ElastAlerter(): should not be passed directly from a configuration file, but must be populated by config.py:load_rules instead. """ + thread_data = threading.local() + def parse_args(self, args): parser = argparse.ArgumentParser() parser.add_argument( @@ -128,6 +134,7 @@ def __init__(self, args): tracer.addHandler(logging.FileHandler(self.args.es_debug_trace)) self.conf = load_rules(self.args) + print len(self.conf['rules']), 'rules loaded' self.max_query_size = self.conf['max_query_size'] self.scroll_keepalive = self.conf['scroll_keepalive'] self.rules = self.conf['rules'] @@ -140,18 +147,15 @@ def __init__(self, args): self.from_addr = self.conf.get('from_addr', 'ElastAlert') self.smtp_host = self.conf.get('smtp_host', 'localhost') self.max_aggregation = self.conf.get('max_aggregation', 10000) - self.alerts_sent = 0 - self.cumulative_hits = 0 - self.num_hits = 0 - self.num_dupes = 0 - self.current_es = None - self.current_es_addr = None self.buffer_time = self.conf['buffer_time'] self.silence_cache = {} self.rule_hashes = get_rule_hashes(self.conf, self.args.rule) self.starttime = self.args.start self.disabled_rules = [] self.replace_dots_in_field_names = self.conf.get('replace_dots_in_field_names', False) + self.thread_data.num_hits = 0 + self.thread_data.num_dupes = 0 + self.scheduler = BackgroundScheduler() self.string_multi_field_name = self.conf.get('string_multi_field_name', False) self.add_metadata_alert = self.conf.get('add_metadata_alert', False) @@ -299,7 +303,8 @@ def get_index_start(self, index, timestamp_field='@timestamp'): """ query = {'sort': {timestamp_field: {'order': 'asc'}}} try: - res = self.current_es.search(index=index, size=1, body=query, _source_include=[timestamp_field], ignore_unavailable=True) + res = self.thread_data.current_es.search(index=index, size=1, body=query, + _source_include=[timestamp_field], ignore_unavailable=True) except ElasticsearchException as e: self.handle_error("Elasticsearch query error: %s" % (e), {'index': index, 'query': query}) return '1969-12-30T00:00:00Z' @@ -381,9 +386,9 @@ def get_hits(self, rule, starttime, endtime, index, scroll=False): try: if scroll: - res = self.current_es.scroll(scroll_id=rule['scroll_id'], scroll=scroll_keepalive) + res = self.thread_data.current_es.scroll(scroll_id=rule['scroll_id'], scroll=scroll_keepalive) else: - res = self.current_es.search( + res = self.thread_data.current_es.search( scroll=scroll_keepalive, index=index, size=rule.get('max_query_size', self.max_query_size), @@ -391,7 +396,7 @@ def get_hits(self, rule, starttime, endtime, index, scroll=False): ignore_unavailable=True, **extra_args ) - self.total_hits = int(res['hits']['total']) + self.thread_data.total_hits = int(res['hits']['total']) if len(res.get('_shards', {}).get('failures', [])) > 0: try: @@ -411,16 +416,16 @@ def get_hits(self, rule, starttime, endtime, index, scroll=False): self.handle_error('Error running query: %s' % (e), {'rule': rule['name'], 'query': query}) return None hits = res['hits']['hits'] - self.num_hits += len(hits) + self.thread_data.num_hits += len(hits) lt = rule.get('use_local_time') status_log = "Queried rule %s from %s to %s: %s / %s hits" % ( rule['name'], pretty_ts(starttime, lt), pretty_ts(endtime, lt), - self.num_hits, + self.thread_data.num_hits, len(hits) ) - if self.total_hits > rule.get('max_query_size', self.max_query_size): + if self.thread_data.total_hits > rule.get('max_query_size', self.max_query_size): elastalert_logger.info("%s (scrolling..)" % status_log) rule['scroll_id'] = res['_scroll_id'] else: @@ -454,7 +459,7 @@ def get_hits_count(self, rule, starttime, endtime, index): ) try: - res = self.current_es.count(index=index, doc_type=rule['doc_type'], body=query, ignore_unavailable=True) + res = self.thread_data.current_es.count(index=index, doc_type=rule['doc_type'], body=query, ignore_unavailable=True) except ElasticsearchException as e: # Elasticsearch sometimes gives us GIGANTIC error messages # (so big that they will fill the entire terminal buffer) @@ -463,7 +468,7 @@ def get_hits_count(self, rule, starttime, endtime, index): self.handle_error('Error running count query: %s' % (e), {'rule': rule['name'], 'query': query}) return None - self.num_hits += res['count'] + self.thread_data.num_hits += res['count'] lt = rule.get('use_local_time') elastalert_logger.info( "Queried rule %s from %s to %s: %s hits" % (rule['name'], pretty_ts(starttime, lt), pretty_ts(endtime, lt), res['count']) @@ -509,7 +514,7 @@ def get_hits_terms(self, rule, starttime, endtime, index, key, qk=None, size=Non try: if not rule['five']: - res = self.current_es.search( + res = self.thread_data.current_es.search( index=index, doc_type=rule['doc_type'], body=query, @@ -517,7 +522,8 @@ def get_hits_terms(self, rule, starttime, endtime, index, key, qk=None, size=Non ignore_unavailable=True ) else: - res = self.current_es.search(index=index, doc_type=rule['doc_type'], body=query, size=0, ignore_unavailable=True) + res = self.thread_data.current_es.search(index=index, doc_type=rule['doc_type'], + body=query, size=0, ignore_unavailable=True) except ElasticsearchException as e: # Elasticsearch sometimes gives us GIGANTIC error messages # (so big that they will fill the entire terminal buffer) @@ -532,7 +538,7 @@ def get_hits_terms(self, rule, starttime, endtime, index, key, qk=None, size=Non buckets = res['aggregations']['filtered']['counts']['buckets'] else: buckets = res['aggregations']['counts']['buckets'] - self.num_hits += len(buckets) + self.thread_data.num_hits += len(buckets) lt = rule.get('use_local_time') elastalert_logger.info( 'Queried rule %s from %s to %s: %s buckets' % (rule['name'], pretty_ts(starttime, lt), pretty_ts(endtime, lt), len(buckets)) @@ -555,7 +561,7 @@ def get_hits_aggregation(self, rule, starttime, endtime, index, query_key, term_ query = self.get_aggregation_query(base_query, rule, query_key, term_size, rule['timestamp_field']) try: if not rule['five']: - res = self.current_es.search( + res = self.thread_data.current_es.search( index=index, doc_type=rule.get('doc_type'), body=query, @@ -563,7 +569,8 @@ def get_hits_aggregation(self, rule, starttime, endtime, index, query_key, term_ ignore_unavailable=True ) else: - res = self.current_es.search(index=index, doc_type=rule.get('doc_type'), body=query, size=0, ignore_unavailable=True) + res = self.thread_data.current_es.search(index=index, doc_type=rule.get('doc_type'), + body=query, size=0, ignore_unavailable=True) except ElasticsearchException as e: if len(str(e)) > 1024: e = str(e)[:1024] + '... (%d characters removed)' % (len(str(e)) - 1024) @@ -575,7 +582,7 @@ def get_hits_aggregation(self, rule, starttime, endtime, index, query_key, term_ payload = res['aggregations']['filtered'] else: payload = res['aggregations'] - self.num_hits += res['hits']['total'] + self.thread_data.num_hits += res['hits']['total'] return {endtime: payload} def remove_duplicate_events(self, data, rule): @@ -629,7 +636,7 @@ def run_query(self, rule, start=None, end=None, scroll=False): if data: old_len = len(data) data = self.remove_duplicate_events(data, rule) - self.num_dupes += old_len - len(data) + self.thread_data.num_dupes += old_len - len(data) # There was an exception while querying if data is None: @@ -645,7 +652,7 @@ def run_query(self, rule, start=None, end=None, scroll=False): rule_inst.add_data(data) try: - if rule.get('scroll_id') and self.num_hits < self.total_hits: + if rule.get('scroll_id') and self.thread_data.num_hits < self.thread_data.total_hits: self.run_query(rule, start, end, scroll=True) except RuntimeError: # It's possible to scroll far enough to hit max recursive depth @@ -689,7 +696,6 @@ def get_starttime(self, rule): def set_starttime(self, rule, endtime): """ Given a rule and an endtime, sets the appropriate starttime for it. """ - # This means we are starting fresh if 'starttime' not in rule: if not rule.get('scan_entire_timeframe'): @@ -844,8 +850,7 @@ def run_rule(self, rule, endtime, starttime=None): """ run_start = time.time() - self.current_es = elasticsearch_client(rule) - self.current_es_addr = (rule['es_host'], rule['es_port']) + self.thread_data.current_es = elasticsearch_client(rule) # If there are pending aggregate matches, try processing them for x in range(len(rule['agg_matches'])): @@ -866,9 +871,9 @@ def run_rule(self, rule, endtime, starttime=None): return 0 # Run the rule. If querying over a large time period, split it up into segments - self.num_hits = 0 - self.num_dupes = 0 - self.cumulative_hits = 0 + self.thread_data.num_hits = 0 + self.thread_data.num_dupes = 0 + self.thread_data.cumulative_hits = 0 segment_size = self.get_segment_size(rule) tmp_endtime = rule['starttime'] @@ -877,15 +882,15 @@ def run_rule(self, rule, endtime, starttime=None): tmp_endtime = tmp_endtime + segment_size if not self.run_query(rule, rule['starttime'], tmp_endtime): return 0 - self.cumulative_hits += self.num_hits - self.num_hits = 0 + self.thread_data.cumulative_hits += self.thread_data.num_hits + self.thread_data.num_hits = 0 rule['starttime'] = tmp_endtime rule['type'].garbage_collect(tmp_endtime) if rule.get('aggregation_query_element'): if endtime - tmp_endtime == segment_size: self.run_query(rule, tmp_endtime, endtime) - self.cumulative_hits += self.num_hits + self.thread_data.cumulative_hits += self.thread_data.num_hits elif total_seconds(rule['original_starttime'] - tmp_endtime) == 0: rule['starttime'] = rule['original_starttime'] return 0 @@ -894,14 +899,14 @@ def run_rule(self, rule, endtime, starttime=None): else: if not self.run_query(rule, rule['starttime'], endtime): return 0 - self.cumulative_hits += self.num_hits + self.thread_data.cumulative_hits += self.thread_data.num_hits rule['type'].garbage_collect(endtime) # Process any new matches num_matches = len(rule['type'].matches) while rule['type'].matches: match = rule['type'].matches.pop(0) - match['num_hits'] = self.cumulative_hits + match['num_hits'] = self.thread_data.cumulative_hits match['num_matches'] = num_matches # If realert is set, silence the rule for that duration @@ -947,7 +952,7 @@ def run_rule(self, rule, endtime, starttime=None): 'endtime': endtime, 'starttime': rule['original_starttime'], 'matches': num_matches, - 'hits': max(self.num_hits, self.cumulative_hits), + 'hits': max(self.thread_data.num_hits, self.thread_data.cumulative_hits), '@timestamp': ts_now(), 'time_taken': time_taken} self.writeback('elastalert_status', body) @@ -956,6 +961,9 @@ def run_rule(self, rule, endtime, starttime=None): def init_rule(self, new_rule, new=True): ''' Copies some necessary non-config state from an exiting rule to a new rule. ''' + if not new: + self.scheduler.remove_job(job_id=new_rule['name']) + try: self.modify_rule_for_ES5(new_rule) except TransportError as e: @@ -990,7 +998,9 @@ def init_rule(self, new_rule, new=True): blank_rule = {'agg_matches': [], 'aggregate_alert_time': {}, 'current_aggregate_id': {}, - 'processed_hits': {}} + 'processed_hits': {}, + 'run_every': self.run_every, + 'has_run_once': False} rule = blank_rule # Set rule to either a blank template or existing rule with same name @@ -1006,12 +1016,22 @@ def init_rule(self, new_rule, new=True): 'aggregate_alert_time', 'processed_hits', 'starttime', - 'minimum_starttime'] + 'minimum_starttime', + 'has_run_once', + 'run_every'] for prop in copy_properties: if prop not in rule: continue new_rule[prop] = rule[prop] + job = self.scheduler.add_job(self.handle_rule_execution, 'interval', + args=[new_rule], + seconds=new_rule['run_every'].total_seconds(), + id=new_rule['name'], + max_instances=1, + jitter=5) + job.modify(next_run_time=datetime.datetime.now() + datetime.timedelta(seconds=random.randint(0, 15))) + return new_rule @staticmethod @@ -1118,14 +1138,20 @@ def start(self): except (TypeError, ValueError): self.handle_error("%s is not a valid ISO8601 timestamp (YYYY-MM-DDTHH:MM:SS+XX:00)" % (self.starttime)) exit(1) + + for rule in self.rules: + rule['initial_starttime'] = self.starttime self.wait_until_responsive(timeout=self.args.timeout) self.running = True elastalert_logger.info("Starting up") + self.scheduler.add_job(self.handle_pending_alerts, 'interval', + seconds=self.run_every.total_seconds(), id='_internal_handle_pending_alerts') + self.scheduler.add_job(self.handle_config_change, 'interval', + seconds=self.run_every.total_seconds(), id='_internal_handle_config_change') + self.scheduler.start() while self.running: next_run = datetime.datetime.utcnow() + self.run_every - self.run_all_rules() - # Quit after end_time has been reached if self.args.end: endtime = ts_to_dt(self.args.end) @@ -1176,53 +1202,95 @@ def wait_until_responsive(self, timeout, clock=timeit.default_timer): def run_all_rules(self): """ Run each rule one time """ + self.handle_pending_alerts() + + for rule in self.rules: + self.handle_rule_execution(rule) + + self.handle_config_change() + + def handle_pending_alerts(self): + self.thread_data.alerts_sent = 0 self.send_pending_alerts() + elastalert_logger.info("Background alerts thread %s pending alerts sent at %s" % (self.thread_data.alerts_sent, + pretty_ts(ts_now()))) - next_run = datetime.datetime.utcnow() + self.run_every + def handle_config_change(self): + if not self.args.pin_rules: + self.load_rule_changes() + elastalert_logger.info("Background configuration change check run at %s" % (pretty_ts(ts_now()))) + + def handle_rule_execution(self, rule): + self.thread_data.alerts_sent = 0 + next_run = datetime.datetime.utcnow() + rule['run_every'] + # Set endtime based on the rule's delay + delay = rule.get('query_delay') + if hasattr(self.args, 'end') and self.args.end: + endtime = ts_to_dt(self.args.end) + elif delay: + endtime = ts_now() - delay + else: + endtime = ts_now() + + # Apply rules based on execution time limits + if rule.get('limit_execution'): + rule['next_starttime'] = None + rule['next_min_starttime'] = None + exec_next = croniter(rule['limit_execution']).next() + endtime_epoch = dt_to_unix(endtime) + # If the estimated next endtime (end + run_every) isn't at least a minute past the next exec time + # That means that we need to pause execution after this run + if endtime_epoch + rule['run_every'].total_seconds() < exec_next - 59: + # apscheduler requires pytz tzinfos, so don't use unix_to_dt here! + rule['next_starttime'] = datetime.datetime.utcfromtimestamp(exec_next).replace(tzinfo=pytz.utc) + if rule.get('limit_execution_coverage'): + rule['next_min_starttime'] = rule['next_starttime'] + if not rule['has_run_once']: + self.reset_rule_schedule(rule) + return - for rule in self.rules: - # Set endtime based on the rule's delay - delay = rule.get('query_delay') - if hasattr(self.args, 'end') and self.args.end: - endtime = ts_to_dt(self.args.end) - elif delay: - endtime = ts_now() - delay - else: - endtime = ts_now() + rule['has_run_once'] = True + try: + num_matches = self.run_rule(rule, endtime, rule.get('initial_starttime')) + except EAException as e: + self.handle_error("Error running rule %s: %s" % (rule['name'], e), {'rule': rule['name']}) + except Exception as e: + self.handle_uncaught_exception(e, rule) + else: + old_starttime = pretty_ts(rule.get('original_starttime'), rule.get('use_local_time')) + elastalert_logger.info("Ran %s from %s to %s: %s query hits (%s already seen), %s matches," + " %s alerts sent" % (rule['name'], old_starttime, pretty_ts(endtime, rule.get('use_local_time')), + self.thread_data.num_hits, self.thread_data.num_dupes, num_matches, + self.thread_data.alerts_sent)) + self.thread_data.alerts_sent = 0 - try: - num_matches = self.run_rule(rule, endtime, self.starttime) - except EAException as e: - self.handle_error("Error running rule %s: %s" % (rule['name'], e), {'rule': rule['name']}) - except Exception as e: - self.handle_uncaught_exception(e, rule) - else: - old_starttime = pretty_ts(rule.get('original_starttime'), rule.get('use_local_time')) - total_hits = max(self.num_hits, self.cumulative_hits) - elastalert_logger.info("Ran %s from %s to %s: %s query hits (%s already seen), %s matches," - " %s alerts sent" % (rule['name'], old_starttime, pretty_ts(endtime, rule.get('use_local_time')), - total_hits, self.num_dupes, num_matches, self.alerts_sent)) - self.alerts_sent = 0 - - if next_run < datetime.datetime.utcnow(): - # We were processing for longer than our refresh interval - # This can happen if --start was specified with a large time period - # or if we are running too slow to process events in real time. - logging.warning( - "Querying from %s to %s took longer than %s!" % ( - old_starttime, - pretty_ts(endtime, rule.get('use_local_time')), - self.run_every - ) + if next_run < datetime.datetime.utcnow(): + # We were processing for longer than our refresh interval + # This can happen if --start was specified with a large time period + # or if we are running too slow to process events in real time. + logging.warning( + "Querying from %s to %s took longer than %s!" % ( + old_starttime, + pretty_ts(endtime, rule.get('use_local_time')), + self.run_every ) + ) - self.remove_old_events(rule) + rule['initial_starttime'] = None - # Only force starttime once - self.starttime = None + self.remove_old_events(rule) - if not self.args.pin_rules: - self.load_rule_changes() + self.reset_rule_schedule(rule) + + def reset_rule_schedule(self, rule): + # We hit the end of a execution schedule, pause ourselves until next run + if rule.get('limit_execution') and rule['next_starttime']: + self.scheduler.modify_job(job_id=rule['name'], next_run_time=rule['next_starttime']) + # If we are preventing covering non-scheduled time periods, reset min_starttime and previous_endtime + if rule['next_min_starttime']: + rule['minimum_starttime'] = rule['next_min_starttime'] + rule['previous_endtime'] = rule['next_min_starttime'] + elastalert_logger.info('Pausing %s until next run at %s' % (rule['name'], pretty_ts(rule['next_starttime']))) def stop(self): """ Stop an ElastAlert runner that's been started """ @@ -1453,7 +1521,7 @@ def send_alert(self, matches, rule, alert_time=None, retried=False): self.handle_error('Error while running alert %s: %s' % (alert.get_info()['type'], e), {'rule': rule['name']}) alert_exception = str(e) else: - self.alerts_sent += 1 + self.thread_data.alerts_sent += 1 alert_sent = True # Write the alert(s) to ES @@ -1834,6 +1902,7 @@ def handle_uncaught_exception(self, exception, rule): if self.disable_rules_on_error: self.rules = [running_rule for running_rule in self.rules if running_rule['name'] != rule['name']] self.disabled_rules.append(rule) + self.scheduler.pause_job(job_id=rule['name']) elastalert_logger.info('Rule %s disabled', rule['name']) if self.notify_email: self.send_notification_email(exception=exception, rule=rule) @@ -1892,7 +1961,7 @@ def get_top_counts(self, rule, starttime, endtime, keys, number=None, qk=None): buckets = hits_terms.values()[0] # get_hits_terms adds to num_hits, but we don't want to count these - self.num_hits -= len(buckets) + self.thread_data.num_hits -= len(buckets) terms = {} for bucket in buckets: terms[bucket['key']] = bucket['doc_count'] diff --git a/elastalert/enhancements.py b/elastalert/enhancements.py index d6c902514..2744e35c8 100644 --- a/elastalert/enhancements.py +++ b/elastalert/enhancements.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from util import pretty_ts class BaseEnhancement(object): @@ -14,6 +15,11 @@ def process(self, match): raise NotImplementedError() +class TimeEnhancement(BaseEnhancement): + def process(self, match): + match['@timestamp'] = pretty_ts(match['@timestamp']) + + class DropMatchException(Exception): """ ElastAlert will drop a match if this exception type is raised by an enhancement """ pass diff --git a/elastalert/util.py b/elastalert/util.py index 33f0b4e71..29cf24fbe 100644 --- a/elastalert/util.py +++ b/elastalert/util.py @@ -5,7 +5,7 @@ import os import dateutil.parser -import dateutil.tz +import pytz from auth import Auth from elasticsearch import RequestsHttpConnection from elasticsearch.client import Elasticsearch @@ -112,7 +112,7 @@ def ts_to_dt(timestamp): dt = dateutil.parser.parse(timestamp) # Implicitly convert local timestamps to UTC if dt.tzinfo is None: - dt = dt.replace(tzinfo=dateutil.tz.tzutc()) + dt = dt.replace(tzinfo=pytz.utc) return dt @@ -365,6 +365,15 @@ def build_es_conn_config(conf): return parsed_conf +def pytzfy(dt): + # apscheduler requires pytz timezone objects + # This function will replace a dateutil.tz one with a pytz one + if dt.tzinfo is not None: + new_tz = pytz.timezone(dt.tzinfo.tzname('Y is this even required??')) + return dt.replace(tzinfo=new_tz) + return dt + + def parse_duration(value): """Convert ``unit=num`` spec into a ``timedelta`` object.""" unit, num = value.split('=') diff --git a/requirements-dev.txt b/requirements-dev.txt index 36daa0ebd..1cb67cb8e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +-r requirements.txt coverage flake8 pre-commit diff --git a/requirements.txt b/requirements.txt index 4f23f2ec1..418f92869 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ +apscheduler>=3.3.0 aws-requests-auth>=0.3.0 blist>=1.3.6 boto3>=1.4.4 +cffi>=1.11.5 configparser>=3.5.0 croniter>=0.3.16 elasticsearch @@ -11,11 +13,10 @@ jsonschema>=2.6.0 mock>=2.0.0 PyStaticConfiguration>=0.10.3 python-dateutil>=2.6.0,<2.7.0 +python-magic>=0.4.15 PyYAML>=3.12 requests>=2.0.0 stomp.py>=4.1.17 texttable>=0.8.8 -twilio==6.0.0 thehive4py>=1.4.4 -python-magic>=0.4.15 -cffi>=1.11.5 +twilio==6.0.0 diff --git a/setup.py b/setup.py index 865d7974f..b2528e3b4 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ packages=find_packages(), package_data={'elastalert': ['schema.yaml']}, install_requires=[ + 'apscheduler>=3.3.0' 'aws-requests-auth>=0.3.0', 'blist>=1.3.6', 'boto3>=1.4.4', diff --git a/tests/base_test.py b/tests/base_test.py index b10eb5a74..fa9018b11 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -88,9 +88,9 @@ def test_init_rule(ea): def test_query(ea): - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) - ea.current_es.search.assert_called_with(body={ + ea.thread_data.current_es.search.assert_called_with(body={ 'query': {'filtered': {'filter': {'bool': {'must': [{'range': {'@timestamp': {'lte': END_TIMESTAMP, 'gt': START_TIMESTAMP}}}]}}}}, 'sort': [{'@timestamp': {'order': 'asc'}}]}, index='idx', _source_include=['@timestamp'], ignore_unavailable=True, size=ea.rules[0]['max_query_size'], scroll=ea.conf['scroll_keepalive']) @@ -98,9 +98,9 @@ def test_query(ea): def test_query_with_fields(ea): ea.rules[0]['_source_enabled'] = False - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) - ea.current_es.search.assert_called_with(body={ + ea.thread_data.current_es.search.assert_called_with(body={ 'query': {'filtered': {'filter': {'bool': {'must': [{'range': {'@timestamp': {'lte': END_TIMESTAMP, 'gt': START_TIMESTAMP}}}]}}}}, 'sort': [{'@timestamp': {'order': 'asc'}}], 'fields': ['@timestamp']}, index='idx', ignore_unavailable=True, size=ea.rules[0]['max_query_size'], scroll=ea.conf['scroll_keepalive']) @@ -109,11 +109,11 @@ def test_query_with_fields(ea): def test_query_with_unix(ea): ea.rules[0]['timestamp_type'] = 'unix' ea.rules[0]['dt_to_ts'] = dt_to_unix - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) start_unix = dt_to_unix(START) end_unix = dt_to_unix(END) - ea.current_es.search.assert_called_with( + ea.thread_data.current_es.search.assert_called_with( body={'query': {'filtered': {'filter': {'bool': {'must': [{'range': {'@timestamp': {'lte': end_unix, 'gt': start_unix}}}]}}}}, 'sort': [{'@timestamp': {'order': 'asc'}}]}, index='idx', _source_include=['@timestamp'], ignore_unavailable=True, size=ea.rules[0]['max_query_size'], scroll=ea.conf['scroll_keepalive']) @@ -122,18 +122,18 @@ def test_query_with_unix(ea): def test_query_with_unixms(ea): ea.rules[0]['timestamp_type'] = 'unixms' ea.rules[0]['dt_to_ts'] = dt_to_unixms - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) start_unix = dt_to_unixms(START) end_unix = dt_to_unixms(END) - ea.current_es.search.assert_called_with( + ea.thread_data.current_es.search.assert_called_with( body={'query': {'filtered': {'filter': {'bool': {'must': [{'range': {'@timestamp': {'lte': end_unix, 'gt': start_unix}}}]}}}}, 'sort': [{'@timestamp': {'order': 'asc'}}]}, index='idx', _source_include=['@timestamp'], ignore_unavailable=True, size=ea.rules[0]['max_query_size'], scroll=ea.conf['scroll_keepalive']) def test_no_hits(ea): - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) assert ea.rules[0]['type'].add_data.call_count == 0 @@ -142,7 +142,7 @@ def test_no_terms_hits(ea): ea.rules[0]['use_terms_query'] = True ea.rules[0]['query_key'] = 'QWERTY' ea.rules[0]['doc_type'] = 'uiop' - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.run_query(ea.rules[0], START, END) assert ea.rules[0]['type'].add_terms_data.call_count == 0 @@ -150,7 +150,7 @@ def test_no_terms_hits(ea): def test_some_hits(ea): hits = generate_hits([START_TIMESTAMP, END_TIMESTAMP]) hits_dt = generate_hits([START, END]) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.run_query(ea.rules[0], START, END) assert ea.rules[0]['type'].add_data.call_count == 1 ea.rules[0]['type'].add_data.assert_called_with([x['_source'] for x in hits_dt['hits']['hits']]) @@ -162,7 +162,7 @@ def test_some_hits_unix(ea): ea.rules[0]['ts_to_dt'] = unix_to_dt hits = generate_hits([dt_to_unix(START), dt_to_unix(END)]) hits_dt = generate_hits([START, END]) - ea.current_es.search.return_value = copy.deepcopy(hits) + ea.thread_data.current_es.search.return_value = copy.deepcopy(hits) ea.run_query(ea.rules[0], START, END) assert ea.rules[0]['type'].add_data.call_count == 1 ea.rules[0]['type'].add_data.assert_called_with([x['_source'] for x in hits_dt['hits']['hits']]) @@ -176,7 +176,7 @@ def _duplicate_hits_generator(timestamps, **kwargs): def test_duplicate_timestamps(ea): - ea.current_es.search.side_effect = _duplicate_hits_generator([START_TIMESTAMP] * 3, blah='duplicate') + ea.thread_data.current_es.search.side_effect = _duplicate_hits_generator([START_TIMESTAMP] * 3, blah='duplicate') ea.run_query(ea.rules[0], START, ts_to_dt('2014-01-01T00:00:00Z')) assert len(ea.rules[0]['type'].add_data.call_args_list[0][0][0]) == 3 @@ -189,7 +189,7 @@ def test_duplicate_timestamps(ea): def test_match(ea): hits = generate_hits([START_TIMESTAMP, END_TIMESTAMP]) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.rules[0]['type'].matches = [{'@timestamp': END}] with mock.patch('elastalert.elastalert.elasticsearch_client'): ea.run_rule(ea.rules[0], END, START) @@ -280,7 +280,7 @@ def test_match_with_module_with_agg(ea): ea.rules[0]['match_enhancements'] = [mod] ea.rules[0]['aggregation'] = datetime.timedelta(minutes=15) hits = generate_hits([START_TIMESTAMP, END_TIMESTAMP]) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.rules[0]['type'].matches = [{'@timestamp': END}] with mock.patch('elastalert.elastalert.elasticsearch_client'): ea.run_rule(ea.rules[0], END, START) @@ -294,7 +294,7 @@ def test_match_with_enhancements_first(ea): ea.rules[0]['aggregation'] = datetime.timedelta(minutes=15) ea.rules[0]['run_enhancements_first'] = True hits = generate_hits([START_TIMESTAMP, END_TIMESTAMP]) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.rules[0]['type'].matches = [{'@timestamp': END}] with mock.patch('elastalert.elastalert.elasticsearch_client'): with mock.patch.object(ea, 'add_aggregated_alert') as add_alert: @@ -317,7 +317,7 @@ def test_agg_matchtime(ea): hits_timestamps = ['2014-09-26T12:34:45', '2014-09-26T12:40:45', '2014-09-26T12:47:45'] alerttime1 = dt_to_ts(ts_to_dt(hits_timestamps[0]) + datetime.timedelta(minutes=10)) hits = generate_hits(hits_timestamps) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits with mock.patch('elastalert.elastalert.elasticsearch_client'): # Aggregate first two, query over full range ea.rules[0]['aggregate_by_match_time'] = True @@ -373,7 +373,7 @@ def test_agg_not_matchtime(ea): hits_timestamps = ['2014-09-26T12:34:45', '2014-09-26T12:40:45', '2014-09-26T12:47:45'] match_time = ts_to_dt('2014-09-26T12:55:00Z') hits = generate_hits(hits_timestamps) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits with mock.patch('elastalert.elastalert.elasticsearch_client'): with mock.patch('elastalert.elastalert.ts_now', return_value=match_time): ea.rules[0]['aggregation'] = datetime.timedelta(minutes=10) @@ -402,7 +402,7 @@ def test_agg_cron(ea): ea.max_aggregation = 1337 hits_timestamps = ['2014-09-26T12:34:45', '2014-09-26T12:40:45', '2014-09-26T12:47:45'] hits = generate_hits(hits_timestamps) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits alerttime1 = dt_to_ts(ts_to_dt('2014-09-26T12:46:00')) alerttime2 = dt_to_ts(ts_to_dt('2014-09-26T13:04:00')) @@ -439,7 +439,7 @@ def test_agg_no_writeback_connectivity(ea): run again, that they will be passed again to add_aggregated_alert """ hit1, hit2, hit3 = '2014-09-26T12:34:45', '2014-09-26T12:40:45', '2014-09-26T12:47:45' hits = generate_hits([hit1, hit2, hit3]) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.rules[0]['aggregation'] = datetime.timedelta(minutes=10) ea.rules[0]['type'].matches = [{'@timestamp': hit1}, {'@timestamp': hit2}, @@ -453,7 +453,7 @@ def test_agg_no_writeback_connectivity(ea): {'@timestamp': hit2, 'num_hits': 0, 'num_matches': 3}, {'@timestamp': hit3, 'num_hits': 0, 'num_matches': 3}] - ea.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} + ea.thread_data.current_es.search.return_value = {'hits': {'total': 0, 'hits': []}} ea.add_aggregated_alert = mock.Mock() with mock.patch('elastalert.elastalert.elasticsearch_client'): @@ -469,7 +469,7 @@ def test_agg_with_aggregation_key(ea): hits_timestamps = ['2014-09-26T12:34:45', '2014-09-26T12:40:45', '2014-09-26T12:43:45'] match_time = ts_to_dt('2014-09-26T12:45:00Z') hits = generate_hits(hits_timestamps) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits with mock.patch('elastalert.elastalert.elasticsearch_client'): with mock.patch('elastalert.elastalert.ts_now', return_value=match_time): ea.rules[0]['aggregation'] = datetime.timedelta(minutes=10) @@ -562,7 +562,7 @@ def test_compound_query_key(ea): ea.rules[0]['query_key'] = 'this,that,those' ea.rules[0]['compound_query_key'] = ['this', 'that', 'those'] hits = generate_hits([START_TIMESTAMP, END_TIMESTAMP], this='abc', that=u'☃', those=4) - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits ea.run_query(ea.rules[0], START, END) call_args = ea.rules[0]['type'].add_data.call_args_list[0] assert 'this,that,those' in call_args[0][0][0] @@ -604,7 +604,7 @@ def test_silence_query_key(ea): def test_realert(ea): hits = ['2014-09-26T12:35:%sZ' % (x) for x in range(60)] matches = [{'@timestamp': x} for x in hits] - ea.current_es.search.return_value = hits + ea.thread_data.current_es.search.return_value = hits with mock.patch('elastalert.elastalert.elasticsearch_client'): ea.rules[0]['realert'] = datetime.timedelta(seconds=50) ea.rules[0]['type'].matches = matches @@ -703,7 +703,7 @@ def test_count(ea): query['query']['filtered']['filter']['bool']['must'][0]['range']['@timestamp']['lte'] = dt_to_ts(end) query['query']['filtered']['filter']['bool']['must'][0]['range']['@timestamp']['gt'] = dt_to_ts(start) start = start + ea.run_every - ea.current_es.count.assert_any_call(body=query, doc_type='doctype', index='idx', ignore_unavailable=True) + ea.thread_data.current_es.count.assert_any_call(body=query, doc_type='doctype', index='idx', ignore_unavailable=True) def run_and_assert_segmented_queries(ea, start, end, segment_size): @@ -727,8 +727,8 @@ def run_and_assert_segmented_queries(ea, start, end, segment_size): def test_query_segmenting_reset_num_hits(ea): # Tests that num_hits gets reset every time run_query is run def assert_num_hits_reset(): - assert ea.num_hits == 0 - ea.num_hits += 10 + assert ea.thread_data.num_hits == 0 + ea.thread_data.num_hits += 10 with mock.patch.object(ea, 'run_query') as mock_run_query: mock_run_query.side_effect = assert_num_hits_reset() ea.run_rule(ea.rules[0], END, START) @@ -915,6 +915,7 @@ def test_kibana_dashboard(ea): def test_rule_changes(ea): + re = datetime.timedelta(minutes=10) ea.rule_hashes = {'rules/rule1.yaml': 'ABC', 'rules/rule2.yaml': 'DEF'} ea.rules = [ea.init_rule(rule, True) for rule in [{'rule_file': 'rules/rule1.yaml', 'name': 'rule1', 'filter': []}, @@ -926,8 +927,8 @@ def test_rule_changes(ea): with mock.patch('elastalert.elastalert.get_rule_hashes') as mock_hashes: with mock.patch('elastalert.elastalert.load_configuration') as mock_load: - mock_load.side_effect = [{'filter': [], 'name': 'rule2', 'rule_file': 'rules/rule2.yaml'}, - {'filter': [], 'name': 'rule3', 'rule_file': 'rules/rule3.yaml'}] + mock_load.side_effect = [{'filter': [], 'name': 'rule2', 'rule_file': 'rules/rule2.yaml', 'run_every': re}, + {'filter': [], 'name': 'rule3', 'rule_file': 'rules/rule3.yaml', 'run_every': re}] mock_hashes.return_value = new_hashes ea.load_rule_changes() @@ -1004,9 +1005,9 @@ def test_count_keys(ea): ea.rules[0]['doc_type'] = 'blah' buckets = [{'aggregations': {'filtered': {'counts': {'buckets': [{'key': 'a', 'doc_count': 10}, {'key': 'b', 'doc_count': 5}]}}}}, {'aggregations': {'filtered': {'counts': {'buckets': [{'key': 'd', 'doc_count': 10}, {'key': 'c', 'doc_count': 12}]}}}}] - ea.current_es.search.side_effect = buckets + ea.thread_data.current_es.search.side_effect = buckets counts = ea.get_top_counts(ea.rules[0], START, END, ['this', 'that']) - calls = ea.current_es.search.call_args_list + calls = ea.thread_data.current_es.search.call_args_list assert calls[0][1]['search_type'] == 'count' assert calls[0][1]['body']['aggs']['filtered']['aggs']['counts']['terms'] == {'field': 'this', 'size': 5} assert counts['top_events_this'] == {'a': 10, 'b': 5} @@ -1131,7 +1132,7 @@ def mock_loop(): ea.stop() with mock.patch.object(ea, 'sleep_for', return_value=None): - with mock.patch.object(ea, 'run_all_rules') as mock_run: + with mock.patch.object(ea, 'sleep_for') as mock_run: mock_run.side_effect = mock_loop() start_thread = threading.Thread(target=ea.start) # Set as daemon to prevent a failed test from blocking exit diff --git a/tests/conftest.py b/tests/conftest.py index ca50a101a..bf066122a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import datetime - import logging -import mock import os + +import mock import pytest import elastalert.elastalert @@ -87,7 +87,8 @@ def ea(): 'max_query_size': 10000, 'ts_to_dt': ts_to_dt, 'dt_to_ts': dt_to_ts, - '_source_enabled': True}] + '_source_enabled': True, + 'run_every': datetime.timedelta(seconds=15)}] conf = {'rules_folder': 'rules', 'run_every': datetime.timedelta(minutes=10), 'buffer_time': datetime.timedelta(minutes=5), @@ -103,14 +104,18 @@ def ea(): elastalert.elastalert.elasticsearch_client = mock_es_client with mock.patch('elastalert.elastalert.get_rule_hashes'): with mock.patch('elastalert.elastalert.load_rules') as load_conf: - load_conf.return_value = conf - ea = elastalert.elastalert.ElastAlerter(['--pin_rules']) + with mock.patch('elastalert.elastalert.BackgroundScheduler'): + load_conf.return_value = conf + ea = elastalert.elastalert.ElastAlerter(['--pin_rules']) ea.rules[0]['type'] = mock_ruletype() ea.rules[0]['alert'] = [mock_alert()] ea.writeback_es = mock_es_client() ea.writeback_es.search.return_value = {'hits': {'hits': []}} ea.writeback_es.index.return_value = {'_id': 'ABCD'} ea.current_es = mock_es_client('', '') + ea.thread_data.current_es = ea.current_es + ea.thread_data.num_hits = 0 + ea.thread_data.num_dupes = 0 return ea