diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..2879763 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,48 @@ +# Changelog + +All notable changes to PyHSS are documented in this file, beginning from [Service Overhaul #168](https://github.com/nickvsnetworking/pyhss/pull/168). + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.0] - 2023-09-27 + +### Added + + - Systemd service files for PyHSS services + - /oam/diameter_peers endpoint + - /oam/deregister/{imsi} endpoint + - /geored/peers endpoint + - /geored/webhooks endpoint + - Dependency on Redis 7 for inter-service messaging + - Significant performance improvements under load + - Basic Rx support for RAA, AAA, ASA and STA + - Rx MO call flow support (AAR -> RAR -> RAA -> AAA) + - Dedicated bearer setup and teardown on Rx call + - Asymmetric geored support + - Configurable redis connection (Unix socket or TCP) + - Basic database upgrade support in tools/databaseUpgrade + - PCSCF state storage in ims_subscriber + - (Experimental) Working horizontal scalability + +### Changed + +- Split logical functions of PyHSS into 6 service processes +- Logtool no longer handles metric processing +- Updated config.yaml +- Gx CCR-T now flushes PGW / IMS data, depending on Called-Station-Id +- Benchmarked capability of at least ~500 diameter requests per second with a response time of under 2 seconds on a local network. + +### Fixed + + - Memory leaking in diameter.py + - Gx CCA now supports apn inside a plmn based uri + - AVP_Preemption_Capability and AVP_Preemption_Vulnerability now presents correctly in all diameter messages + - Crash when webhook or geored endpoints enabled and no peers defined + - CPU overutilization on all services + +### Removed + +- Multithreading in all services, except for metricService + +[1.0.0]: https://github.com/nickvsnetworking/pyhss/releases/tag/v1.0.0 \ No newline at end of file diff --git a/README.md b/README.md index cff724e..b3ee019 100644 --- a/README.md +++ b/README.md @@ -41,20 +41,28 @@ Basic configuration is set in the ``config.yaml`` file, You will need to set the IP address to bind to (IPv4 or IPv6), the Diameter hostname, realm, your PLMN and transport type to use (SCTP or TCP). -Once the configuration is done you can run the HSS by running ``hss.py`` and the server will run using whichever transport (TCP/SCTP) you have selected. +The diameter service runs in a trusting mode allowing Diameter connections from any other Diameter hosts. -The service runs in a trusting mode allowing Diameter connections from any other Diameter hosts. +To perform as a functioning HSS, the following services must be run as a minimum: +- diameterService.py +- hssService.py -## Structure +If you're provisioning the HSS for the first time, you'll also want to run: + - apiService.py -The file *hss.py* runs a threaded Sockets based listener (SCTP or TCP) to receive Diameter requests, process them and send back Diameter responses. +The rest of the services aren't strictly necessary, however your own configuration will dictate whether or not they are required. -Most of the heavy lifting in this is managed by the Diameter class, in ``diameter.py``. This: +## Structure - * Decodes incoming packets (Requests)(Returns AVPs as an array, called *avp*, and a Dict containing the packet variables (called *packet_vars*) - * Generates responses (Answer messages) to Requests (when provided with the AVP and packet_vars of the original Request) - * Generates Requests to send to other peers +PyHSS uses a queued microservices model. Each service performs a specific set of tasks, and uses redis messages to communicate with other services. +The following services make up PyHSS: + - diameterService.py: Handles receiving and sending of diameter messages, and diameter client connection state. + - hssService.py: Provides decoding and encoding of diameter requests and responses, as well as logic to perform as a HSS. + - apiService.py: Provides the API, to allow management of PyHSS. + - georedService.py: Sends georaphic redundancy messages to geored peers when defined. Also handles webhook messages. + - logService.py: Handles logging for all services. + - metricService.py: Exposes prometheus metrics from other services. ## Subscriber Information Storage @@ -71,12 +79,17 @@ Dependencies can be installed using Pip3: pip3 install -r requirements.txt ``` -Then after setting up the config, you can fire up the HSS itself by running: +PyHSS also requires [Redis 7.0.0](https://redis.io/docs/getting-started/installation/install-redis-on-linux/) or above. + +Then after setting up the config, you can fire up the necessary PyHSS services by running: ```shell -python3 hss.py +python3 diameterService.py +python3 hssService.py +python3 apiService.py ``` -All going well you'll have a functioning HSS at this point. +All going well you'll have a functioning HSS at this point. For production use, systemd scripts are located in `./systemd` +PyHSS API uses Flask, and can be configured with your favourite WSGI server. To get everything more production ready checkout [Monit with PyHSS](docs/monit.md) for more info. diff --git a/config.yaml b/config.yaml index 75fb78a..e0e2a84 100644 --- a/config.yaml +++ b/config.yaml @@ -30,17 +30,17 @@ hss: #IMSI of Test Subscriber for Unit Checks (Optional) test_sub_imsi: '001021234567890' - #Device Watchdog Request Interval (In Seconds - If set to 0 disabled) - device_watchdog_request_interval: 0 - - #Async Queue Check Interval (In Seconds - If set to 0 disabled) - async_check_interval: 0 - #The maximum time to wait, in seconds, before disconnecting a client when no data is received. client_socket_timeout: 120 - #The maximum amount of times a failed diameter response/query should be resent before considering the peer offline and terminating their connection - diameter_max_retries: 1 + #The maximum time to wait, in seconds, before disconnecting a client when no data is received. + client_socket_timeout: 300 + + #The maximum time to wait, in seconds, before discarding a diameter request. + diameter_request_timeout: 3 + + #The amount of time, in seconds, before purging a disconnected client from the Active Diameter Peers key in redis. + active_diameter_peers_timeout: 10 #Prevent updates from being performed without a valid 'Provisioning-Key' in the header lock_provisioning: False @@ -68,22 +68,24 @@ hss: api: page_size: 200 -external: - external_webhook_notification_enabled: False - external_webhook_notification_url: https://api.example.com/webhook +benchmarking: + # Whether to enable benchmark logging + enabled: True + # How often to report, in seconds. Not all benchmarking supports interval reporting. + reporting_interval: 3600 eir: imsi_imei_logging: True #Store current IMEI / IMSI pair in backend - sim_swap_notify_webhook: http://localhost:5000/webhooks/sim_swap_notify/ no_match_response: 2 #Greylist tac_database_csv: '/etc/pyhss/tac_database_Nov2022.csv' logging: - level: DEBUG + level: INFO logfiles: - hss_logging_file: log/hss.log - diameter_logging_file: log/diameter.log - database_logging_file: log/db.log + hss_logging_file: /var/log/pyhss_hss.log + diameter_logging_file: /var/log/pyhss_diameter.log + geored_logging_file: /var/log/pyhss_geored.log + metric_logging_file: /var/log/pyhss_metrics.log log_to_terminal: True sqlalchemy_sql_echo: True sqlalchemy_pool_recycle: 15 @@ -98,18 +100,25 @@ database: password: password database: hss2 +## External Webhook Notifications +webhooks: + enabled: False + endpoints: + - http://127.0.0.1:8181 + ## Geographic Redundancy Parameters geored: enabled: False sync_actions: ['HSS', 'IMS', 'PCRF', 'EIR'] #What event actions should be synced - sync_endpoints: #List of PyHSS API Endpoints to update + endpoints: #List of PyHSS API Endpoints to update - 'http://hss01.mnc001.mcc001.3gppnetwork.org:8080' - 'http://hss02.mnc001.mcc001.3gppnetwork.org:8080' -## Stats Parameters +#Redis is required to run PyHSS. A locally running instance is recommended for production. redis: - enabled: False - clear_stats_on_boot: True + # Whether to use a UNIX socket instead of a tcp connection to redis. Host and port is ignored if useUnixSocket is True. + useUnixSocket: False + unixSocketPath: '/var/run/redis/redis-server.sock' host: localhost port: 6379 diff --git a/database.py b/database.py deleted file mode 100755 index cf78d87..0000000 --- a/database.py +++ /dev/null @@ -1,2392 +0,0 @@ -from sqlalchemy import Column, Integer, String, MetaData, Table, Boolean, ForeignKey, select, UniqueConstraint, DateTime, BigInteger, event, Text, DateTime, Float -from sqlalchemy import create_engine -from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.sql import desc, func -from sqlalchemy_utils import database_exists, create_database -from sqlalchemy.orm import sessionmaker, relationship, Session, class_mapper -from sqlalchemy.orm.attributes import History, get_history -import sys, os -sys.path.append(os.path.realpath('lib')) -from functools import wraps -import json -import datetime, time -from datetime import timezone -import re -import binascii -import uuid -import socket -import traceback -from contextlib import contextmanager -import logging -import logtool -import pprint -from logtool import * -from construct import Default -import S6a_crypt -import requests -from requests.exceptions import ConnectionError, Timeout -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry -import threading - -import yaml -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -logtool = logtool.LogTool() -logtool.setup_logger('DBLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) -DBLogger = logging.getLogger('DBLogger') -DBLogger.info("DB Log Initialised.") - -db_string = 'mysql://' + str(yaml_config['database']['username']) + ':' + str(yaml_config['database']['password']) + '@' + str(yaml_config['database']['server']) + '/' + str(yaml_config['database']['database'] + "?autocommit=true") -print(db_string) -engine = create_engine( - db_string, - echo = yaml_config['logging'].get('sqlalchemy_sql_echo', True), - pool_recycle=yaml_config['logging'].get('sqlalchemy_pool_recycle', 5), - pool_size=yaml_config['logging'].get('sqlalchemy_pool_size', 30), - max_overflow=yaml_config['logging'].get('sqlalchemy_max_overflow', 0)) -from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() - -class OPERATION_LOG_BASE(Base): - __tablename__ = 'operation_log' - id = Column(Integer, primary_key=True) - item_id = Column(Integer, nullable=False) - operation_id = Column(String(36), nullable=False) - operation = Column(String(10)) - changes = Column(Text) - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc)) - timestamp = Column(DateTime, default=func.now()) - table_name = Column('table_name', String(255)) - __mapper_args__ = {'polymorphic_on': table_name} - -class APN_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'apn'} - apn = relationship("APN", back_populates="operation_logs") - apn_id = Column(Integer, ForeignKey('apn.apn_id')) - -class SUBSCRIBER_ROUTING_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'subscriber_routing'} - subscriber_routing = relationship("SUBSCRIBER_ROUTING", back_populates="operation_logs") - subscriber_routing_id = Column(Integer, ForeignKey('subscriber_routing.subscriber_routing_id')) - -class SERVING_APN_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'serving_apn'} - serving_apn = relationship("SERVING_APN", back_populates="operation_logs") - serving_apn_id = Column(Integer, ForeignKey('serving_apn.serving_apn_id')) - -class AUC_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'auc'} - auc = relationship("AUC", back_populates="operation_logs") - auc_id = Column(Integer, ForeignKey('auc.auc_id')) - -class SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'subscriber'} - subscriber = relationship("SUBSCRIBER", back_populates="operation_logs") - subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id')) - -class IMS_SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'ims_subscriber'} - ims_subscriber = relationship("IMS_SUBSCRIBER", back_populates="operation_logs") - ims_subscriber_id = Column(Integer, ForeignKey('ims_subscriber.ims_subscriber_id')) - -class CHARGING_RULE_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'charging_rule'} - charging_rule = relationship("CHARGING_RULE", back_populates="operation_logs") - charging_rule_id = Column(Integer, ForeignKey('charging_rule.charging_rule_id')) - -class TFT_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'tft'} - tft = relationship("TFT", back_populates="operation_logs") - tft_id = Column(Integer, ForeignKey('tft.tft_id')) - -class EIR_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'eir'} - eir = relationship("EIR", back_populates="operation_logs") - eir_id = Column(Integer, ForeignKey('eir.eir_id')) - -class IMSI_IMEI_HISTORY_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'eir_history'} - eir_history = relationship("IMSI_IMEI_HISTORY", back_populates="operation_logs") - imsi_imei_history_id = Column(Integer, ForeignKey('eir_history.imsi_imei_history_id')) - -class SUBSCRIBER_ATTRIBUTES_OPERATION_LOG(OPERATION_LOG_BASE): - __mapper_args__ = {'polymorphic_identity': 'subscriber_attributes'} - subscriber_attributes = relationship("SUBSCRIBER_ATTRIBUTES", back_populates="operation_logs") - subscriber_attributes_id = Column(Integer, ForeignKey('subscriber_attributes.subscriber_attributes_id')) - -class APN(Base): - __tablename__ = 'apn' - apn_id = Column(Integer, primary_key=True, doc='Unique ID of APN') - apn = Column(String(50), nullable=False, doc='Short name of the APN') - ip_version = Column(Integer, default=0, doc="IP version used - 0: ipv4, 1: ipv6 2: ipv4+6 3: ipv4 or ipv6 [3GPP TS 29.272 7.3.62]") - pgw_address = Column(String(50), doc='IP of the PGW') - sgw_address = Column(String(50), doc='IP of the SGW') - charging_characteristics = Column(String(4), default='0800', doc='For the encoding of this information element see 3GPP TS 32.298 [9]') - apn_ambr_dl = Column(Integer, nullable=False, doc='Downlink Maximum Bit Rate for this APN') - apn_ambr_ul = Column(Integer, nullable=False, doc='Uplink Maximum Bit Rate for this APN') - qci = Column(Integer, default=9, doc='QoS Class Identifier') - arp_priority = Column(Integer, default=4, doc='Allocation and Retention Policy - Bearer priority level (1-15)') - arp_preemption_capability = Column(Boolean, default=False, doc='Allocation and Retention Policy - Capability to Preempt resources from other Subscribers') - arp_preemption_vulnerability = Column(Boolean, default=True, doc='Allocation and Retention Policy - Vulnerability to have resources Preempted by other Subscribers') - charging_rule_list = Column(String(18), doc='Comma separated list of predefined ChargingRules to be installed in CCA-I') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("APN_OPERATION_LOG", back_populates="apn") - -class SUBSCRIBER_ROUTING(Base): - __tablename__ = 'subscriber_routing' - __table_args__ = ( - # this can be db.PrimaryKeyConstraint if you want it to be a primary key - UniqueConstraint('subscriber_id', 'apn_id'), - ) - subscriber_routing_id = Column(Integer, primary_key=True, doc='Unique ID of Subscriber Routing item') - subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='subscriber_id of the served subscriber') - apn_id = Column(Integer, ForeignKey('apn.apn_id', ondelete='CASCADE'), doc='apn_id of the target apn') - ip_version = Column(Integer, default=0, doc="IP version used - 0: ipv4, 1: ipv6 2: ipv4+6 3: ipv4 or ipv6 [3GPP TS 29.272 7.3.62]") - ip_address = Column(String(254), doc='IP of the UE') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("SUBSCRIBER_ROUTING_OPERATION_LOG", back_populates="subscriber_routing") - -class AUC(Base): - __tablename__ = 'auc' - auc_id = Column(Integer, primary_key = True, doc='Unique ID of AuC entry') - ki = Column(String(32), doc='SIM Key - Authentication Key - Ki', nullable=False) - opc = Column(String(32), doc='SIM Key - Network Operators key OPc', nullable=False) - amf = Column(String(4), doc='Authentication Management Field', nullable=False) - sqn = Column(BigInteger, doc='Authentication sequence number') - iccid = Column(String(20), unique=True, doc='Integrated Circuit Card Identification Number') - imsi = Column(String(18), unique=True, doc='International Mobile Subscriber Identity') - batch_name = Column(String(20), doc='Name of SIM Batch') - sim_vendor = Column(String(20), doc='SIM Vendor') - esim = Column(Boolean, default=0, doc='Card is eSIM') - lpa = Column(String(128), doc='LPA URL for activating eSIM') - pin1 = Column(String(20), doc='PIN1') - pin2 = Column(String(20), doc='PIN2') - puk1 = Column(String(20), doc='PUK1') - puk2 = Column(String(20), doc='PUK2') - kid = Column(String(20), doc='KID') - psk = Column(String(128), doc='PSK') - des = Column(String(128), doc='DES') - adm1 = Column(String(20), doc='ADM1') - misc1 = Column(String(128), doc='For misc data storage 1') - misc2 = Column(String(128), doc='For misc data storage 2') - misc3 = Column(String(128), doc='For misc data storage 3') - misc4 = Column(String(128), doc='For misc data storage 4') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("AUC_OPERATION_LOG", back_populates="auc") - -class SUBSCRIBER(Base): - __tablename__ = 'subscriber' - subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of Subscriber entry') - imsi = Column(String(18), unique=True, doc='International Mobile Subscriber Identity') - enabled = Column(Boolean, default=1, doc='Subscriber enabled/disabled') - auc_id = Column(Integer, ForeignKey('auc.auc_id'), doc='Reference to AuC ID defined with SIM Auth data', nullable=False) - default_apn = Column(Integer, ForeignKey('apn.apn_id'), doc='APN ID to use for the default APN', nullable=False) - apn_list = Column(String(64), doc='Comma separated list of allowed APNs', nullable=False) - msisdn = Column(String(18), doc='Primary Phone number of Subscriber') - ue_ambr_dl = Column(Integer, default=999999, doc='Downlink Aggregate Maximum Bit Rate') - ue_ambr_ul = Column(Integer, default=999999, doc='Uplink Aggregate Maximum Bit Rate') - nam = Column(Integer, default=0, doc='Network Access Mode [3GPP TS. 123 008 2.1.1.2] - 0 (PACKET_AND_CIRCUIT) or 2 (ONLY_PACKET)') - subscribed_rau_tau_timer = Column(Integer, default=300, doc='Subscribed periodic TAU/RAU timer value in seconds') - serving_mme = Column(String(512), doc='MME serving this subscriber') - serving_mme_timestamp = Column(DateTime, doc='Timestamp of attach to MME') - serving_mme_realm = Column(String(512), doc='Realm of serving mme') - serving_mme_peer = Column(String(512), doc='Diameter peer used to reach MME then ; then the HSS the Diameter peer is connected to') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("SUBSCRIBER_OPERATION_LOG", back_populates="subscriber") - -class SERVING_APN(Base): - __tablename__ = 'serving_apn' - serving_apn_id = Column(Integer, primary_key=True, doc='Unique ID of SERVING_APN') - subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='subscriber_id of the served subscriber') - apn = Column(Integer, ForeignKey('apn.apn_id', ondelete='CASCADE'), doc='apn_id of the APN served') - pcrf_session_id = Column(String(100), doc='Session ID from the PCRF') - subscriber_routing = Column(String(100), doc='IP Address allocated to the UE') - ip_version = Column(Integer, default=0, doc=APN.ip_version.doc) - serving_pgw = Column(String(512), doc='PGW serving this subscriber') - serving_pgw_timestamp = Column(DateTime, doc='Timestamp of attach to PGW') - serving_pgw_realm = Column(String(512), doc='Realm of serving PGW') - serving_pgw_peer = Column(String(512), doc='Diameter peer used to reach PGW') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("SERVING_APN_OPERATION_LOG", back_populates="serving_apn") - -class IMS_SUBSCRIBER(Base): - __tablename__ = 'ims_subscriber' - ims_subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of IMS_Subscriber entry') - msisdn = Column(String(18), unique=True, doc=SUBSCRIBER.msisdn.doc) - msisdn_list = Column(String(1200), doc='Comma Separated list of additional MSISDNs for Subscriber') - imsi = Column(String(18), unique=False, doc=SUBSCRIBER.imsi.doc) - ifc_path = Column(String(18), doc='Path to template file for the Initial Filter Criteria') - sh_profile = Column(Text(12000), doc='Sh Subscriber Profile') - scscf = Column(String(512), doc='Serving-CSCF serving this subscriber') - scscf_timestamp = Column(DateTime, doc='Timestamp of attach to S-CSCF') - scscf_realm = Column(String(512), doc='Realm of SCSCF') - scscf_peer = Column(String(512), doc='Diameter peer used to reach SCSCF') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("IMS_SUBSCRIBER_OPERATION_LOG", back_populates="ims_subscriber") - -class CHARGING_RULE(Base): - __tablename__ = 'charging_rule' - charging_rule_id = Column(Integer, primary_key = True, doc='Unique ID of CHARGING_RULE entry') - rule_name = Column(String(20), doc='Name of rule pushed to PGW (Short, no special chars)') - - qci = Column(Integer, default=9, doc=APN.qci.doc) - arp_priority = Column(Integer, default=4, doc=APN.arp_priority.doc) - arp_preemption_capability = Column(Boolean, default=False, doc=APN.arp_preemption_capability.doc) - arp_preemption_vulnerability = Column(Boolean, default=True, doc=APN.arp_preemption_vulnerability.doc) - - mbr_dl = Column(Integer, nullable=False, doc='Maximum Downlink Bitrate for traffic matching this rule') - mbr_ul = Column(Integer, nullable=False, doc='Maximum Uplink Bitrate for traffic matching this rule') - gbr_dl = Column(Integer, nullable=False, doc='Guaranteed Downlink Bitrate for traffic matching this rule') - gbr_ul = Column(Integer, nullable=False, doc='Guaranteed Uplink Bitrate for traffic matching this rule') - tft_group_id = Column(Integer, doc='Will match any TFTs using this TFT Group to form the TFT list used in the Charging Rule') - precedence = Column(Integer, doc='Precedence of this rule, allows rule to override or be overridden by a higher priority rule') - rating_group = Column(Integer, doc='Rating Group in OCS / OFCS that traffic matching this rule will be charged under') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("CHARGING_RULE_OPERATION_LOG", back_populates="charging_rule") - -class TFT(Base): - __tablename__ = 'tft' - tft_id = Column(Integer, primary_key = True, doc='Unique ID of CHARGING_RULE entry') - tft_group_id = Column(Integer, nullable=False, doc=CHARGING_RULE.tft_group_id.doc) - tft_string = Column(String(100), nullable=False, doc='IPFilterRules as defined in [RFC 6733] taking the format: action dir proto from src to dst') - direction = Column(Integer, nullable=False, doc='Traffic Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("TFT_OPERATION_LOG", back_populates="tft") - -class EIR(Base): - __tablename__ = 'eir' - eir_id = Column(Integer, primary_key = True, doc='Unique ID of EIR entry') - imei = Column(String(60), doc='Exact IMEI or Regex to match IMEI (Depending on regex_mode value)') - imsi = Column(String(60), doc='Exact IMSI or Regex to match IMSI (Depending on regex_mode value)') - regex_mode = Column(Integer, default=1, doc='0 - Exact Match mode, 1 - Regex Mode') - match_response_code = Column(Integer, doc='0 - Whitelist, 1 - Blacklist, 2 - Greylist') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("EIR_OPERATION_LOG", back_populates="eir") - -class IMSI_IMEI_HISTORY(Base): - __tablename__ = 'eir_history' - imsi_imei_history_id = Column(Integer, primary_key = True, doc='Unique ID of IMSI_IMEI_HISTORY entry') - imsi_imei = Column(String(60), unique=True, doc='Combined IMSI + IMEI value') - match_response_code = Column(Integer, doc='Response code that was returned') - imsi_imei_timestamp = Column(DateTime, doc='Timestamp of last match') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("IMSI_IMEI_HISTORY_OPERATION_LOG", back_populates="eir_history") - -class SUBSCRIBER_ATTRIBUTES(Base): - __tablename__ = 'subscriber_attributes' - subscriber_attributes_id = Column(Integer, primary_key = True, doc='Unique ID of Attribute') - subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='Reference to Subscriber ID defined within Subscriber Section', nullable=False) - key = Column(String(60), doc='Arbitrary key') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - value = Column(String(12000), doc='Arbitrary value') - operation_logs = relationship("SUBSCRIBER_ATTRIBUTES_OPERATION_LOG", back_populates="subscriber_attributes") - -# Create database if it does not exist. -if not database_exists(engine.url): - DBLogger.debug("Creating database") - create_database(engine.url) - Base.metadata.create_all(engine) -else: - DBLogger.debug("Database already created") - -def load_IMEI_database_into_Redis(): - try: - DBLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) - csvfile = open(str(yaml_config['eir']['tac_database_csv'])) - DBLogger.info("This may take a few seconds to buffer into Redis...") - except: - DBLogger.error("Failed to read CSV file of IMEI TAC database") - return - try: - count = 0 - for line in csvfile: - line = line.replace('"', '') #Strip excess invered commas - line = line.replace("'", '') #Strip excess invered commas - line = line.rstrip() #Strip newlines - result = line.split(',') - tac_prefix = result[0] - name = result[1].lstrip() - model = result[2].lstrip() - if count == 0: - DBLogger.info("Checking to see if entries are already present...") - #DBLogger.info("Searching Redis for key " + str(tac_prefix) + " to see if data already provisioned") - redis_imei_result = logtool.RedisHMGET(key=str(tac_prefix)) - if len(redis_imei_result) != 0: - DBLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") - break - else: - DBLogger.info("No data loaded into Redis, proceeding to load...") - imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - logtool.RedisHMSET(key=str(tac_prefix), value_dict=imei_result) - count = count +1 - DBLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") - except Exception as E: - DBLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) - return - -#Load IMEI TAC database into Redis if enabled -if ('tac_database_csv' in yaml_config['eir']) and (yaml_config['redis']['enabled'] == True): - load_IMEI_database_into_Redis() -else: - DBLogger.info("Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config") - - -def safe_rollback(session): - try: - if session.is_active: - session.rollback() - except Exception as E: - DBLogger.error(f"Failed to rollback session, error: {E}") - -def safe_close(session): - try: - if session.is_active: - session.close() - except Exception as E: - DBLogger.error(f"Failed to run safe_close on session, error: {E}") - -def sqlalchemy_type_to_json_schema_type(sqlalchemy_type): - """ - Map SQLAlchemy types to JSON Schema types. - """ - if isinstance(sqlalchemy_type, Integer): - return "integer" - elif isinstance(sqlalchemy_type, String): - return "string" - elif isinstance(sqlalchemy_type, Boolean): - return "boolean" - elif isinstance(sqlalchemy_type, DateTime): - return "string" - elif isinstance(sqlalchemy_type, Float): - return "number" - else: - return "string" # Default to string for unsupported types. - -def generate_json_schema(model_class, required=None): - properties = {} - required = required or [] - - for column in model_class.__table__.columns: - prop_type = sqlalchemy_type_to_json_schema_type(column.type) - prop_dict = { - "type": prop_type, - "description": column.doc - } - if prop_type == "string": - if hasattr(column.type, 'length'): - prop_dict["maxLength"] = column.type.length - if isinstance(column.type, DateTime): - prop_dict["format"] = "date-time" - if not column.nullable: - required.append(column.name) - properties[column.name] = prop_dict - - return {"type": "object", "title" : str(model_class.__name__), "properties": properties, "required": required} - -# Create individual tables if they do not exist. -inspector = Inspector.from_engine(engine) -for table_name in Base.metadata.tables.keys(): - if table_name not in inspector.get_table_names(): - DBLogger.debug(f"Creating table {table_name}") - Base.metadata.tables[table_name].create(bind=engine) - else: - DBLogger.debug(f"Table {table_name} already exists") - -def update_old_record(session, operation_log): - oldest_log = session.query(OPERATION_LOG_BASE).order_by(OPERATION_LOG_BASE.timestamp.asc()).first() - if oldest_log is not None: - for attr in class_mapper(oldest_log.__class__).column_attrs: - if attr.key != 'id' and hasattr(operation_log, attr.key): - setattr(oldest_log, attr.key, getattr(operation_log, attr.key)) - oldest_log.timestamp = datetime.datetime.now(tz=timezone.utc) - session.flush() - else: - raise ValueError("Unable to find record to update") - -def notify_webhook(operation, external_webhook_notification_url, externalNotification, externalNotificationHeaders): - try: - if operation == 'UPDATE': - webhookResponse = requests.patch(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - elif operation == 'DELETE': - webhookResponse = requests.delete(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - elif operation == 'CREATE': - webhookResponse = requests.put(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - except requests.exceptions.Timeout: - DBLogger.error(f"Timeout occurred when sending webhook to {external_webhook_notification_url}") - return False - except requests.exceptions.RequestException as e: - DBLogger.error(f"Request exception when sending webhook to {external_webhook_notification_url}") - return False - - if webhookResponse.status_code != 200: - DBLogger.error(f"Response code from external webhook at {external_webhook_notification_url} is != 200.\nResponse Code is: {webhookResponse.status_code}\nResponse Body is: {webhookResponse.content}") - return False - return True - -def handle_external_webhook(objectData, operation): - external_webhook_notification_enabled = yaml_config.get('external', {}).get('external_webhook_notification_enabled', False) - external_webhook_notification_url = yaml_config.get('external', {}).get('external_webhook_notification_url', '') - if not external_webhook_notification_enabled: - return False - if not external_webhook_notification_url: - DBLogger.error("External webhook notification enabled, but external_webhook_notification_url is not defined.") - - externalNotification = Sanitize_Datetime(objectData) - externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} - - # Using separate thread to process webhook - threading.Thread(target=notify_webhook, args=(operation, external_webhook_notification_url, externalNotification, externalNotificationHeaders), daemon=True).start() - return True - -def log_change(session, item_id, operation, changes, table_name, operation_id, generated_id=None): - # We don't want to log rollback operations - if session.info.get("operation") == 'ROLLBACK': - return - max_records = 1000 - count = session.query(OPERATION_LOG_BASE).count() - - # Combine all changes into a single string with their types - changes_string = '\r\n\r\n'.join(f"{column_name}: [{type(old_value).__name__}] {old_value} ----> [{type(new_value).__name__}] {new_value}" for column_name, old_value, new_value in changes) - - change = OPERATION_LOG_BASE( - item_id=item_id or generated_id, - operation_id=operation_id, - operation=operation, - last_modified=datetime.datetime.now(tz=timezone.utc), - changes=changes_string, - table_name=table_name - ) - - if count >= max_records: - update_old_record(session, change) - else: - try: - session.add(change) - session.flush() - except Exception as E: - DBLogger.error("Failed to commit changelog, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - return operation_id - - -def log_changes_before_commit(session): - - operation_id = session.info.get("operation_id", None) or str(uuid.uuid4()) - if session.info.get("operation") == 'ROLLBACK': - return - - changelog_pending = any(isinstance(obj, OPERATION_LOG_BASE) for obj in session.new) - if changelog_pending: - return # Skip if there are pending OPERATION_LOG_BASE objects - - for state, operation in [ - (session.new, 'INSERT'), - (session.dirty, 'UPDATE'), - (session.deleted, 'DELETE') - ]: - for obj in state: - if isinstance(obj, OPERATION_LOG_BASE): - continue # Skip change log entries - - item_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) - generated_id = None - - #Avoid logging rollback operations - if operation == 'ROLLBACK': - return - - # Flush the session to generate primary key for new objects - if operation == 'INSERT': - session.flush() - - if operation == 'UPDATE': - changes = [] - for attr in class_mapper(obj.__class__).column_attrs: - hist = get_history(obj, attr.key) - DBLogger.info(f"History {hist}") - if hist.has_changes() and hist.added and hist.deleted: - old_value, new_value = hist.deleted[0], hist.added[0] - DBLogger.info(f"Old Value {old_value}") - DBLogger.info(f"New Value {new_value}") - changes.append((attr.key, old_value, new_value)) - continue - - if not changes: - continue - - operation_id = log_change(session, item_id, operation, changes, obj.__table__.name, operation_id) - - elif operation in ['INSERT', 'DELETE']: - changes = [] - for column in obj.__table__.columns: - column_name = column.name - value = getattr(obj, column_name) - if operation == 'INSERT': - old_value, new_value = None, value - if item_id is None: - generated_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) - elif operation == 'DELETE': - old_value, new_value = value, None - changes.append((column_name, old_value, new_value)) - operation_id = log_change(session, item_id, operation, changes, obj.__table__.name, operation_id, generated_id) - -def get_class_by_tablename(base, tablename): - """ - Returns a class object based on the given tablename. - - :param base: Base class of SQLAlchemy models - :param tablename: Name of the table to retrieve the class for - :return: Class object or None if not found - """ - for mapper in base.registry.mappers: - cls = mapper.class_ - if hasattr(cls, '__tablename__') and cls.__tablename__ == tablename: - return cls - return None - -def str_to_type(type_str, value_str): - if type_str == 'int': - return int(value_str) - elif type_str == 'float': - return float(value_str) - elif type_str == 'str': - return value_str - elif type_str == 'bool': - return value_str == 'True' - elif type_str == 'NoneType': - return None - else: - raise ValueError(f'Cannot convert to type: {type_str}') - - -def rollback_last_change(existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - try: - # Get the most recent operation - last_operation = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() - - if last_operation is None: - return "No operations to roll back." - - rollback_messages = [] - operation_id = str(uuid.uuid4()) - - target_class = get_class_by_tablename(Base, last_operation.table_name) - if not target_class: - return f"Error: Could not find table {last_operation.table_name}" - - primary_key_col = target_class.__mapper__.primary_key[0].key - filter_by_kwargs = {primary_key_col: last_operation.item_id} - target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() - - if last_operation.operation == 'UPDATE': - if not target_item: - return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" - - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) - old_value = str_to_type(old_type_str, old_value_repr) - - # Revert the change - setattr(target_item, column_name, old_value) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" - ) - - elif last_operation.operation == 'INSERT': - if target_item: - session.delete(target_item) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" - ) - - elif last_operation.operation == 'DELETE': - # Aggregate old values of all columns into a single dictionary - old_values_dict = {} - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - DBLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") - old_value = str_to_type(old_type_str, old_value_repr) - - old_values_dict[column_name] = old_value - DBLogger.error("old_value_dict: " + str(old_values_dict)) - - if not target_item: - try: - # Create the target item using the aggregated old values - target_item = target_class(**old_values_dict) - session.add(target_item) - except Exception as e: - return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" - ) - - else: - return f"Error: Unknown operation {last_operation.operation}" - - try: - session.commit() - safe_close(session) - except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) - - except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def rollback_change_by_operation_id(operation_id, existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - try: - # Get the most recent operation - last_operation = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() - - if last_operation is None: - return "No operation to roll back." - - rollback_messages = [] - operation_id = str(uuid.uuid4()) - - target_class = get_class_by_tablename(Base, last_operation.table_name) - if not target_class: - return f"Error: Could not find table {last_operation.table_name}" - - primary_key_col = target_class.__mapper__.primary_key[0].key - filter_by_kwargs = {primary_key_col: last_operation.item_id} - target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() - - if last_operation.operation == 'UPDATE': - if not target_item: - return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" - - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) - old_value = str_to_type(old_type_str, old_value_repr) - - # Revert the change - setattr(target_item, column_name, old_value) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" - ) - - elif last_operation.operation == 'INSERT': - if target_item: - session.delete(target_item) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" - ) - - elif last_operation.operation == 'DELETE': - # Aggregate old values of all columns into a single dictionary - old_values_dict = {} - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - DBLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") - old_value = str_to_type(old_type_str, old_value_repr) - - old_values_dict[column_name] = old_value - DBLogger.error("old_value_dict: " + str(old_values_dict)) - - if not target_item: - try: - # Create the target item using the aggregated old values - target_item = target_class(**old_values_dict) - session.add(target_item) - except Exception as e: - return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" - ) - - else: - return f"Error: Unknown operation {last_operation.operation}" - - try: - session.commit() - safe_close(session) - except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) - - except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def get_all_operation_logs(page=0, page_size=yaml_config['api'].get('page_size', 100), existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - try: - # Get all distinct operation_ids ordered by max timestamp (descending order) - operation_ids = session.query(OPERATION_LOG_BASE.operation_id).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) - - operation_ids = operation_ids.limit(page_size).offset(page * page_size) - - operation_ids = operation_ids.all() - - all_operations = [] - - for operation_id in operation_ids: - operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() - - if operation_log is not None: - # Convert the object to dictionary - obj_dict = operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - all_operations.append(sanitized_obj_dict) - - safe_close(session) - return all_operations - except Exception as E: - DBLogger.error(f"get_all_operation_logs error: {E}") - DBLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def get_all_operation_logs_by_table(table_name, page=0, page_size=yaml_config['api'].get('page_size', 100), existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - try: - # Get all distinct operation_ids ordered by max timestamp (descending order) - operation_ids = session.query(OPERATION_LOG_BASE.operation_id).filter(OPERATION_LOG_BASE.table_name == table_name).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) - - operation_ids = operation_ids.limit(page_size).offset(page * page_size) - - operation_ids = operation_ids.all() - - all_operations = [] - - for operation_id in operation_ids: - operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() - - if operation_log is not None: - # Convert the object to dictionary - obj_dict = operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - all_operations.append(sanitized_obj_dict) - - safe_close(session) - return all_operations - except Exception as E: - DBLogger.error(f"get_all_operation_logs_by_table error: {E}") - DBLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def get_last_operation_log(existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - try: - # Get the top 100 records ordered by timestamp (descending order) - top_100_records = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).limit(100) - - # Get the most recent operation_id - most_recent_operation_log = top_100_records.first() - - # Convert the object to dictionary - if most_recent_operation_log is not None: - obj_dict = most_recent_operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - return sanitized_obj_dict - - safe_close(session) - return None - except Exception as E: - DBLogger.error(f"get_last_operation_log error: {E}") - DBLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - - -def GeoRed_Push_Request(remote_hss, json_data, transaction_id, url=None): - headers = {"Content-Type": "application/json", "Transaction-Id": str(transaction_id)} - DBLogger.debug("transaction_id: " + str(transaction_id) + " pushing update to " + str(remote_hss).replace('http://', '')) - try: - session = requests.Session() - # Create a Retry object with desired parameters - retries = Retry(total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]) - - # Create an HTTPAdapter and pass the Retry object - adapter = HTTPAdapter(max_retries=retries) - - session.mount('http://', adapter) - if url == None: - endpoint = 'geored' - r = session.patch(str(remote_hss) + '/geored/', data=json.dumps(json_data), headers=headers) - else: - endpoint = url.split('/', 1)[0] - r = session.patch(url, data=json.dumps(json_data), headers=headers) - DBLogger.debug("transaction_id: " + str(transaction_id) + " updated on " + str(remote_hss).replace('http://', '') + " with status code " + str(r.status_code)) - if str(r.status_code).startswith('2'): - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code=str(r.status_code), - error="" - ).inc() - else: - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code=str(r.status_code), - error=str(r.reason) - ).inc() - except ConnectionError as e: - error_message = str(e) - if "Name or service not known" in error_message: - DBLogger.error("transaction_id: " + str(transaction_id) + " name or service not known") - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="No matching DNS entry found" - ).inc() - else: - print("Other ConnectionError:", error_message) - DBLogger.error("transaction_id: " + str(transaction_id) + " " + str(error_message)) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="Connection Refused" - ).inc() - except Timeout: - DBLogger.error("transaction_id: " + str(transaction_id) + " timed out connecting to peer " + str(remote_hss).replace('http://', '')) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="Timeout" - ).inc() - except Exception as e: - DBLogger.error("transaction_id: " + str(transaction_id) + " unexpected error " + str(e) + " when connecting to peer " + str(remote_hss).replace('http://', '')) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error=str(e) - ).inc() - - - -def GeoRed_Push_Async(json_data): - try: - if yaml_config['geored']['enabled'] == True: - if yaml_config['geored']['sync_endpoints'] is not None and len(yaml_config['geored']['sync_endpoints']) > 0: - transaction_id = str(uuid.uuid4()) - DBLogger.info("Pushing out data to GeoRed peers with transaction_id " + str(transaction_id) + " and JSON body: " + str(json_data)) - for remote_hss in yaml_config['geored']['sync_endpoints']: - GeoRed_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(remote_hss, json_data, transaction_id)) - GeoRed_Push_thread.start() - except Exception as E: - DBLogger.debug("Failed to push Async jobs due to error: " + str(E)) - -def Webhook_Push_Async(target, json_data): - transaction_id = str(uuid.uuid4()) - Webook_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(target, json_data, transaction_id)) - Webook_Push_thread.start() - -def Sanitize_Datetime(result): - for keys in result: - if "timestamp" in keys: - if result[keys] == None: - continue - else: - DBLogger.debug("Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String") - result[keys] = str(result[keys]) - return result - -def Sanitize_Keys(result): - names_to_strip = ['opc', 'ki', 'des', 'kid', 'psk', 'adm1'] - for name_to_strip in names_to_strip: - try: - result.pop(name_to_strip) - except: - pass - return result - -def GetObj(obj_type, obj_id=None, page=None, page_size=None): - DBLogger.debug("Called GetObj for type " + str(obj_type)) - - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - - try: - if obj_id is not None: - result = session.query(obj_type).get(obj_id) - if result is None: - raise ValueError(f"No {obj_type} found with id {obj_id}") - - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - elif page is not None and page_size is not None: - if page < 1 or page_size < 1: - raise ValueError("page and page_size should be positive integers") - - offset = (page - 1) * page_size - results = ( - session.query(obj_type) - .order_by(obj_type.id) # Assuming obj_type has an attribute 'id' - .offset(offset) - .limit(page_size) - .all() - ) - - result = [] - for item in results: - item_dict = item.__dict__ - item_dict.pop('_sa_instance_state') - result.append(Sanitize_Datetime(item_dict)) - else: - raise ValueError("Provide either obj_id or both page and page_size") - - except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - safe_close(session) - return result - -def GetAll(obj_type): - DBLogger.debug("Called GetAll for type " + str(obj_type)) - - Base.metadata.create_all(engine) - Session = sessionmaker(bind = engine) - session = Session() - final_result_list = [] - - try: - result = session.query(obj_type) - except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - for record in result: - record = record.__dict__ - record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) - final_result_list.append(record) - - safe_close(session) - return final_result_list - -def getAllPaginated(obj_type, page=0, page_size=0, existingSession=None): - DBLogger.debug("Called getAllPaginated for type " + str(obj_type)) - - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - - final_result_list = [] - - try: - # Query object type - result = session.query(obj_type) - - # Apply pagination - if page_size != 0: - result = result.limit(page_size).offset(page * page_size) - - result = result.all() - - for record in result: - record = record.__dict__ - record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) - final_result_list.append(record) - - safe_close(session) - return final_result_list - - except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - -def GetAllByTable(obj_type, table): - DBLogger.debug(f"Called GetAll for type {str(obj_type)} and table {table}") - - Base.metadata.create_all(engine) - Session = sessionmaker(bind = engine) - session = Session() - final_result_list = [] - - try: - result = session.query(obj_type).filter_by(table_name=str(table)) - except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - for record in result: - record = record.__dict__ - record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) - final_result_list.append(record) - - safe_close(session) - return final_result_list - -def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=None): - DBLogger.debug(f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}") - Session = sessionmaker(bind=engine) - session = Session() - obj_type_str = str(obj_type.__table__.name).upper() - DBLogger.debug(f"obj_type_str is {obj_type_str}") - filter_input = eval(obj_type_str + "." + obj_type_str.lower() + "_id==obj_id") - try: - obj = session.query(obj_type).filter(filter_input).one() - for key, value in json_data.items(): - if hasattr(obj, key): - setattr(obj, key, value) - setattr(obj, "last_modified", datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z') - except Exception as E: - DBLogger.error(f"Failed to query or update object, error: {E}") - raise ValueError(E) - try: - session.info["operation_id"] = operation_id # Pass the operation id - try: - if not disable_logging: - log_changes_before_commit(session) - objectData = GetObj(obj_type, obj_id) - session.commit() - handle_external_webhook(objectData, 'UPDATE') - except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) - raise ValueError(E) - except Exception as E: - DBLogger.error(f"Exception in UpdateObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) - - return GetObj(obj_type, obj_id) - -def DeleteObj(obj_type, obj_id, disable_logging=False, operation_id=None): - DBLogger.debug(f"Called DeleteObj for type {obj_type} with id {obj_id}") - - Session = sessionmaker(bind=engine) - session = Session() - - try: - res = session.query(obj_type).get(obj_id) - if res is None: - raise ValueError("The specified row does not exist") - objectData = GetObj(obj_type, obj_id) - session.delete(res) - session.info["operation_id"] = operation_id # Pass the operation id - try: - if not disable_logging: - log_changes_before_commit(session) - session.commit() - handle_external_webhook(objectData, 'DELETE') - except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) - raise ValueError(E) - - except Exception as E: - DBLogger.error(f"Exception in DeleteObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) - - return {'Result': 'OK'} - - -def CreateObj(obj_type, json_data, disable_logging=False, operation_id=None): - DBLogger.debug("Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data)) - last_modified_value = datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z' - json_data["last_modified"] = last_modified_value # set last_modified value in json_data - newObj = obj_type(**json_data) - Session = sessionmaker(bind=engine) - session = Session() - - session.add(newObj) - try: - session.info["operation_id"] = operation_id # Pass the operation id - try: - if not disable_logging: - log_changes_before_commit(session) - session.commit() - except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) - raise ValueError(E) - session.refresh(newObj) - result = newObj.__dict__ - DBLogger.debug("Created new object OK") - result.pop('_sa_instance_state') - handle_external_webhook(result, 'CREATE') - return result - except Exception as E: - DBLogger.error(f"Exception in CreateObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) - -def Generate_JSON_Model_for_Flask(obj_type): - DBLogger.debug("Generating JSON model for Flask for object type: " + str(obj_type)) - - dictty = dict(generate_json_schema(obj_type)) - pprint.pprint(dictty) - - - #dictty['properties'] = dict(dictty['properties']) - - # Exclude 'table_name' column from the properties - if 'properties' in dictty: - dictty['properties'].pop('discriminator', None) - dictty['properties'].pop('last_modified', None) - - - # Set the ID Object to not required - obj_type_str = str(dictty['title']).lower() - dictty['required'].remove(obj_type_str + '_id') - - return dictty - -def Get_AuC(**kwargs): - #Get AuC data by IMSI or ICCID - - Session = sessionmaker(bind = engine) - session = Session() - - if 'iccid' in kwargs: - DBLogger.debug("Get_AuC for iccid " + str(kwargs['iccid'])) - try: - result = session.query(AUC).filter_by(iccid=str(kwargs['iccid'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - elif 'imsi' in kwargs: - DBLogger.debug("Get_AuC for imsi " + str(kwargs['imsi'])) - try: - result = session.query(AUC).filter_by(imsi=str(kwargs['imsi'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - DBLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result - -def Get_IMS_Subscriber(**kwargs): - #Get subscriber by IMSI or MSISDN - Session = sessionmaker(bind = engine) - session = Session() - if 'msisdn' in kwargs: - DBLogger.debug("Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn'])) - try: - result = session.query(IMS_SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - elif 'imsi' in kwargs: - DBLogger.debug("Get_IMS_Subscriber for imsi " + str(kwargs['imsi'])) - try: - result = session.query(IMS_SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - DBLogger.debug("Converting result to dict") - result = result.__dict__ - try: - result.pop('_sa_instance_state') - except: - pass - result = Sanitize_Datetime(result) - DBLogger.debug("Returning IMS Subscriber Data: " + str(result)) - safe_close(session) - return result - -def Get_Subscriber(**kwargs): - #Get subscriber by IMSI or MSISDN - - Session = sessionmaker(bind = engine) - session = Session() - - if 'msisdn' in kwargs: - DBLogger.debug("Get_Subscriber for msisdn " + str(kwargs['msisdn'])) - try: - result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - elif 'imsi' in kwargs: - DBLogger.debug("Get_Subscriber for imsi " + str(kwargs['imsi'])) - try: - result = session.query(SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - if 'get_attributes' in kwargs: - if kwargs['get_attributes'] == True: - attributes = Get_Subscriber_Attributes(result['subscriber_id']) - result['attributes'] = attributes - - DBLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result - -def Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id): - Session = sessionmaker(bind = engine) - session = Session() - - DBLogger.debug("Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id)) - try: - result = session.query(SUBSCRIBER_ROUTING).filter_by(subscriber_id=subscriber_id, apn_id=apn_id).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - DBLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result - -def Get_Subscriber_Attributes(subscriber_id): - #Get subscriber attributes - - Session = sessionmaker(bind = engine) - session = Session() - - DBLogger.debug("Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id)) - try: - result = session.query(SUBSCRIBER_ATTRIBUTES).filter_by(subscriber_id=subscriber_id) - except Exception as E: - safe_close(session) - raise ValueError(E) - final_res = [] - for record in result: - result = record.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - final_res.append(result) - DBLogger.debug("Got back result: " + str(final_res)) - safe_close(session) - return final_res - -def Get_Served_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this HSS") - - Session = sessionmaker(bind = engine) - session = Session() - - Served_Subs = {} - try: - results = session.query(SUBSCRIBER).filter(SUBSCRIBER.serving_mme.isnot(None)) - for result in results: - result = result.__dict__ - DBLogger.debug("Result: " + str(result) + " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - if get_local_users_only == True: - DBLogger.debug("Filtering to locally served IMS Subs only") - try: - serving_hss = result['serving_mme_peer'].split(';')[1] - DBLogger.debug("Serving HSS: " + str(serving_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_hss == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving HSS matches local HSS") - Served_Subs[result['imsi']] = {} - Served_Subs[result['imsi']] = result - #DBLogger.debug("Processed result") - continue - else: - DBLogger.debug("Sub is served by remote HSS: " + str(serving_hss)) - except Exception as E: - DBLogger.debug("Error in filtering Get_Served_Subscribers to local peer only: " + str(E)) - continue - else: - Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") - - - except Exception as E: - safe_close(session) - raise ValueError(E) - DBLogger.debug("Final Served_Subs: " + str(Served_Subs)) - safe_close(session) - return Served_Subs - -def Get_Served_IMS_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this IMS-HSS") - Session = sessionmaker(bind=engine) - session = Session() - - Served_Subs = {} - try: - - results = session.query(IMS_SUBSCRIBER).filter( - IMS_SUBSCRIBER.scscf.isnot(None)) - for result in results: - result = result.__dict__ - DBLogger.debug("Result: " + str(result) + - " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - if get_local_users_only == True: - DBLogger.debug("Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only") - try: - serving_ims_hss = result['scscf_peer'].split(';')[1] - DBLogger.debug("Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_ims_hss == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving IMS-HSS matches local HSS for " + str(result['imsi'])) - Served_Subs[result['imsi']] = {} - Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") - continue - else: - DBLogger.debug("Sub is served by remote IMS-HSS: " + str(serving_ims_hss)) - except Exception as E: - DBLogger.debug("Error in filtering to local peer only: " + str(E)) - continue - else: - Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") - - except Exception as E: - safe_close(session) - raise ValueError(E) - DBLogger.debug("Final Served_Subs: " + str(Served_Subs)) - safe_close(session) - return Served_Subs - -def Get_Served_PCRF_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this PCRF") - Session = sessionmaker(bind=engine) - session = Session() - Served_Subs = {} - try: - results = session.query(SERVING_APN).all() - for result in results: - result = result.__dict__ - DBLogger.debug("Result: " + str(result) + " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - if get_local_users_only == True: - DBLogger.debug("Filtering to locally served IMS Subs only") - try: - serving_pcrf = result['serving_pgw_peer'].split(';')[1] - DBLogger.debug("Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_pcrf == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving PCRF matches local PCRF") - DBLogger.debug("Processed result") - - else: - DBLogger.debug("Sub is served by remote PCRF: " + str(serving_pcrf)) - continue - except Exception as E: - DBLogger.debug("Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E)) - continue - - # Get APN Info - apn_info = GetObj(APN, result['apn']) - #DBLogger.debug("Got APN Info: " + str(apn_info)) - result['apn_info'] = apn_info - - # Get Subscriber Info - subscriber_info = GetObj(SUBSCRIBER, result['subscriber_id']) - result['subscriber_info'] = subscriber_info - - #DBLogger.debug("Got Subscriber Info: " + str(subscriber_info)) - - Served_Subs[subscriber_info['imsi']] = result - DBLogger.debug("Processed result") - except Exception as E: - raise ValueError(E) - #DBLogger.debug("Final SERVING_APN: " + str(Served_Subs)) - safe_close(session) - return Served_Subs - -def Get_Vectors_AuC(auc_id, action, **kwargs): - DBLogger.debug("Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action)) - key_data = GetObj(AUC, auc_id) - vector_dict = {} - - if action == "air": - rand, xres, autn, kasme = S6a_crypt.generate_eutran_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) - vector_dict['rand'] = rand - vector_dict['xres'] = xres - vector_dict['autn'] = autn - vector_dict['kasme'] = kasme - - #Incriment SQN - Update_AuC(auc_id, sqn=key_data['sqn']+100) - - return vector_dict - - elif action == "sqn_resync": - DBLogger.debug("Resync SQN") - rand = kwargs['rand'] - sqn, mac_s = S6a_crypt.generate_resync_s6a(key_data['ki'], key_data['opc'], key_data['amf'], kwargs['auts'], rand) - DBLogger.debug("SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")") - Update_AuC(auc_id, sqn=sqn+100) - return - - elif action == "sip_auth": - rand, autn, xres, ck, ik = S6a_crypt.generate_maa_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) - DBLogger.debug("RAND is: " + str(rand)) - DBLogger.debug("AUTN is: " + str(autn)) - vector_dict['SIP_Authenticate'] = rand + autn - vector_dict['xres'] = xres - vector_dict['ck'] = ck - vector_dict['ik'] = ik - Update_AuC(auc_id, sqn=key_data['sqn']+100) - return vector_dict - - elif action == "Digest-MD5": - DBLogger.debug("Generating Digest-MD5 Auth vectors") - DBLogger.debug("key_data: " + str(key_data)) - nonce = uuid.uuid4().hex - #nonce = "beef4d878f2642ed98afe491b943ca60" - vector_dict['nonce'] = nonce - vector_dict['SIP_Authenticate'] = key_data['ki'] - return vector_dict - -def Get_APN(apn_id): - DBLogger.debug("Getting APN " + str(apn_id)) - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(APN).filter_by(apn_id=apn_id).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - safe_close(session) - return result - -def Get_APN_by_Name(apn): - DBLogger.debug("Getting APN named " + str(apn_id)) - Session = sessionmaker(bind = engine) - session = Session() - try: - result = session.query(APN).filter_by(apn=str(apn)).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - safe_close(session) - return result - -def Update_AuC(auc_id, sqn=1): - DBLogger.debug("Updating AuC record for sub " + str(auc_id)) - DBLogger.debug(UpdateObj(AUC, {'sqn': sqn}, auc_id, True)) - return - -def Update_Serving_MME(imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, propagate=True): - DBLogger.debug("Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme)) - Session = sessionmaker(bind = engine) - session = Session() - try: - result = session.query(SUBSCRIBER).filter_by(imsi=imsi).one() - if yaml_config['hss']['CancelLocationRequest_Enabled'] == True: - DBLogger.debug("Evaluating if we should trigger sending a CLR.") - serving_hss = str(result.serving_mme_peer).split(';',1)[1] - serving_mme_peer = str(result.serving_mme_peer).split(';',1)[0] - DBLogger.debug("Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer)) - DBLogger.debug("Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer)) - #Evaluate if we need to send a CLR to the old MME - if result.serving_mme != None: - if str(result.serving_mme) == str(serving_mme): - DBLogger.debug("This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR") - elif (str(result.serving_mme) != str(serving_mme)): - DBLogger.debug("There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR") - if serving_hss != yaml_config['hss']['OriginHost']: - DBLogger.debug("This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss)) - URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) - else: - DBLogger.debug("This subscriber is served by this HSS we need to send a CLR to old MME from this HSS") - - URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) - DBLogger.debug("Sending CLR to API at " + str(URL)) - json_data = { - "DestinationRealm": result.serving_mme_realm, - "DestinationHost": result.serving_mme, - "cancellationType": 2, - "diameterPeer": serving_mme_peer, - } - - DBLogger.debug("Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data)) - transaction_id = str(uuid.uuid4()) - GeoRed_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(serving_hss, json_data, transaction_id, URL)) - GeoRed_Push_thread.start() - else: - #No currently serving MME - No action to take - DBLogger.debug("No currently serving MME - No need to send CLR") - - if type(serving_mme) == str: - DBLogger.debug("Updating serving MME & Timestamp") - result.serving_mme = serving_mme - result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) - result.serving_mme_realm = serving_mme_realm - result.serving_mme_peer = serving_mme_peer - else: - #Clear values - DBLogger.debug("Clearing serving MME") - result.serving_mme = None - result.serving_mme_timestamp = None - result.serving_mme_realm = None - result.serving_mme_peer = None - - session.commit() - objectData = GetObj(SUBSCRIBER, result.subscriber_id) - handle_external_webhook(objectData, 'UPDATE') - - #Sync state change with geored - if propagate == True: - if 'HSS' in yaml_config['geored'].get('sync_actions', []) and yaml_config['geored'].get('enabled', False) == True: - DBLogger.debug("Propagate MME changes to Geographic PyHSS instances") - GeoRed_Push_Async({ - "imsi": str(imsi), - "serving_mme": result.serving_mme, - "serving_mme_realm": str(result.serving_mme_realm), - "serving_mme_peer": str(result.serving_mme_peer) - }) - else: - DBLogger.debug("Config does not allow sync of HSS events") - except Exception as E: - DBLogger.error("Error occurred, rolling back session: " + str(E)) - raise - finally: - safe_close(session) - -def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): - DBLogger.debug("Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer)) - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() - try: - assert(type(serving_cscf) == str) - assert(len(serving_cscf) > 0) - DBLogger.debug("Setting serving CSCF") - #Strip duplicate SIP prefix before storing - serving_cscf = serving_cscf.replace("sip:sip:", "sip:") - result.scscf = serving_cscf - result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) - result.scscf_realm = scscf_realm - result.scscf_peer = str(scscf_peer) - except: - #Clear values - DBLogger.debug("Clearing serving CSCF") - result.scscf = None - result.scscf_timestamp = None - result.scscf_realm = None - result.scscf_peer = None - - session.commit() - objectData = GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) - handle_external_webhook(objectData, 'UPDATE') - - #Sync state change with geored - if propagate == True: - if 'IMS' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate IMS changes to Geographic PyHSS instances") - GeoRed_Push_Async({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) - else: - DBLogger.debug("Config does not allow sync of IMS events") - except Exception as E: - DBLogger.error("An error occurred, rolling back session: " + str(E)) - safe_rollback(session) - raise - finally: - safe_close(session) - -def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): - DBLogger.debug("Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn)) - DBLogger.debug("PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing)) - DBLogger.debug("Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer)) - DBLogger.debug("subscriber_routing: " + str(subscriber_routing)) - - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - subscriber_id = subscriber_details['subscriber_id'] - - #Split the APN list into a list - apn_list = subscriber_details['apn_list'].split(',') - DBLogger.debug("Current APN List: " + str(apn_list)) - #Remove the default APN from the list - try: - apn_list.remove(str(subscriber_details['default_apn'])) - except: - DBLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") - pass - #Add default APN in first position - apn_list.insert(0, str(subscriber_details['default_apn'])) - - #Get APN ID from APN - for apn_id in apn_list: - #Get each APN in List - apn_data = Get_APN(apn_id) - DBLogger.debug(apn_data) - if str(apn_data['apn']).lower() == str(apn).lower(): - DBLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) - break - DBLogger.debug("APN ID is " + str(apn_id)) - - json_data = { - 'apn' : apn_id, - 'subscriber_id' : subscriber_id, - 'pcrf_session_id' : str(pcrf_session_id), - 'serving_pgw' : str(serving_pgw), - 'serving_pgw_realm' : str(serving_pgw_realm), - 'serving_pgw_peer' : str(serving_pgw_peer), - 'serving_pgw_timestamp' : datetime.datetime.now(tz=timezone.utc), - 'subscriber_routing' : str(subscriber_routing) - } - - try: - #Check if already a serving APN on record - DBLogger.debug("Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id)) - ServingAPN = Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) - assert(ServingAPN is not None) - DBLogger.debug("Existing Serving APN ID on record, updating") - try: - assert(type(serving_pgw) == str) - assert(len(serving_pgw) > 0) - assert("None" not in serving_pgw) - - UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'UPDATE') - except: - DBLogger.debug("Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id'])) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'DELETE') - DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) - except Exception as E: - DBLogger.info("Failed to update existing APN " + str(E)) - #Create if does not exist - ServingAPN = CreateObj(SERVING_APN, json_data, True) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'CREATE') - - #Sync state change with geored - if propagate == True: - try: - if 'PCRF' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate PCRF changes to Geographic PyHSS instances") - GeoRed_Push_Async({"imsi": str(imsi), - 'serving_apn' : str(apn), - 'pcrf_session_id': str(pcrf_session_id), - 'serving_pgw': str(serving_pgw), - 'serving_pgw_realm': str(serving_pgw_realm), - 'serving_pgw_peer': str(serving_pgw_peer), - 'subscriber_routing': str(subscriber_routing) - }) - else: - DBLogger.debug("Config does not allow sync of PCRF events") - except Exception as E: - DBLogger.debug("Nothing synced to Geographic PyHSS instances for event PCRF") - - - return - -def Get_Serving_APN(subscriber_id, apn_id): - DBLogger.debug("Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id)) - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(SERVING_APN).filter_by(subscriber_id=subscriber_id, apn=apn_id).first() - if result is None: - return result - result = result.__dict__ - result.pop('_sa_instance_state') - - except Exception as E: - DBLogger.debug(E) - safe_close(session) - raise ValueError(E) - - safe_close(session) - return result - -def Get_Charging_Rule(charging_rule_id): - DBLogger.debug("Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id)) - Session = sessionmaker(bind = engine) - session = Session() - #Get base Rule - ChargingRule = GetObj(CHARGING_RULE, charging_rule_id) - ChargingRule['tft'] = [] - #Get TFTs - try: - results = session.query(TFT).filter_by(tft_group_id=ChargingRule['tft_group_id']) - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - ChargingRule['tft'].append(result) - except Exception as E: - safe_close(session) - raise ValueError(E) - safe_close(session) - return ChargingRule - -def Get_Charging_Rules(imsi, apn): - DBLogger.debug("Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn)) - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - - #Split the APN list into a list - apn_list = subscriber_details['apn_list'].split(',') - DBLogger.debug("Current APN List: " + str(apn_list)) - #Remove the default APN from the list - try: - apn_list.remove(str(subscriber_details['default_apn'])) - except: - DBLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") - pass - #Add default APN in first position - apn_list.insert(0, str(subscriber_details['default_apn'])) - - #Get APN ID from APN - for apn_id in apn_list: - DBLogger.debug("Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn)) - #Get each APN in List - apn_data = Get_APN(apn_id) - DBLogger.debug(apn_data) - if str(apn_data['apn']).lower() == str(apn).lower(): - DBLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) - - DBLogger.debug("Getting charging rule list from " + str(apn_data['charging_rule_list'])) - ChargingRule = {} - ChargingRule['charging_rule_list'] = str(apn_data['charging_rule_list']).split(',') - ChargingRule['apn_data'] = apn_data - - #Get Charging Rules list - if apn_data['charging_rule_list'] == None: - DBLogger.debug("No Charging Rule associated with this APN") - ChargingRule['charging_rules'] = None - return ChargingRule - - DBLogger.debug("ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list'])) - #Empty dict for the Charging Rules to go into - ChargingRule['charging_rules'] = [] - #Add each of the Charging Rules for the APN - for individual_charging_rule in ChargingRule['charging_rule_list']: - DBLogger.debug("Getting Charging rule " + str(individual_charging_rule)) - individual_charging_rule_complete = Get_Charging_Rule(individual_charging_rule) - DBLogger.debug("Got individual_charging_rule_complete: " + str(individual_charging_rule_complete)) - ChargingRule['charging_rules'].append(individual_charging_rule_complete) - DBLogger.debug("Completed Get_Charging_Rules()") - DBLogger.debug(ChargingRule) - return ChargingRule - -def Get_UE_by_IP(subscriber_routing): - DBLogger.debug("Called Get_UE_by_IP() for IP " + str(subscriber_routing)) - - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(SERVING_APN).filter_by(subscriber_routing=subscriber_routing).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - return result - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - -def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): - #IMSI 14-15 Digits - #IMEI 15 Digits - #IMEI-SV 2 Digits - DBLogger.debug("Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code)) - if yaml_config['eir']['imsi_imei_logging'] != True: - DBLogger.debug("Skipping storing binding") - return - #Concat IMEI + IMSI - imsi_imei = str(imsi) + "," + str(imei) - Session = sessionmaker(bind = engine) - session = Session() - - #Check if exist already & update - try: - session.query(IMSI_IMEI_HISTORY).filter_by(imsi_imei=imsi_imei).one() - DBLogger.debug("Entry already present for IMSI/IMEI Combo") - safe_close(session) - return - except Exception as E: - newObj = IMSI_IMEI_HISTORY(imsi_imei=imsi_imei, match_response_code=match_response_code, imsi_imei_timestamp = datetime.datetime.now(tz=timezone.utc)) - session.add(newObj) - try: - session.commit() - except Exception as E: - DBLogger.error("Failed to commit session, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - safe_close(session) - DBLogger.debug("Added new IMSI_IMEI_HISTORY binding") - - if 'sim_swap_notify_webhook' in yaml_config['eir']: - DBLogger.debug("Sending SIM Swap notification to Webhook") - try: - dictToSend = {'imei':imei, 'imsi': imsi, 'match_response_code': match_response_code} - Webhook_Push_Async(str(yaml_config['eir']['sim_swap_notify_webhook']), json_data=dictToSend) - except Exception as E: - DBLogger.debug("Failed to post to Webhook") - DBLogger.debug(str(E)) - - #Lookup Device Info - if 'tac_database_csv' in yaml_config['eir']: - try: - device_info = get_device_info_from_TAC(imei=str(imei)) - DBLogger.debug("Got Device Info: " + str(device_info)) - prom_eir_devices.labels( - imei_prefix=device_info['tac_prefix'], - device_type=device_info['name'], - device_name=device_info['model'] - ).inc() - except Exception as E: - DBLogger.debug("Failed to get device info from TAC") - prom_eir_devices.labels( - imei_prefix=str(imei)[0:8], - device_type='Unknown', - device_name='Unknown' - ).inc() - else: - DBLogger.debug("No TAC database configured, skipping device info lookup") - - #Sync state change with geored - if propagate == True: - try: - if 'EIR' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate EIR changes to Geographic PyHSS instances") - GeoRed_Push_Async( - {"imsi": str(imsi), - "imei": str(imei), - "match_response_code": str(match_response_code)} - ) - else: - DBLogger.debug("Config does not allow sync of EIR events") - except Exception as E: - DBLogger.debug("Nothing synced to Geographic PyHSS instances for EIR event") - DBLogger.debug(E) - - return - -def Get_IMEI_IMSI_History(attribute): - DBLogger.debug("Called Get_IMEI_IMSI_History() for entry matching " + str(Get_IMEI_IMSI_History)) - Session = sessionmaker(bind = engine) - session = Session() - result_array = [] - try: - results = session.query(IMSI_IMEI_HISTORY).filter(IMSI_IMEI_HISTORY.imsi_imei.ilike("%" + str(attribute) + "%")).all() - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - try: - result['imsi'] = result['imsi_imei'].split(",")[0] - except: - continue - try: - result['imei'] = result['imsi_imei'].split(",")[1] - except: - continue - result_array.append(result) - safe_close(session) - return result_array - except Exception as E: - safe_close(session) - raise ValueError(E) - -def Check_EIR(imsi, imei): - eir_response_code_table = {0 : 'Whitelist', 1: 'Blacklist', 2: 'Greylist'} - DBLogger.debug("Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei)) - Session = sessionmaker(bind = engine) - session = Session() - #Check for Exact Matches - DBLogger.debug("Looking for exact matches") - #Check for exact Matches - try: - results = session.query(EIR).filter_by(imei=str(imei), regex_mode=0) - for result in results: - result = result.__dict__ - match_response_code = result['match_response_code'] - if result['imsi'] == '': - DBLogger.debug("No IMSI specified in DB, so matching only on IMEI") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - elif result['imsi'] == str(imsi): - DBLogger.debug("Matched on IMEI and IMSI") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - DBLogger.debug("Did not match any Exact Matches - Checking Regex") - try: - results = session.query(EIR).filter_by(regex_mode=1) #Get all Regex records from DB - for result in results: - result = result.__dict__ - match_response_code = result['match_response_code'] - if re.match(result['imei'], imei): - DBLogger.debug("IMEI matched " + str(result['imei'])) - #Check if IMSI also specified - if len(result['imsi']) != 0: - DBLogger.debug("With IMEI matched, now checking if IMSI matches regex") - if re.match(result['imsi'], imsi): - DBLogger.debug("IMSI also matched, so match OK!") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - else: - DBLogger.debug("No IMSI specified, so match OK!") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - try: - session.commit() - except Exception as E: - DBLogger.error("Failed to commit session, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - DBLogger.debug("No matches at all - Returning default response") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=yaml_config['eir']['no_match_response']) - safe_close(session) - return yaml_config['eir']['no_match_response'] - -def Get_EIR_Rules(): - DBLogger.debug("Getting all EIR Rules") - Session = sessionmaker(bind = engine) - session = Session() - EIR_Rules = [] - try: - results = session.query(EIR) - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - EIR_Rules.append(result) - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - DBLogger.debug("Final EIR_Rules: " + str(EIR_Rules)) - safe_close(session) - return EIR_Rules - - -def dict_bytes_to_dict_string(dict_bytes): - dict_string = {} - for key, value in dict_bytes.items(): - dict_string[key.decode()] = value.decode() - return dict_string - - -def get_device_info_from_TAC(imei): - DBLogger.debug("Getting Device Info from IMEI: " + str(imei)) - #Try 8 digit TAC - try: - DBLogger.debug("Trying to match on 8 Digit IMEI") - imei_result = logtool.RedisHMGET(str(imei[0:8])) - print("Got back: " + str(imei_result)) - imei_result = dict_bytes_to_dict_string(imei_result) - assert(len(imei_result) != 0) - DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - return imei_result - except: - DBLogger.debug("Failed to match on 8 digit IMEI") - - try: - DBLogger.debug("Trying to match on 6 Digit IMEI") - imei_result = logtool.RedisHMGET(str(imei[0:6])) - print("Got back: " + str(imei_result)) - imei_result = dict_bytes_to_dict_string(imei_result) - assert(len(imei_result) != 0) - DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - return imei_result - except: - DBLogger.debug("Failed to match on 6 digit IMEI") - - raise ValueError("No matching TAC in IMEI Database") - -if __name__ == "__main__": - import binascii,os,pprint - DeleteAfter = True - - #Define Charging Rule - charging_rule = { - 'rule_name' : 'charging_rule_A', - 'qci' : 4, - 'arp_priority' : 5, - 'arp_preemption_capability' : True, - 'arp_preemption_vulnerability' : False, - 'mbr_dl' : 128000, - 'mbr_ul' : 128000, - 'gbr_dl' : 128000, - 'gbr_ul' : 128000, - 'tft_group_id' : 1, - 'precedence' : 100, - 'rating_group' : 20000 - } - print("Creating Charging Rule A") - ChargingRule_newObj_A = CreateObj(CHARGING_RULE, charging_rule) - print("ChargingRule_newObj A: " + str(ChargingRule_newObj_A)) - charging_rule['gbr_ul'], charging_rule['gbr_dl'], charging_rule['mbr_ul'], charging_rule['mbr_dl'] = 256000, 256000, 256000, 256000 - print("Creating Charging Rule B") - charging_rule['rule_name'], charging_rule['precedence'], charging_rule['tft_group_id'] = 'charging_rule_B', 80, 2 - ChargingRule_newObj_B = CreateObj(CHARGING_RULE, charging_rule) - print("ChargingRule_newObj B: " + str(ChargingRule_newObj_B)) - - #Define TFTs - tft_template1 = { - 'tft_group_id' : 1, - 'tft_string' : 'permit out ip from any to any', - 'direction' : 1 - } - tft_template2 = { - 'tft_group_id' : 1, - 'tft_string' : 'permit out ip from any to any', - 'direction' : 2 - } - print("Creating TFT") - CreateObj(TFT, tft_template1) - CreateObj(TFT, tft_template2) - - tft_template3 = { - 'tft_group_id' : 2, - 'tft_string' : 'permit out ip from 10.98.0.0 255.255.255.0 to any', - 'direction' : 1 - } - tft_template4 = { - 'tft_group_id' : 2, - 'tft_string' : 'permit out ip from any to 10.98.0.0 255.255.255.0', - 'direction' : 2 - } - print("Creating TFT") - CreateObj(TFT, tft_template3) - CreateObj(TFT, tft_template4) - - - apn2 = { - 'apn':'ims', - 'apn_ambr_dl' : 9999, - 'apn_ambr_ul' : 9999, - 'arp_priority': 1, - 'arp_preemption_capability' : False, - 'arp_preemption_vulnerability': True, - 'charging_rule_list' : str(ChargingRule_newObj_A['charging_rule_id']) + "," + str(ChargingRule_newObj_B['charging_rule_id']) - } - print("Creating APN " + str(apn2['apn'])) - newObj = CreateObj(APN, apn2) - print(newObj) - - print("Getting APN " + str(apn2['apn'])) - print(GetObj(APN, newObj['apn_id'])) - apn_id = newObj['apn_id'] - UpdatedObj = newObj - UpdatedObj['apn'] = 'UpdatedInUnitTest' - - print("Updating APN " + str(apn2['apn'])) - newObj = UpdateObj(APN, UpdatedObj, newObj['apn_id']) - print(newObj) - - #Create AuC - auc_json = { - "ki": binascii.b2a_hex(os.urandom(16)).zfill(16), - "opc": binascii.b2a_hex(os.urandom(16)).zfill(16), - "amf": "9000", - "sqn": 0 - } - print(auc_json) - print("Creating AuC entry") - newObj = CreateObj(AUC, auc_json) - print(newObj) - - #Get AuC - print("Getting AuC entry") - newObj = GetObj(AUC, newObj['auc_id']) - auc_id = newObj['auc_id'] - print(newObj) - - #Update AuC - print("Updating AuC entry") - newObj['sqn'] = newObj['sqn'] + 10 - newObj = UpdateObj(AUC, newObj, auc_id) - - #Generate Vectors - print("Generating Vectors") - Get_Vectors_AuC(auc_id, "air", plmn='12ff') - print(Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) - - - #Update AuC - Update_AuC(auc_id, sqn=100) - - #New Subscriber - subscriber_json = { - "imsi": "001001000000006", - "enabled": True, - "msisdn": "12345678", - "ue_ambr_dl": 999999, - "ue_ambr_ul": 999999, - "nam": 0, - "subscribed_rau_tau_timer": 600, - "auc_id" : auc_id, - "default_apn" : apn_id, - "apn_list" : apn_id - } - - #Delete IMSI if already exists - try: - existing_sub_data = Get_Subscriber(imsi=subscriber_json['imsi']) - DeleteObj(SUBSCRIBER, existing_sub_data['subscriber_id']) - except: - print("Did not find old sub to delete") - - print("Creating new Subscriber") - print(subscriber_json) - newObj = CreateObj(SUBSCRIBER, subscriber_json) - print(newObj) - subscriber_id = newObj['subscriber_id'] - - #Get SUBSCRIBER - print("Getting Subscriber") - newObj = GetObj(SUBSCRIBER, subscriber_id) - print(newObj) - - #Update SUBSCRIBER - print("Updating Subscriber") - newObj['ue_ambr_ul'] = 999995 - newObj = UpdateObj(SUBSCRIBER, newObj, subscriber_id) - - #Set MME Location for Subscriber - print("Updating Serving MME for Subscriber") - Update_Serving_MME(imsi=newObj['imsi'], serving_mme="Test123", serving_mme_peer="Test123", serving_mme_realm="TestRealm") - - #Update Serving APN for Subscriber - print("Updating Serving APN for Subscriber") - Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='kjsdlkjfd', serving_pgw='pgw.test.com', subscriber_routing='1.2.3.4') - - print("Getting Charging Rule for Subscriber / APN Combo") - ChargingRule = Get_Charging_Rules(imsi=newObj['imsi'], apn=apn2['apn']) - pprint.pprint(ChargingRule) - - #New IMS Subscriber - ims_subscriber_json = { - "msisdn": newObj['msisdn'], - "msisdn_list": newObj['msisdn'], - "imsi": subscriber_json['imsi'], - "ifc_path" : "default_ifc.xml", - "sh_profile" : "default_sh_user_data.xml" - } - print(ims_subscriber_json) - newObj = CreateObj(IMS_SUBSCRIBER, ims_subscriber_json) - print(newObj) - ims_subscriber_id = newObj['ims_subscriber_id'] - - - #Test Get Subscriber - print("Test Getting Subscriber") - GetSubscriber_Result = Get_Subscriber(imsi=subscriber_json['imsi']) - print(GetSubscriber_Result) - - #Test IMS Get Subscriber - print("Getting IMS Subscribers") - print(Get_IMS_Subscriber(imsi='001001000000006')) - print(Get_IMS_Subscriber(msisdn='12345678')) - - #Set SCSCF for Subscriber - Update_Serving_CSCF(newObj['imsi'], "NickTestCSCF") - #Get Served Subscriber List - print(Get_Served_IMS_Subscribers()) - - #Clear Serving PGW for PCRF Subscriber - print("Clear Serving PGW for PCRF Subscriber") - Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='sessionid123', serving_pgw=None, subscriber_routing=None) - - #Clear MME Location for Subscriber - print("Clear MME Location for Subscriber") - Update_Serving_MME(newObj['imsi'], None) - - #Generate Vectors for IMS Subscriber - print("Generating Vectors for IMS Subscriber") - print(Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) - - #print("Generating Resync for IMS Subscriber") - #print(Get_Vectors_AuC(auc_id, "sqn_resync", auts='7964347dfdfe432289522183fcfb', rand='1bc9f096002d3716c65e4e1f4c1c0d17')) - - #Test getting APNs - GetAPN_Result = Get_APN(GetSubscriber_Result['default_apn']) - print(GetAPN_Result) - - #GeoRed_Push_Async({"imsi": "001001000000006", "serving_mme": "abc123"}) - - - if DeleteAfter == True: - #Delete IMS Subscriber - print(DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id)) - #Delete Subscriber - print(DeleteObj(SUBSCRIBER, subscriber_id)) - #Delete AuC - print(DeleteObj(AUC, auc_id)) - #Delete APN - print(DeleteObj(APN, apn_id)) - - #Whitelist IMEI / IMSI Binding - eir_template = {'imei': '1234', 'imsi': '567', 'regex_mode': 0, 'match_response_code': 0} - CreateObj(EIR, eir_template) - - #Blacklist Example - eir_template = {'imei': '99881232', 'imsi': '', 'regex_mode': 0, 'match_response_code': 1} - CreateObj(EIR, eir_template) - - #IMEI Prefix Regex Example (Blacklist all IMEIs starting with 666) - eir_template = {'imei': '^666.*', 'imsi': '', 'regex_mode': 1, 'match_response_code': 1} - CreateObj(EIR, eir_template) - - #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) - eir_template = {'imei': '^777.*', 'imsi': '^1234123412341234$', 'regex_mode': 1, 'match_response_code': 2} - CreateObj(EIR, eir_template) - - print("\n\n\n\n") - #Check Whitelist (No Match) - assert Check_EIR(imei='1234', imsi='') == 2 - - print("\n\n\n\n") - #Check Whitelist (Matched) - assert Check_EIR(imei='1234', imsi='567') == 0 - - print("\n\n\n\n") - #Check Blacklist (Match) - assert Check_EIR(imei='99881232', imsi='567') == 1 - - print("\n\n\n\n") - #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) - assert Check_EIR(imei='7771234', imsi='1234123412341234') == 2 - - print(Get_IMEI_IMSI_History('1234123412')) - - - print("\n\n\n") - print(Generate_JSON_Model_for_Flask(SUBSCRIBER)) - - - diff --git a/diameter.py b/diameter.py deleted file mode 100644 index 67d8573..0000000 --- a/diameter.py +++ /dev/null @@ -1,2603 +0,0 @@ -#Diameter Packet Decoder / Encoder & Tools -from multiprocessing.sharedctypes import Value -import socket -import logging -import sys -import binascii -import math -import uuid -import os -import random -import ipaddress -sys.path.append(os.path.realpath('lib')) -import S6a_crypt - -import jinja2 -import yaml -import time -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -#Setup Logging -import logtool -from logtool import * -logtool = logtool.LogTool() -logtool.setup_logger('DiameterLogger', yaml_config['logging']['logfiles']['diameter_logging_file'], level=yaml_config['logging']['level']) -DiameterLogger = logging.getLogger('DiameterLogger') - -DiameterLogger.info("Initialised Diameter Logger, importing database") -import database -DiameterLogger.info("Imported database") - -if yaml_config['redis']['enabled'] == True: - DiameterLogger.debug("Redis support enabled") - import redis - - -class Diameter: - ##Function Definitions - - - #Generates rounding for calculating padding - def myround(self, n, base=4): - if(n > 0): - return math.ceil(n/4.0) * 4 - elif( n < 0): - return math.floor(n/4.0) * 4 - else: - return 4 - - #Converts a dotted-decimal IPv4 address or IPV6 address to hex - def ip_to_hex(self, ip): - #Determine IPvX version: - if "." in ip: - ip = ip.split('.') - ip_hex = "0001" #IPv4 - ip_hex = ip_hex + str(format(int(ip[0]), 'x').zfill(2)) - ip_hex = ip_hex + str(format(int(ip[1]), 'x').zfill(2)) - ip_hex = ip_hex + str(format(int(ip[2]), 'x').zfill(2)) - ip_hex = ip_hex + str(format(int(ip[3]), 'x').zfill(2)) - else: - ip_hex = "0002" #IPv6 - ip_hex += format(ipaddress.IPv6Address(ip), 'X') - #DiameterLogger.debug("Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex)) - return ip_hex - - def hex_to_int(self, hex): - return int(str(hex), base=16) - - - #Converts a hex formatted IPv4 address or IPV6 address to dotted-decimal - def hex_to_ip(self, hex_ip): - if len(hex_ip) == 8: - octet_1 = int(str(hex_ip[0:2]), base=16) - octet_2 = int(str(hex_ip[2:4]), base=16) - octet_3 = int(str(hex_ip[4:6]), base=16) - octet_4 = int(str(hex_ip[6:8]), base=16) - return str(octet_1) + "." + str(octet_2) + "." + str(octet_3) + "." + str(octet_4) - elif len(hex_ip) == 32: - n=4 - ipv6_split = [hex_ip[idx:idx + n] for idx in range(0, len(hex_ip), n)] - ipv6_str = '' - for octect in ipv6_split: - ipv6_str += str(octect).lstrip('0') + ":" - #Strip last Colon - ipv6_str = ipv6_str[:-1] - return ipv6_str - - #Converts string to hex - def string_to_hex(self, string): - string_bytes = string.encode('utf-8') - return str(binascii.hexlify(string_bytes), 'ascii') - - #Converts int to hex padded to required number of bytes - def int_to_hex(self, input_int, output_bytes): - - return format(input_int,"x").zfill(output_bytes*2) - - #Converts Hex byte to Binary - def hex_to_bin(self, input_hex): - return bin(int(str(input_hex), 16))[2:].zfill(8) - - #Generates a valid random ID to use - def generate_id(self, length): - length = length * 2 - return str(uuid.uuid4().hex[:length]) - - def Reverse(self, str): - stringlength=len(str) - slicedString=str[stringlength::-1] - return (slicedString) - - def DecodePLMN(self, plmn): - DiameterLogger.debug("Decoded PLMN: " + str(plmn)) - mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') - DiameterLogger.debug("Decoded MCC: " + mcc) - - mnc = self.Reverse(plmn[4:6]) - DiameterLogger.debug("Decoded MNC: " + mnc) - return mcc, mnc - - def EncodePLMN(self, mcc, mnc): - plmn = list('XXXXXX') - plmn[0] = self.Reverse(mcc)[1] - plmn[1] = self.Reverse(mcc)[2] - plmn[2] = "f" - plmn[3] = self.Reverse(mcc)[0] - plmn[4] = self.Reverse(mnc)[0] - plmn[5] = self.Reverse(mnc)[1] - plmn_list = plmn - plmn = '' - for bits in plmn_list: - plmn = plmn + bits - DiameterLogger.debug("Encoded PLMN: " + str(plmn)) - return plmn - - def TBCD_special_chars(self, input): - DiameterLogger.debug("Special character possible in " + str(input)) - if input == "*": - DiameterLogger.debug("Found * - Returning 1010") - return "1010" - elif input == "#": - DiameterLogger.debug("Found # - Returning 1011") - return "1011" - elif input == "a": - DiameterLogger.debug("Found a - Returning 1100") - return "1100" - elif input == "b": - DiameterLogger.debug("Found b - Returning 1101") - return "1101" - elif input == "c": - DiameterLogger.debug("Found c - Returning 1100") - return "1100" - else: - binform = "{:04b}".format(int(input)) - DiameterLogger.debug("input " + str(input) + " is not a special char, converted to bin: " + str(binform)) - return (binform) - - def TBCD_encode(self, input): - DiameterLogger.debug("TBCD_encode input value is " + str(input)) - offset = 0 - output = '' - matches = ['*', '#', 'a', 'b', 'c'] - while offset < len(input): - if len(input[offset:offset+2]) == 2: - DiameterLogger.debug("processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset)) - bit = input[offset:offset+2] #Get two digits at a time - bit = bit[::-1] #Reverse them - #Check if *, #, a, b or c - if any(x in bit for x in matches): - DiameterLogger.debug("Special char in bit " + str(bit)) - new_bit = '' - new_bit = new_bit + str(self.TBCD_special_chars(bit[0])) - new_bit = new_bit + str(self.TBCD_special_chars(bit[1])) - DiameterLogger.debug("Final bin output of new_bit is " + str(new_bit)) - bit = hex(int(new_bit, 2))[2:] #Get Hex value - DiameterLogger.debug("Formatted as Hex this is " + str(bit)) - output = output + bit - offset = offset + 2 - else: - #If odd-length input - last_digit = str(input[offset:offset+2]) - #Check if *, #, a, b or c - if any(x in last_digit for x in matches): - DiameterLogger.debug("Special char in bit " + str(bit)) - new_bit = '' - new_bit = new_bit + '1111' #Add the F first - #Encode the symbol into binary and append it to the new_bit var - new_bit = new_bit + str(self.TBCD_special_chars(last_digit)) - DiameterLogger.debug("Final bin output of new_bit is " + str(new_bit)) - bit = hex(int(new_bit, 2))[2:] #Get Hex value - DiameterLogger.debug("Formatted as Hex this is " + str(bit)) - else: - bit = "f" + last_digit - offset = offset + 2 - output = output + bit - DiameterLogger.debug("TBCD_encode final output value is " + str(output)) - return output - - def TBCD_decode(self, input): - DiameterLogger.debug("TBCD_decode Input value is " + str(input)) - offset = 0 - output = '' - while offset < len(input): - if "f" not in input[offset:offset+2]: - bit = input[offset:offset+2] #Get two digits at a time - bit = bit[::-1] #Reverse them - output = output + bit - offset = offset + 2 - else: #If f in bit strip it - bit = input[offset:offset+2] - output = output + bit[1] - DiameterLogger.debug("TBCD_decode output value is " + str(output)) - return output - - #Hexify the vars we got when initializing the class - def __init__(self, OriginHost, OriginRealm, ProductName, MNC, MCC): - self.OriginHost = self.string_to_hex(OriginHost) - self.OriginRealm = self.string_to_hex(OriginRealm) - self.ProductName = self.string_to_hex(ProductName) - self.MNC = str(MNC) - self.MCC = str(MCC) - - DiameterLogger.info("Initialized Diameter for " + str(OriginHost) + " at Realm " + str(OriginRealm) + " serving as Product Name " + str(ProductName)) - DiameterLogger.info("PLMN is " + str(MCC) + "/" + str(MNC)) - - #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) - #AVP content must already be in HEX - This can be done with binascii.hexlify(avp_content.encode()) - def generate_avp(self, avp_code, avp_flags, avp_content, avps=None, packet_vars=None): - if avp_code == 268 or avp_code == 298 and packet_vars['command_code'] != 280: - try: - DiameterLogger.debug("Incrementing Prometheus Stats for prom_diam_result_code") - try: - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - except: - imsi = '' - - try: - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - except: - OriginHost = '' - DiameterLogger.debug("Generating result code: " + str(int(avp_content, 16)) + " for OriginHost: " + str(OriginHost) + " and IMSI: " + str(imsi)) - #Turn result code into int - prom_diam_result_code.labels( - diameter_application_id = packet_vars['ApplicationId'], - diameter_cmd_code = packet_vars['command_code'], - result_code = int(avp_content, 16), - endpoint = OriginHost, - imsi = imsi, - ).inc() - DiameterLogger.debug("Incremented Prometheus Stats for prom_diam_result_code") - except Exception as E: - DiameterLogger.debug("Failed to increment Prometheus Stats for prom_diam_result_code") - DiameterLogger.debug(E) - avp_code = format(avp_code,"x").zfill(8) - avp_length = 1 ##This is a placeholder that's overwritten later - - #AVP Must always be a multiple of 4 - Round up to nearest multiple of 4 and fill remaining bits with padding - avp = str(avp_code) + str(avp_flags) + str("000000") + str(avp_content) - avp_length = int(len(avp)/2) - - if avp_length % 4 == 0: #Multiple of 4 - No Padding needed - avp_padding = '' - else: #Not multiple of 4 - Padding needed - rounded_value = self.myround(avp_length) - avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) - - avp = str(avp_code) + str(avp_flags) + str(format(avp_length,"x").zfill(6)) + str(avp_content) + str(avp_padding) - return avp - - #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) - #AVP content must already be in HEX - This can be done with binascii.hexlify(avp_content.encode()) - def generate_vendor_avp(self, avp_code, avp_flags, avp_vendorid, avp_content): - avp_code = format(avp_code,"x").zfill(8) - - avp_length = 1 ##This is a placeholder that gets overwritten later - - avp_vendorid = format(int(avp_vendorid),"x").zfill(8) - - #AVP Must always be a multiple of 4 - Round up to nearest multiple of 4 and fill remaining bits with padding - avp = str(avp_code) + str(avp_flags) + str("000000") + str(avp_vendorid) + str(avp_content) - avp_length = int(len(avp)/2) - - if avp_length % 4 == 0: #Multiple of 4 - No Padding needed - avp_padding = '' - else: #Not multiple of 4 - Padding needed - rounded_value = self.myround(avp_length) - DiameterLogger.debug("Rounded value is " + str(rounded_value)) - DiameterLogger.debug("Has " + str( int( rounded_value - avp_length)) + " bytes of padding") - avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) - - - - avp = str(avp_code) + str(avp_flags) + str(format(avp_length,"x").zfill(6)) + str(avp_vendorid) + str(avp_content) + str(avp_padding) - return avp - - def generate_diameter_packet(self, packet_version, packet_flags, packet_command_code, packet_application_id, packet_hop_by_hop_id, packet_end_to_end_id, avp): - #Placeholder that is updated later on - packet_length = 228 - packet_length = format(packet_length,"x").zfill(6) - - packet_command_code = format(packet_command_code,"x").zfill(6) - - packet_application_id = format(packet_application_id,"x").zfill(8) - - packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp - packet_length = int(round(len(packet_hex))/2) - packet_length = format(packet_length,"x").zfill(6) - - packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp - return packet_hex - - def decode_diameter_packet(self, data): - packet_vars = {} - avps = [] - - if type(data) is bytes: - data = data.hex() - - - packet_vars['packet_version'] = data[0:2] - packet_vars['length'] = int(data[2:8], 16) - packet_vars['flags'] = data[8:10] - packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) - packet_vars['command_code'] = int(data[10:16], 16) - packet_vars['ApplicationId'] = int(data[16:24], 16) - packet_vars['hop-by-hop-identifier'] = data[24:32] - packet_vars['end-to-end-identifier'] = data[32:40] - - avp_sum = data[40:] - - avp_vars, remaining_avps = self.decode_avp_packet(avp_sum) - avps.append(avp_vars) - - while len(remaining_avps) > 0: - avp_vars, remaining_avps = self.decode_avp_packet(remaining_avps) - avps.append(avp_vars) - else: - pass - return packet_vars, avps - - def decode_avp_packet(self, data): - - if len(data) <= 8: - #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP - raise ValueError("Length of data is too short to be valid AVP") - - avp_vars = {} - avp_vars['avp_code'] = int(data[0:8], 16) - - avp_vars['avp_flags'] = data[8:10] - avp_vars['avp_length'] = int(data[10:16], 16) - if avp_vars['avp_flags'] == "c0": - #If c0 is present AVP is Vendor AVP - avp_vars['vendor_id'] = int(data[16:24], 16) - avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] - else: - #if is not a vendor AVP - avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] - - if avp_vars['avp_length'] % 4 == 0: - #Multiple of 4 - No Padding needed - avp_vars['padding'] = 0 - else: - #Not multiple of 4 - Padding needed - rounded_value = self.myround(avp_vars['avp_length']) - avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 - avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] - - - #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] - try: - sub_avp_vars, sub_remaining_avps = self.decode_avp_packet(avp_vars['misc_data']) - #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... - if int(sub_avp_vars['avp_code']) > 9999: - pass - else: - #If the decoded AVP is valid store it - avp_vars['misc_data'] = [] - avp_vars['misc_data'].append(sub_avp_vars) - #While there are more AVPs to be decoded, decode them: - while len(sub_remaining_avps) > 0: - sub_avp_vars, sub_remaining_avps = self.decode_avp_packet(sub_remaining_avps) - avp_vars['misc_data'].append(sub_avp_vars) - - except Exception as e: - if str(e) == "invalid literal for int() with base 16: ''": - logging.debug("AVP length 0 error") - pass - elif str(e) == "Length of data is too short to be valid AVP": - logging.debug("AVP length 0 error v2") - pass - else: - DiameterLogger.debug("failed to decode sub-avp - error: " + str(e)) - pass - - - remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again - return avp_vars, remaining_avps - - def get_avp_data(self, avps, avp_code): #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) - misc_data = [] - for keys in avps: - if keys['avp_code'] == avp_code: - misc_data.append(keys['misc_data']) - return misc_data - - def decode_diameter_packet_length(self, data): - packet_vars = {} - data = data.hex() - packet_vars['packet_version'] = data[0:2] - packet_vars['length'] = int(data[2:8], 16) - if packet_vars['packet_version'] == "01": - return packet_vars['length'] - else: - return False - - def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body - for avp_dicts in avps: - if avp_dicts['avp_code'] == 278: - origin_state_incriment_int = int(avp_dicts['misc_data'], 16) - origin_state_incriment_int = origin_state_incriment_int + 1 - origin_state_incriment_hex = format(origin_state_incriment_int,"x").zfill(8) - return origin_state_incriment_hex - - def Charging_Rule_Generator(self, ChargingRules, ue_ip): - DiameterLogger.debug("Called Charging_Rule_Generator") - #Install Charging Rules - DiameterLogger.info("Naming Charging Rule") - Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) - DiameterLogger.info("Named Charging Rule") - - #Populate all Flow Information AVPs - Flow_Information = '' - for tft in ChargingRules['tft']: - DiameterLogger.info(tft) - #If {{ UE_IP }} in TFT splice in the real UE IP Value - try: - tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) - tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) - DiameterLogger.info("Spliced in UE IP into TFT: " + str(tft['tft_string'])) - except Exception as E: - DiameterLogger.error("Failed to splice in UE IP into flow description") - - #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional - Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) - Flow_Description = self.generate_vendor_avp(507, "c0", 10415, str(binascii.hexlify(str.encode(tft['tft_string'])),'ascii')) - Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) - - Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) - DiameterLogger.info("Defined Flow_Status: " + str(Flow_Status)) - - DiameterLogger.info("Defining QoS information") - #QCI - QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) - - #ARP - DiameterLogger.info("Defining ARP information") - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_vulnerability']), 4)) - ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - - DiameterLogger.info("Defining MBR information") - #Max Requested Bandwidth - Bandwidth_info = '' - Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) - Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) - - DiameterLogger.info("Defining GBR information") - #GBR - if int(ChargingRules['gbr_ul']) != 0: - Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) - if int(ChargingRules['gbr_dl']) != 0: - Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) - DiameterLogger.info("Defined Bandwith Info: " + str(Bandwidth_info)) - - #Populate QoS Information - QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) - DiameterLogger.info("Defined QoS_Information: " + str(QoS_Information)) - - #Precedence - DiameterLogger.info("Defining Precedence information") - Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) - DiameterLogger.info("Defined Precedence " + str(Precedence)) - - #Rating Group - DiameterLogger.info("Defining Rating Group information") - if ChargingRules['rating_group'] != None: - RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID - else: - RatingGroup = '' - DiameterLogger.info("Defined Rating Group " + str(ChargingRules['rating_group'])) - - - #Complete Charging Rule Defintion - DiameterLogger.info("Collating ChargingRuleDef") - ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup - ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) - - #Charging Rule Install - DiameterLogger.info("Collating ChargingRuleDef") - return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) - - def Get_IMS_Subscriber_Details_from_AVP(self, username): - #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data - username = binascii.unhexlify(username).decode('utf-8') - DiameterLogger.info("Username AVP is present, value is " + str(username)) - username = username.split('@')[0] #Strip Domain to get User part - username = username[4:] #Strip tel: or sip: prefix - #Determine if dealing with IMSI or MSISDN - if (len(username) == 15) or (len(username) == 16): - DiameterLogger.debug("We have an IMSI: " + str(username)) - ims_subscriber_details = database.Get_IMS_Subscriber(imsi=username) - else: - DiameterLogger.debug("We have an msisdn: " + str(username)) - ims_subscriber_details = database.Get_IMS_Subscriber(msisdn=username) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) - return ims_subscriber_details - - - def Generate_Prom_Stats(self): - DiameterLogger.debug("Called Generate_Prom_Stats") - try: - prom_ims_subs_value = len(database.Get_Served_IMS_Subscribers(get_local_users_only=True)) - prom_ims_subs.set(prom_ims_subs_value) - prom_mme_subs_value = len(database.Get_Served_Subscribers(get_local_users_only=True)) - prom_mme_subs.set(prom_mme_subs_value) - prom_pcrf_subs_value = len(database.Get_Served_PCRF_Subscribers(get_local_users_only=True)) - prom_pcrf_subs.set(prom_pcrf_subs_value) - except Exception as e: - DiameterLogger.debug("Failed to generate Prometheus Stats for IMS Subscribers") - DiameterLogger.debug(e) - DiameterLogger.debug("Generated Prometheus Stats for IMS Subscribers") - - return - - - #### Diameter Answers #### - - #Capabilities Exchange Answer - def Answer_257(self, packet_vars, avps, recv_ip): - avp = '' #Initiate empty var AVP - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it - if avps_to_check['avp_code'] == 278: - avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) - for host in yaml_config['hss']['bind_ip']: #Loop through all IPs from Config and add to response - avp += self.generate_avp(257, 40, self.ip_to_hex(host)) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) - avp += self.generate_avp(266, 40, "00000000") #Vendor-Id - avp += self.generate_avp(269, "00", self.ProductName) #Product-Name - - avp += self.generate_avp(267, 40, "000027d9") #Firmware-Revision - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777252),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S13) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777291),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (SLh) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777236),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Rx) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777238),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Gx) - avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID - Diameter Gx - avp += self.generate_avp(258, 40, format(int(10),"x").zfill(8)) #Auth-Application-ID - Diameter CER - avp += self.generate_avp(265, 40, format(int(5535),"x").zfill(8)) #Supported-Vendor-ID (3GGP v2) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) - - response = self.generate_diameter_packet("01", "00", 257, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated CEA") - return response - - #Device Watchdog Answer - def Answer_280(self, packet_vars, avps): - - avp = '' #Initiate empty var AVP - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it - if avps_to_check['avp_code'] == 278: - avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) - response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated DWA") - orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it - return response - - #Disconnect Peer Answer - def Answer_282(self, packet_vars, avps): - avp = '' #Initiate empty var AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "00", 282, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated DPA") - return response - - #3GPP S6a/S6d Update Location Answer - def Answer_16777251_316(self, packet_vars, avps): - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) #AVP: Auth-Application-Id(258) l=12 f=-M- val=3GPP S6a/S6d (16777251) - - - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID - SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - - - #APNs from DB - APN_Configuration = '' - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - DiameterLogger.debug("Got back subscriber_details: " + str(subscriber_details)) - - if subscriber_details['enabled'] == 0: - DiameterLogger.info("Subscriber is disabled") - - - - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - DiameterLogger.debug("Successfully Generated ULA for disabled Sub") - response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) - return response - - except ValueError as e: - DiameterLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - DiameterLogger.error("Error is " + str(e)) - DiameterLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") - avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4), avps=avps, packet_vars=packet_vars) #Result Code - response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") - return response - except Exception as ex: - template = "An exception of type {0} occurred. Arguments:\n{1!r}" - message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) - raise - - #Store MME Location into Database - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - DiameterLogger.debug("Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm)) - - #Find Remote Peer we need to address CLRs through - try: #Check if we have a record-route set as that's where we'll need to send the response - remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header - remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it - except: #If we don't have a record-route set, we'll send the response to the OriginHost - remote_peer = OriginHost - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) - - database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) - - - #Boilerplate AVPs - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_vendor_avp(1406, "c0", 10415, "00000001") #ULA Flags - - - #Subscription Data: - subscription_data = '' - subscription_data += self.generate_vendor_avp(1426, "c0", 10415, "00000000") #Access Restriction Data - subscription_data += self.generate_vendor_avp(1424, "c0", 10415, "00000000") #Subscriber-Status (SERVICE_GRANTED) - subscription_data += self.generate_vendor_avp(1417, "c0", 10415, self.int_to_hex(int(subscriber_details['nam']), 4)) #Network-Access-Mode (PACKET_AND_CIRCUIT) - - #AMBR is a sub-AVP of Subscription Data - AMBR = '' #Initiate empty var AVP for AMBR - ue_ambr_ul = int(subscriber_details['ue_ambr_ul']) - ue_ambr_dl = int(subscriber_details['ue_ambr_dl']) - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - subscription_data += self.generate_vendor_avp(1435, "c0", 10415, AMBR) #Add AMBR AVP in two sub-AVPs - - - subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(int(subscriber_details['subscribed_rau_tau_timer']), 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) - - - #APN Configuration Profile is a sub AVP of Subscription Data - APN_Configuration_Profile = '' - APN_Configuration_Profile += self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(1, 4)) #Context Identifier for default APN (First APN is default in our case) - APN_Configuration_Profile += self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) #All-APN-Configurations-Included-Indicator - - #Split the APN list into a list - apn_list = subscriber_details['apn_list'].split(',') - DiameterLogger.debug("Current APN List: " + str(apn_list)) - #Remove the default APN from the list - try: - apn_list.remove(str(subscriber_details['default_apn'])) - except: - DiameterLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") - pass - #Add default APN in first position - apn_list.insert(0, str(subscriber_details['default_apn'])) - - DiameterLogger.debug("APN list: " + str(apn_list)) - APN_context_identifer_count = 1 - for apn_id in apn_list: - #Per APN Setup - DiameterLogger.debug("Processing APN ID " + str(apn_id)) - try: - apn_data = database.Get_APN(apn_id) - except: - DiameterLogger.error("Failed to get APN " + str(apn_id)) - continue - APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_data['apn']))) - - DiameterLogger.debug("Setting APN Configuration Profile") - #Sub AVPs of APN Configuration Profile - APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) - APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(int(apn_data['ip_version']), 4)) - - DiameterLogger.debug("Setting APN AMBR") - #AMBR - AMBR = '' #Initiate empty var AVP for AMBR - apn_ambr_ul = int(apn_data['apn_ambr_ul']) - apn_ambr_dl = int(apn_data['apn_ambr_dl']) - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") - #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) - APN_EPS_Subscribed_QoS_Profile = self.generate_vendor_avp(1431, "c0", 10415, AVP_QoS + AVP_ARP) - - #Try static IP allocation - try: - subscriber_routing_dict = database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details - DiameterLogger.info("Got static UE IP " + str(subscriber_routing_dict)) - DiameterLogger.debug("Found static IP for UE " + str(subscriber_routing_dict['ip_address'])) - Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) - except Exception as E: - DiameterLogger.debug("Error getting static UE IP: " + str(E)) - Served_Party_Address = "" - - - #if 'PDN_GW_Allocation_Type' in apn_profile: - # DiameterLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) - # PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - # DiameterLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) - # else: - # PDN_GW_Allocation_Type = '' - # if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - # DiameterLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) - # VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - # DiameterLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) - # else: - # VPLMN_Dynamic_Address_Allowed = '' - PDN_GW_Allocation_Type = '' - VPLMN_Dynamic_Address_Allowed = '' - - #If static SMF / PGW-C defined - if apn_data['pgw_address'] is not None: - DiameterLogger.info("MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address'])) - MIP_Home_Agent_Address = self.generate_avp(334, '40', self.ip_to_hex(apn_data['pgw_address'])) - MIP6_Agent_Info = self.generate_avp(486, '40', MIP_Home_Agent_Address) - else: - MIP6_Agent_Info = '' - - APN_Configuration_AVPS = APN_context_identifer + APN_PDN_type + APN_AMBR + APN_Service_Selection \ - + APN_EPS_Subscribed_QoS_Profile + Served_Party_Address + MIP6_Agent_Info + PDN_GW_Allocation_Type + VPLMN_Dynamic_Address_Allowed - - APN_Configuration += self.generate_vendor_avp(1430, "c0", 10415, APN_Configuration_AVPS) - - #Incriment Context Identifier Count to keep track of how many APN Profiles returned - APN_context_identifer_count = APN_context_identifer_count + 1 - DiameterLogger.debug("Completed processing APN ID " + str(apn_id)) - - subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_Configuration_Profile + APN_Configuration) - - try: - DiameterLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") - msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(subscriber_details['msisdn']))) #MSISDN - DiameterLogger.debug(msisdn_avp) - subscription_data += msisdn_avp - except Exception as E: - DiameterLogger.error("Failed to populate MSISDN in ULA due to error " + str(E)) - - if 'RAT_freq_priorityID' in subscriber_details: - DiameterLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") - rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - DiameterLogger.debug("Adding rat_freq_priorityID: " + str(rat_freq_priorityID)) - subscription_data += rat_freq_priorityID - - if 'charging_characteristics' in subscriber_details: - DiameterLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA") - _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, str(subscriber_details['charging_characteristics'])) - subscription_data += _3gpp_charging_characteristics - DiameterLogger.debug("Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics)) - - #ToDo - Fix this - # if 'APN_OI_replacement' in subscriber_details: - # DiameterLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") - # subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) - - avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data - - response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - - DiameterLogger.debug("Successfully Generated ULA") - return response - - #3GPP S6a/S6d Authentication Information Answer - def Answer_16777251_318(self, packet_vars, avps): - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from User-Name AVP in request - avp = '' - try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - - if subscriber_details['enabled'] == 0: - DiameterLogger.info("Subscriber is disabled") - avp += self.generate_avp(268, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #Result Code - prom_diam_auth_event_count.labels( - diameter_application_id = 16777251, - diameter_cmd_code = 318, - event='Disabled User', - imsi_prefix = str(imsi[0:6]), - ).inc() - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated AIA for disabled Sub") - DiameterLogger.debug(response) - return response - - except ValueError as e: - DiameterLogger.info("Minor getting subscriber details for IMSI " + str(imsi)) - DiameterLogger.info(e) - #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - prom_diam_auth_event_count.labels( - diameter_application_id = 16777251, - diameter_cmd_code = 318, - event='Unknown User', - imsi_prefix = str(imsi[0:6]), - ).inc() - - DiameterLogger.info("Subscriber " + str(imsi) + " is unknown in database") - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - except Exception as ex: - template = "An exception of type {0} occurred. Arguments:\n{1!r}" - message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) - raise - - - - requested_vectors = 1 - for avp in avps: - if avp['avp_code'] == 1408: - DiameterLogger.debug("AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP") - EUTRAN_Authentication_Info = avp['misc_data'] - DiameterLogger.debug("EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info)) - for sub_avp in EUTRAN_Authentication_Info: - #If resync request - if sub_avp['avp_code'] == 1411: - DiameterLogger.debug("Re-Synchronization required - SQN is out of sync") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777251, - diameter_cmd_code = 318, - event='Resync', - imsi_prefix = str(imsi[0:6]), - ).inc() - auts = str(sub_avp['misc_data'])[32:] - rand = str(sub_avp['misc_data'])[:32] - rand = binascii.unhexlify(rand) - #Calculate correct SQN - database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - - #Get number of requested vectors - if sub_avp['avp_code'] == 1410: - DiameterLogger.debug("Raw value of requested vectors is " + str(sub_avp['misc_data'])) - requested_vectors = int(sub_avp['misc_data'], 16) - if requested_vectors >= 32: - DiameterLogger.info("Client has requested " + str(requested_vectors) + " vectors, limiting this to 32") - requested_vectors = 32 - - DiameterLogger.debug("Generating " + str(requested_vectors) + " vectors as requested") - eutranvector_complete = '' - while requested_vectors != 0: - DiameterLogger.debug("Generating vector number " + str(requested_vectors)) - plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) - eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) - eutranvector += self.generate_vendor_avp(1419, "c0", 10415, self.int_to_hex(requested_vectors, 4)) - eutranvector += self.generate_vendor_avp(1447, "c0", 10415, vector_dict['rand']) #And is made up of other AVPs joined together with RAND - eutranvector += self.generate_vendor_avp(1448, "c0", 10415, vector_dict['xres']) #XRes - eutranvector += self.generate_vendor_avp(1449, "c0", 10415, vector_dict['autn']) #AUTN - eutranvector += self.generate_vendor_avp(1450, "c0", 10415, vector_dict['kasme']) #And KASME - - requested_vectors = requested_vectors - 1 - eutranvector_complete += self.generate_vendor_avp(1414, "c0", 10415, eutranvector) #Put EUTRAN vectors in E-UTRAN-Vector AVP - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_vendor_avp(1413, "c0", 10415, eutranvector_complete) #Authentication-Info (3GPP) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") - #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - - response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated AIA") - DiameterLogger.debug(response) - return response - - #Purge UE Answer (PUA) - def Answer_16777251_321(self, packet_vars, avps): - - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') - - avp = '' - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #1442 - PUA-Flags - avp += self.generate_vendor_avp(1442, "c0", 10415, self.int_to_hex(1, 4)) - - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID - SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - - - response = self.generate_diameter_packet("01", "40", 321, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - - - database.Update_Serving_MME(imsi, None) - DiameterLogger.debug("Successfully Generated PUA") - return response - - #Notify Answer (NOA) - def Answer_16777251_323(self, packet_vars, avps): - avp = '' - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - response = self.generate_diameter_packet("01", "40", 323, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated PUA") - return response - - #3GPP Gx Credit Control Answer - def Answer_16777238_272(self, packet_vars, avps): - CC_Request_Type = self.get_avp_data(avps, 416)[0] - CC_Request_Number = self.get_avp_data(avps, 415)[0] - #Called Station ID - DiameterLogger.debug("Attempting to find APN in CCR") - apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') - DiameterLogger.debug("CCR for APN " + str(apn)) - - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - - OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - - try: #Check if we have a record-route set as that's where we'll need to send the response - remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header - remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it - except: #If we don't have a record-route set, we'll send the response to the OriginHost - remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) - avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type - avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number - - - #Get Subscriber info from Subscription ID - for SubscriptionIdentifier in self.get_avp_data(avps, 443): - for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - DiameterLogger.debug("Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI") - if UniqueSubscriptionIdentifier['avp_code'] == 444: - imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - DiameterLogger.debug("Found IMSI " + str(imsi)) - - DiameterLogger.info("SubscriptionID: " + str(self.get_avp_data(avps, 443))) - try: - DiameterLogger.info("Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database") #Get subscriber details - ChargingRules = database.Get_Charging_Rules(imsi=imsi, apn=apn) - DiameterLogger.info("Got Charging Rules: " + str(ChargingRules)) - except Exception as E: - #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - DiameterLogger.debug(E) - DiameterLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists") - - if int(CC_Request_Type) == 1: - DiameterLogger.info("Request type for CCA is 1 - Initial") - - #Get UE IP - try: - ue_ip = self.get_avp_data(avps, 8)[0] - ue_ip = str(self.hex_to_ip(ue_ip)) - except Exception as E: - DiameterLogger.error("Failed to get UE IP") - DiameterLogger.error(E) - ue_ip = 'Failed to Decode / Get UE IP' - - #Store PGW location into Database - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) - database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) - - #Supported-Features(628) (Gx feature list) - avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") - - #Default EPS Beaerer QoS (From database with fallback source CCR-I) - try: - apn_data = ChargingRules['apn_data'] - DiameterLogger.debug("Setting APN AMBR") - #AMBR - AMBR = '' #Initiate empty var AVP for AMBR - apn_ambr_ul = int(apn_data['apn_ambr_ul']) - apn_ambr_dl = int(apn_data['apn_ambr_dl']) - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") - #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) - avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) - except Exception as E: - DiameterLogger.error(E) - DiameterLogger.error("Failed to populate default_EPS_QoS from DB for sub " + str(imsi)) - default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] - avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - - - DiameterLogger.info("Creating QoS Information") - #QoS-Information - try: - apn_data = ChargingRules['apn_data'] - apn_ambr_ul = int(apn_data['apn_ambr_ul']) - apn_ambr_dl = int(apn_data['apn_ambr_dl']) - QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) - QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - DiameterLogger.info("Created both QoS AVPs from data from Database") - DiameterLogger.info("Populated QoS_Information") - avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - except Exception as E: - DiameterLogger.error("Failed to get QoS information dynamically for sub " + str(imsi)) - DiameterLogger.error(E) - - QoS_Information = '' - for AMBR_Part in self.get_avp_data(avps, 1016)[0]: - DiameterLogger.debug(AMBR_Part) - AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) - QoS_Information += AMBR_AVP - DiameterLogger.debug("QoS_Information added " + str(AMBR_AVP)) - avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - DiameterLogger.debug("QoS information set statically") - - DiameterLogger.info("Added to AVP List") - DiameterLogger.debug("QoS Information: " + str(QoS_Information)) - - #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I - if ChargingRules and ChargingRules['charging_rules'] is not None: - try: - DiameterLogger.debug(ChargingRules) - for individual_charging_rule in ChargingRules['charging_rules']: - DiameterLogger.debug("Processing Charging Rule: " + str(individual_charging_rule)) - avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) - - except Exception as E: - DiameterLogger.debug("Error in populating dynamic charging rules: " + str(E)) - - elif int(CC_Request_Type) == 3: - DiameterLogger.info("Request type for CCA is 3 - Termination") - database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP Cx User Authorization Answer - def Answer_16777216_300(self, packet_vars, avps): - - avp = '' #Initiate empty var AVP #Session-ID - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - - - OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - - try: #Check if we have a record-route set as that's where we'll need to send the response - remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header - remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it - except: #If we don't have a record-route set, we'll send the response to the OriginHost - remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) - - try: - DiameterLogger.info("Checking if username present") - username = self.get_avp_data(avps, 1)[0] - username = binascii.unhexlify(username).decode('utf-8') - DiameterLogger.info("Username AVP is present, value is " + str(username)) - imsi = username.split('@')[0] #Strip Domain - domain = username.split('@')[1] #Get Domain Part - DiameterLogger.debug("Extracted imsi: " + str(imsi) + " now checking backend for this IMSI") - ims_subscriber_details = database.Get_IMS_Subscriber(imsi=imsi) - except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_300() input") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 300, - event='Unknown User', - imsi_prefix = str(imsi[0:6]), - ).inc() - result_code = 5001 #IMS User Unknown - #Experimental Result AVP - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 300, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #Determine SAR Type & Store - user_authorization_type_avp_data = self.get_avp_data(avps, 623) - if user_authorization_type_avp_data: - try: - User_Authorization_Type = int(user_authorization_type_avp_data[0]) - DiameterLogger.debug("User_Authorization_Type is: " + str(User_Authorization_Type)) - if (User_Authorization_Type == 1): - DiameterLogger.debug("This is Deregister") - database.Update_Serving_CSCF(imsi, serving_cscf=None) - #Populate S-CSCF Address - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - except Exception as E: - DiameterLogger.debug("Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E)) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) - if ims_subscriber_details['scscf'] != None: - DiameterLogger.debug("Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf'])) - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) - experimental_avp = '' - experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID - experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2002),"x").zfill(8), avps=avps, packet_vars=packet_vars) #DIAMETER_SUBSEQUENT_REGISTRATION (2002) - avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result - else: - DiameterLogger.debug("No SCSCF Assigned from DB") - if 'scscf_pool' in yaml_config['hss']: - try: - scscf = random.choice(yaml_config['hss']['scscf_pool']) - DiameterLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) - except Exception as E: - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated S-CSCF Address as failed to source from list due to " + str(E)) - else: - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated S-CSCF Address as none set in scscf_pool in config") - experimental_avp = '' - experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID - experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2001),"x").zfill(8), avps=avps, packet_vars=packet_vars) #DIAMETER_FIRST_REGISTRATION (2001) - avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result - - response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - - return response - - #3GPP Cx Server Assignment Answer - def Answer_16777216_301(self, packet_vars, avps): - avp = '' #Initiate empty var AVP #Session-ID - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - - OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - - #Find Remote Peer we need to address CLRs through - try: #Check if we have a record-route set as that's where we'll need to send the response - remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header - remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it - except: #If we don't have a record-route set, we'll send the response to the OriginHost - remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) - - try: - DiameterLogger.info("Checking if username present") - username = self.get_avp_data(avps, 601)[0] - ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) - imsi = ims_subscriber_details['imsi'] - domain = "ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org" - except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_301() input") - result_code = 5005 - #Experimental Result AVP - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 301, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(str(imsi) + '@' + str(domain))),'ascii')) - #Cx-User-Data (XML) - - #This loads a Jinja XML template as the default iFC - templateLoader = jinja2.FileSystemLoader(searchpath="./") - templateEnv = jinja2.Environment(loader=templateLoader) - DiameterLogger.debug("Loading iFC from path " + str(ims_subscriber_details['ifc_path'])) - template = templateEnv.get_template(ims_subscriber_details['ifc_path']) - - #These variables are passed to the template for use - ims_subscriber_details['mnc'] = self.MNC.zfill(3) - ims_subscriber_details['mcc'] = self.MCC.zfill(3) - - xmlbody = template.render(iFC_vars=ims_subscriber_details) # this is where to put args to the template renderer - avp += self.generate_vendor_avp(606, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) - - #Charging Information - #avp += self.generate_vendor_avp(618, "c0", 10415, "0000026dc000001b000028af7072695f6363665f6164647265737300") - #avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS - - #Determine SAR Type & Store - Server_Assignment_Type_Hex = self.get_avp_data(avps, 614)[0] - Server_Assignment_Type = self.hex_to_int(Server_Assignment_Type_Hex) - DiameterLogger.debug("Server-Assignment-Type is: " + str(Server_Assignment_Type)) - ServingCSCF = self.get_avp_data(avps, 602)[0] #Get OriginHost from AVP - ServingCSCF = binascii.unhexlify(ServingCSCF).decode('utf-8') #Format it - DiameterLogger.debug("Subscriber is served by S-CSCF " + str(ServingCSCF)) - if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): - DiameterLogger.debug("SAR is Register / Re-Restister") - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) - database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) - else: - DiameterLogger.debug("SAR is not Register") - database.Update_Serving_CSCF(imsi, serving_cscf=None) - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - - response = self.generate_diameter_packet("01", "40", 301, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP Cx Location Information Answer - def Answer_16777216_302(self, packet_vars, avps): - avp = '' #Initiate empty var AVP #Session-ID - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) - avp += self.generate_avp(277, 40, "00000001") #Auth Session State - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - - - - try: - DiameterLogger.info("Checking if username present") - username = self.get_avp_data(avps, 601)[0] - ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) - if ims_subscriber_details['scscf'] != None: - DiameterLogger.debug("Got SCSCF on record for Sub") - #Strip double sip prefix - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(str(ims_subscriber_details['scscf']))),'ascii')) - else: - DiameterLogger.debug("No SCSF assigned - Using SCSCF Pool") - if 'scscf_pool' in yaml_config['hss']: - try: - scscf = random.choice(yaml_config['hss']['scscf_pool']) - DiameterLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) - except Exception as E: - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated iFC as failed to source from list due to " + str(E)) - else: - avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated iFC") - except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_302() input") - result_code = 5001 - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 302, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() - #Experimental Result AVP - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 302, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 302, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - - return response - - #3GPP Cx Multimedia Authentication Answer - def Answer_16777216_303(self, packet_vars, avps): - public_identity = self.get_avp_data(avps, 601)[0] - public_identity = binascii.unhexlify(public_identity).decode('utf-8') - DiameterLogger.debug("Got MAR for public_identity : " + str(public_identity)) - username = self.get_avp_data(avps, 1)[0] - username = binascii.unhexlify(username).decode('utf-8') - imsi = username.split('@')[0] #Strip Domain - domain = username.split('@')[1] #Get Domain Part - DiameterLogger.debug("Got MAR username: " + str(username)) - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - avp += self.generate_avp(277, 40, "00000001") #Auth Session State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - except: - #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - DiameterLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for MAA") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 303, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() - experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER ERROR - User Unknown) - experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") - #Experimental Result (297) - avp += self.generate_avp(297, 40, experimental_result) - response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - DiameterLogger.debug("Got subscriber data for MAA OK") - - mcc, mnc = imsi[0:3], imsi[3:5] - plmn = self.EncodePLMN(mcc, mnc) - - #Determine if SQN Resync is required & auth type to use - for sub_avp_612 in self.get_avp_data(avps, 612)[0]: - if sub_avp_612['avp_code'] == 610: - DiameterLogger.info("SQN in HSS is out of sync - Performing resync") - auts = str(sub_avp_612['misc_data'])[32:] - rand = str(sub_avp_612['misc_data'])[:32] - rand = binascii.unhexlify(rand) - database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - DiameterLogger.debug("Resynced SQN in DB") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 302, - event='ReAuth', - imsi_prefix = str(imsi[0:6]), - ).inc() - if sub_avp_612['avp_code'] == 608: - DiameterLogger.info("Auth mechansim requested: " + str(sub_avp_612['misc_data'])) - auth_scheme = binascii.unhexlify(sub_avp_612['misc_data']).decode('utf-8') - DiameterLogger.info("Auth mechansim requested: " + str(auth_scheme)) - - DiameterLogger.debug("IMSI is " + str(imsi)) - avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) - avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi + "@" + domain)),'ascii')) #Username - - - #Determine Vectors to Generate - if auth_scheme == "Digest-MD5": - DiameterLogger.debug("Generating MD5 Challenge") - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "Digest-MD5", username=imsi, plmn=plmn) - avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) - avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-MD5'),'ascii')) - #Nonce - avp_SIP_Authenticate = self.generate_vendor_avp(609, "c0", 10415, str(vector_dict['nonce'])) - #Expected Response - avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(str.encode(vector_dict['SIP_Authenticate'])),'ascii')) - auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization - else: - DiameterLogger.debug("Generating AKA-MD5 Auth Challenge") - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "sip_auth", plmn=plmn) - - - #diameter.3GPP-SIP-Auth-Data-Items: - - #AVP Code: 613 3GPP-SIP-Item-Number - avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) - #AVP Code: 608 3GPP-SIP-Authentication-Scheme - avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-AKAv1-MD5'),'ascii')) - #AVP Code: 609 3GPP-SIP-Authenticate - avp_SIP_Authenticate = self.generate_vendor_avp(609, "c0", 10415, str(binascii.hexlify(vector_dict['SIP_Authenticate']),'ascii')) #RAND + AUTN - #AVP Code: 610 3GPP-SIP-Authorization - avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(vector_dict['xres']),'ascii')) #XRES - #AVP Code: 625 Confidentiality-Key - avp_Confidentialility_Key = self.generate_vendor_avp(625, "c0", 10415, str(binascii.hexlify(vector_dict['ck']),'ascii')) #CK - #AVP Code: 626 Integrity-Key - avp_Integrity_Key = self.generate_vendor_avp(626, "c0", 10415, str(binascii.hexlify(vector_dict['ik']),'ascii')) #IK - - auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization + avp_Confidentialility_Key + avp_Integrity_Key - avp += self.generate_vendor_avp(612, "c0", 10415, auth_data_item) #3GPP-SIP-Auth-Data-Item - - avp += self.generate_vendor_avp(607, "c0", 10415, "00000001") #3GPP-SIP-Number-Auth-Items - - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - - response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #Generate a Generic error handler with Result Code as input - def Respond_ResultCode(self, packet_vars, avps, result_code): - logging.error("Responding with result code " + str(result_code) + " to request with command code " + str(packet_vars['command_code'])) - avp = '' #Initiate empty var AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - try: - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - except: - DiameterLogger.info("Respond_ResultCode: Failed to add SessionID into error") - for avps_to_check in avps: #Only include AVP 260 (Vendor-Specific-Application-ID) if inital request included it - if avps_to_check['avp_code'] == 260: - concat_subavp = '' - for sub_avp in avps_to_check['misc_data']: - concat_subavp += self.generate_avp(sub_avp['avp_code'], sub_avp['avp_flags'], sub_avp['misc_data']) - avp += self.generate_avp(260, 40, concat_subavp) #Vendor-Specific-Application-ID - avp += self.generate_avp(268, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - - response = self.generate_diameter_packet("01", "60", int(packet_vars['command_code']), int(packet_vars['ApplicationId']), packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP Cx Registration Termination Answer - def Answer_16777216_304(self, packet_vars, avps): - avp = '' #Initiate empty var AVP #Session-ID - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - vendor_id = self.generate_avp(266, 40, str(binascii.hexlify('10415'),'ascii')) - DiameterLogger.debug("vendor_id avp: " + str(vendor_id)) - auth_application_id = self.generate_avp(248, 40, self.int_to_hex(16777252, 8)) - DiameterLogger.debug("auth_application_id: " + auth_application_id) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(277, 40, "00000001") #Auth Session State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - #* [ Proxy-Info ] - proxy_host_avp = self.generate_avp(280, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - proxy_state_avp = self.generate_avp(33, "40", "0001") - avp += self.generate_avp(284, "40", proxy_host_avp + proxy_state_avp) #Proxy-Info AVP ( 284 ) - - #* [ Route-Record ] - avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - - response = self.generate_diameter_packet("01", "40", 304, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP Sh User-Data Answer - def Answer_16777217_306(self, packet_vars, avps): - avp = '' #Initiate empty var AVP #Session-ID - - #Define values so we can check if they've been changed - msisdn = None - try: - user_identity_avp = self.get_avp_data(avps, 700)[0] - msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request - DiameterLogger.info("Got raw MSISDN with value " + str(msisdn)) - msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) - except: - DiameterLogger.error("No MSISDN") - - if msisdn is not None: - DiameterLogger.debug("Getting susbcriber IMS info based on MSISDN") - subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - DiameterLogger.debug("Getting susbcriber info based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber details: " + str(subscriber_details)) - subscriber_details = {**subscriber_details, **subscriber_ims_details} - DiameterLogger.debug("Merged subscriber details: " + str(subscriber_details)) - else: - DiameterLogger.error("No MSISDN or IMSI in Answer_16777217_306() input") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 306, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() - result_code = 5005 - #Experimental Result AVP - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 306, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000001") #Vendor-Specific-Application-ID for Cx - - #Sh-User-Data (XML) - #This loads a Jinja XML template containing the Sh-User-Data - templateLoader = jinja2.FileSystemLoader(searchpath="./") - templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = yaml_config['hss']['Default_Sh_UserData'] - DiameterLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") - template = templateEnv.get_template(sh_userdata_template) - #These variables are passed to the template for use - subscriber_details['mnc'] = self.MNC.zfill(3) - subscriber_details['mcc'] = self.MCC.zfill(3) - - DiameterLogger.debug("Rendering template with values: " + str(subscriber_details)) - xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer - avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - - response = self.generate_diameter_packet("01", "40", 306, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - - return response - - #3GPP Sh Profile-Update Answer - def Answer_16777217_307(self, packet_vars, avps): - - - #Get IMSI - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') - - #Get Sh User Data - sh_user_data = self.get_avp_data(avps, 702)[0] #Get IMSI from User-Name AVP in request - sh_user_data = binascii.unhexlify(sh_user_data).decode('utf-8') - - DiameterLogger.debug("Got Sh User data: " + str(sh_user_data)) - - #Push updated User Data into IMS Backend - #Start with the Current User Data - subscriber_ims_details = database.Get_IMS_Subscriber(imsi=imsi) - database.UpdateObj(database.IMS_SUBSCRIBER, {'sh_profile': sh_user_data}, subscriber_ims_details['ims_subscriber_id']) - - avp = '' #Initiate empty var AVP #Session-ID - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777217),"x").zfill(8)) #Auth-Application-ID Sh - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - - - response = self.generate_diameter_packet("01", "40", 307, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP S13 - ME-Identity-Check Answer - def Answer_16777252_324(self, packet_vars, avps): - - #Get IMSI - try: - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - DiameterLogger.info("Got IMSI with value " + str(imsi)) - except Exception as e: - DiameterLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) - - #Get IMEI - for sub_avp in self.get_avp_data(avps, 1401)[0]: - DiameterLogger.debug("Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI") - if sub_avp['avp_code'] == 1402: - imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') - DiameterLogger.debug("Found IMEI " + str(imei)) - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 - avp += self.generate_avp(277, 40, "00000001") #Auth Session State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 'c0', 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 'c0', self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: SUCESS (2001) - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - - #Equipment-Status - EquipmentStatus = database.Check_EIR(imsi=imsi, imei=imei) - avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) - prom_diam_eir_event_count.labels(response=EquipmentStatus).inc() - - response = self.generate_diameter_packet("01", "40", 324, 16777252, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP SLh - LCS-Routing-Info-Answer - def Answer_16777291_8388622(self, packet_vars, avps): - avp = '' - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777291),"x").zfill(8)) #Auth-Application-ID SLh - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - avp += self.generate_avp(277, 40, "00000001") #Auth Session State (NO_STATE_MAINTAINED) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #Create list of valid AVPs - present_avps = [] - for avp_id in avps: - present_avps.append(avp_id['avp_code']) - - #Define values so we can check if they've been changed - msisdn = None - imsi = None - - #Try and get IMSI if present - if 1 in present_avps: - DiameterLogger.info("IMSI AVP is present") - try: - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - DiameterLogger.info("Got IMSI with value " + str(imsi)) - except Exception as e: - DiameterLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) - elif 701 in present_avps: - #Try and get MSISDN if present - try: - msisdn = self.get_avp_data(avps, 701)[0] #Get MSISDN from AVP in request - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) - avp += self.generate_vendor_avp(701, 'c0', 10415, self.get_avp_data(avps, 701)[0]) #MSISDN - DiameterLogger.info("Got MSISDN with encoded value " + str(msisdn)) - msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with decoded value " + str(msisdn)) - except Exception as e: - DiameterLogger.debug("Failed to get MSISDN from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) - else: - DiameterLogger.error("No MSISDN or IMSI") - - try: - if imsi is not None: - DiameterLogger.debug("Getting susbcriber location based on IMSI") - subscriber_details = database.Get_Subscriber(imsi=imsi) - DiameterLogger.debug("Got subscriber_details from IMSI: " + str(subscriber_details)) - elif msisdn is not None: - DiameterLogger.debug("Getting susbcriber location based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber_details from MSISDN: " + str(subscriber_details)) - except Exception as E: - DiameterLogger.error("No MSISDN or IMSI returned in Answer_16777291_8388622 input") - DiameterLogger.error("Error is " + str(E)) - DiameterLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") - avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4), avps=avps, packet_vars=packet_vars) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") - return response - - - - DiameterLogger.info("Got subscriber_details for subscriber: " + str(subscriber_details)) - - - if subscriber_details['serving_mme'] == None: - #DB has no location on record for subscriber - DiameterLogger.info("No location on record for Subscriber") - result_code = 4201 - #DIAMETER_ERROR_ABSENT_USER (4201) - #This result code shall be sent by the HSS to indicate that the location of the targeted user is not known at this time to - #satisfy the requested operation. - - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code - avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - - response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - - - #Serving Node AVP - avp_serving_node = '' - avp_serving_node += self.generate_vendor_avp(2402, "c0", 10415, self.string_to_hex(subscriber_details['serving_mme'])) #MME-Name - avp_serving_node += self.generate_vendor_avp(2408, "c0", 10415, self.OriginRealm) #MME-Realm - avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(yaml_config['hss']['bind_ip'][0])) #GMLC-Address - avp += self.generate_vendor_avp(2401, "c0", 10415, avp_serving_node) #Serving-Node AVP - - #Set Result-Code - result_code = 2001 #Diameter Success - avp += self.generate_avp(268, 40, self.int_to_hex(result_code, 4), avps=avps, packet_vars=packet_vars) #Result Code - - response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #### Diameter Requests #### - - #Capabilities Exchange Request - def Request_257(self): - avp = '' - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(257, 40, self.ip_to_hex(socket.gethostbyname(socket.gethostname()))) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) - avp += self.generate_avp(266, 40, "00000000") #Vendor-Id - avp += self.generate_avp(269, "00", self.ProductName) #Product-Name - avp += self.generate_avp(260, 40, "000001024000000c01000023" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - avp += self.generate_avp(260, 40, "000001024000000c01000016" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Gx) - avp += self.generate_avp(260, 40, "000001024000000c01000027" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (SLg) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) - avp += self.generate_avp(258, 40, format(int(4294967295),"x").zfill(8)) #Auth-Application-ID Relay - avp += self.generate_avp(265, 40, format(int(5535),"x").zfill(8)) #Supported-Vendor-ID (3GGP v2) - avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) - avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) - response = self.generate_diameter_packet("01", "80", 257, 0, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #Device Watchdog Request - def Request_280(self): - avp = '' - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - response = self.generate_diameter_packet("01", "80", 280, 0, self.generate_id(4), self.generate_id(4), avp)#Generate Diameter packet - return response - - #Disconnect Peer Request - def Request_282(self): - avp = '' #Initiate empty var AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(273, 40, "00000000") #Disconnect-Cause (REBOOTING (0)) - response = self.generate_diameter_packet("01", "80", 282, 0, self.generate_id(4), self.generate_id(4), avp)#Generate Diameter packet - return response - - #3GPP S6a/S6d Authentication Information Request - def Request_16777251_318(self, imsi, DestinationHost, DestinationRealm, requested_vectors=1): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm - #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - number_of_requested_vectors = self.generate_vendor_avp(1410, "c0", 10415, format(int(requested_vectors),"x").zfill(8)) - immediate_response_preferred = self.generate_vendor_avp(1412, "c0", 10415, format(int(1),"x").zfill(8)) - avp += self.generate_vendor_avp(1408, "c0", 10415, str(number_of_requested_vectors) + str(immediate_response_preferred)) - - mcc = str(imsi)[:3] - mnc = str(imsi)[3:5] - avp += self.generate_vendor_avp(1407, "c0", 10415, self.EncodePLMN(mcc, mnc)) #Visited-PLMN-Id(1407) (Derrived from start of IMSI) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID - response = self.generate_diameter_packet("01", "c0", 318, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S6a/S6d Update Location Request (ULR) - def Request_16777251_316(self, imsi, DestinationRealm): - mcc = imsi[0:3] - mnc = imsi[3:5] - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + yaml_config['hss']['OriginHost'])),'ascii')) - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - avp += self.generate_vendor_avp(1032, "80", 10415, self.int_to_hex(1004, 4)) #RAT-Type val=EUTRAN (1004) - avp += self.generate_vendor_avp(1405, "c0", 10415, "00000002") #ULR-Flags val=2 - avp += self.generate_vendor_avp(1407, "c0", 10415, self.EncodePLMN(mcc, mnc)) #Visited-PLMN-Id(1407) (Derrived from start of IMSI) - avp += self.generate_vendor_avp(1615, "80", 10415, "00000000") #E-SRVCC-Capability val=UE-SRVCC-NOT-SUPPORTED (0) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID - response = self.generate_diameter_packet("01", "c0", 316, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S6a/S6d Purge UE Request PUR - def Request_16777251_321(self, imsi, DestinationRealm, DestinationHost): - avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm - #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID - response = self.generate_diameter_packet("01", "c0", 321, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S6a/S6d NOtify Request NOR - def Request_16777251_323(self, imsi, DestinationRealm, DestinationHost): - avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm - #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID - response = self.generate_diameter_packet("01", "c0", 323, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S6a/S6d Cancel-Location-Request Request CLR - def Request_16777251_317(self, imsi, DestinationRealm, DestinationHost=None, CancellationType=2): - avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm - if DestinationHost != None: - avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID - avp += self.generate_vendor_avp(1420, "c0", 10415, self.int_to_hex(CancellationType, 4)) #Cancellation-Type (Subscription Withdrawl) - response = self.generate_diameter_packet("01", "c0", 317, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S6a/S6d Insert Subscriber Data Request (ISD) - def Request_16777251_319(self, packet_vars, avps, **kwargs): - avp = '' #Initiate empty var AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session ID generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID set AVP - avp += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - - - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID - SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags - if 'GetLocation' in kwargs: - DiameterLogger.debug("Requsted Get Location ISD") - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID - SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "18000007") #Feature-List Flags - avp += self.generate_vendor_avp(1490, "c0", 10415, "00000018") #IDR-Flags - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - - try: - user_identity_avp = self.get_avp_data(avps, 700)[0] - DiameterLogger.info(user_identity_avp) - msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request - msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) - except: - DiameterLogger.error("No MSISDN present") - return - #Get Subscriber Location from Database - subscriber_location = database.GetSubscriberLocation(msisdn=msisdn) - DiameterLogger.debug("Got subscriber location: " + subscriber_location) - - - DiameterLogger.info("Getting IMSI for MSISDN " + str(msisdn)) - imsi = database.Get_IMSI_from_MSISDN(msisdn) - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - - DiameterLogger.info("Got back location data: " + str(subscriber_location)) - - #Populate Destination Host & Realm - avp += self.generate_avp(293, 40, self.string_to_hex(subscriber_location)) #Destination Host #Destination-Host - avp += self.generate_avp(283, 40, self.string_to_hex('epc.mnc001.mcc214.3gppnetwork.org')) #Destination Realm - - else: - #APNs from DB - imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request - imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - avp += self.generate_vendor_avp(1490, "c0", 10415, "00000000") #IDR-Flags - - destinationHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - destinationHost = binascii.unhexlify(destinationHost).decode('utf-8') #Format it - DiameterLogger.debug("Received originHost to use as destinationHost is " + str(destinationHost)) - destinationRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - destinationRealm = binascii.unhexlify(destinationRealm).decode('utf-8') #Format it - DiameterLogger.debug("Received originRealm to use as destinationRealm is " + str(destinationRealm)) - avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination-Host - avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) - - APN_Configuration = '' - - try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - except ValueError as e: - DiameterLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - DiameterLogger.error("Error is " + str(e)) - raise - except Exception as ex: - template = "An exception of type {0} occurred. Arguments:\n{1!r}" - message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) - raise - - - - #Subscription Data: - subscription_data = '' - subscription_data += self.generate_vendor_avp(1426, "c0", 10415, "00000000") #Access Restriction Data - subscription_data += self.generate_vendor_avp(1424, "c0", 10415, "00000000") #Subscriber-Status (SERVICE_GRANTED) - subscription_data += self.generate_vendor_avp(1417, "c0", 10415, "00000000") #Network-Access-Mode (PACKET_AND_CIRCUIT) - - #AMBR is a sub-AVP of Subscription Data - AMBR = '' #Initiate empty var AVP for AMBR - if 'ue_ambr_ul' in subscriber_details: - ue_ambr_ul = int(subscriber_details['ue_ambr_ul']) - else: - #use default AMBR of unlimited if no value in subscriber_details - ue_ambr_ul = 1048576000 - - if 'ue_ambr_dl' in subscriber_details: - ue_ambr_dl = int(subscriber_details['ue_ambr_dl']) - else: - #use default AMBR of unlimited if no value in subscriber_details - ue_ambr_dl = 1048576000 - - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - subscription_data += self.generate_vendor_avp(1435, "c0", 10415, AMBR) #Add AMBR AVP in two sub-AVPs - - #APN Configuration Profile is a sub AVP of Subscription Data - APN_Configuration_Profile = '' - APN_Configuration_Profile += self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(1, 4)) #Context Identifier - APN_Configuration_Profile += self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) #All-APN-Configurations-Included-Indicator - - - - apn_list = subscriber_details['pdn'] - DiameterLogger.debug("APN list: " + str(apn_list)) - APN_context_identifer_count = 1 - for apn_profile in apn_list: - DiameterLogger.debug("Processing APN profile " + str(apn_profile)) - APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_profile['apn']))) - - DiameterLogger.debug("Setting APN Configuration Profile") - #Sub AVPs of APN Configuration Profile - APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) - APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(0, 4)) - - DiameterLogger.debug("Setting APN AMBR") - #AMBR - AMBR = '' #Initiate empty var AVP for AMBR - if 'AMBR' in apn_profile: - ue_ambr_ul = int(apn_profile['AMBR']['apn_ambr_ul']) - ue_ambr_dl = int(apn_profile['AMBR']['apn_ambr_dl']) - else: - #use default AMBR of unlimited if no value in subscriber_details - ue_ambr_ul = 50000000 - ue_ambr_dl = 100000000 - - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") - #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['priority_level']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_vulnerability']), 4)) - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_profile['qos']['qci']), 4)) - APN_EPS_Subscribed_QoS_Profile = self.generate_vendor_avp(1431, "c0", 10415, AVP_QoS + AVP_ARP) - - - #If static UE IP is specified - try: - apn_ip = apn_profile['ue']['addr'] - DiameterLogger.debug("Found static IP for UE " + str(apn_ip)) - Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(apn_ip)) - except: - Served_Party_Address = "" - - if 'MIP6-Agent-Info' in apn_profile: - DiameterLogger.info("MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info'])) - MIP6_Destination_Host = self.generate_avp(293, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_HOST']))) - MIP6_Destination_Realm = self.generate_avp(283, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_REALM']))) - MIP6_Home_Agent_Host = self.generate_avp(348, '40', MIP6_Destination_Host + MIP6_Destination_Realm) - MIP6_Agent_Info = self.generate_avp(486, '40', MIP6_Home_Agent_Host) - DiameterLogger.info("MIP6 value is " + str(MIP6_Agent_Info)) - else: - MIP6_Agent_Info = '' - - if 'PDN_GW_Allocation_Type' in apn_profile: - DiameterLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) - PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - DiameterLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) - else: - PDN_GW_Allocation_Type = '' - - if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - DiameterLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) - VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - DiameterLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) - else: - VPLMN_Dynamic_Address_Allowed = '' - - APN_Configuration_AVPS = APN_context_identifer + APN_PDN_type + APN_AMBR + APN_Service_Selection \ - + APN_EPS_Subscribed_QoS_Profile + Served_Party_Address + MIP6_Agent_Info + PDN_GW_Allocation_Type + VPLMN_Dynamic_Address_Allowed - - APN_Configuration += self.generate_vendor_avp(1430, "c0", 10415, APN_Configuration_AVPS) - - #Incriment Context Identifier Count to keep track of how many APN Profiles returned - APN_context_identifer_count = APN_context_identifer_count + 1 - DiameterLogger.debug("Processed APN profile " + str(apn_profile['apn'])) - - subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(720, 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) - subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_context_identifer + \ - self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) + APN_Configuration) - - #If MSISDN is present include it in Subscription Data - if 'msisdn' in subscriber_details: - DiameterLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") - msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, str(subscriber_details['msisdn'])) #MSISDN - DiameterLogger.debug(msisdn_avp) - subscription_data += msisdn_avp - - if 'RAT_freq_priorityID' in subscriber_details: - DiameterLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") - rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - DiameterLogger.debug(rat_freq_priorityID) - subscription_data += rat_freq_priorityID - - if '3gpp-charging-characteristics' in subscriber_details: - DiameterLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA") - _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, self.string_to_hex(str(subscriber_details['3gpp-charging-characteristics']))) - subscription_data += _3gpp_charging_characteristics - DiameterLogger.debug(_3gpp_charging_characteristics) - - - if 'APN_OI_replacement' in subscriber_details: - DiameterLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") - subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) - - - if 'GetLocation' in kwargs: - avp += self.generate_vendor_avp(1400, "c0", 10415, "") #Subscription-Data - else: - avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data - - response = self.generate_diameter_packet("01", "C0", 319, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response - - #3GPP Cx Location Information Request (LIR) - #ToDo - Check the command code here... - def Request_16777216_302(self, sipaor): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate - #Auth Session state - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex(sipaor)) #Public-Identity / SIP-AOR - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host - - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID - - - response = self.generate_diameter_packet("01", "c0", 302, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Cx User Authorization Request (UAR) - def Request_16777216_300(self, imsi, domain): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(1, 40, self.string_to_hex(imsi + "@" + domain)) #User-Name - avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + imsi + "@" + domain)) #Public-Identity - avp += self.generate_vendor_avp(600, "c0", 10415, self.string_to_hex(domain)) #Visited Network Identifier - response = self.generate_diameter_packet("01", "c0", 300, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Cx Server Assignment Request (SAR) - def Request_16777216_301(self, imsi, domain, server_assignment_type): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session Session ID - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + yaml_config['hss']['OriginHost'])),'ascii')) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + imsi + "@" + domain)) #Public-Identity - avp += self.generate_vendor_avp(602, "c0", 10415, self.string_to_hex('sip:scscf.ims.mnc' + self.MNC + '.mcc' + self.MCC + '.3gppnetwork.org:5060')) #Public-Identity - avp += self.generate_avp(1, 40, self.string_to_hex(imsi + "@" + domain)) #User-Name - avp += self.generate_vendor_avp(614, "c0", 10415, format(int(server_assignment_type),"x").zfill(8)) #Server Assignment Type - avp += self.generate_vendor_avp(624, "c0", 10415, "00000000") #User Data Already Available (Not Available) - response = self.generate_diameter_packet("01", "c0", 301, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Cx Multimedia Authentication Request (MAR) - def Request_16777216_303(self, imsi, domain): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(1, 40, self.string_to_hex(str(imsi) + "@" + domain)) #User-Name - avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + str(imsi) + "@" + domain)) #Public-Identity - avp += self.generate_vendor_avp(607, "c0", 10415, "00000001") #3GPP-SIP-Number-Auth-Items - #3GPP-SIP-Number-Auth-Data-Item - - avp += self.generate_vendor_avp(612, "c0", 10415, "00000260c0000013000028af756e6b6e6f776e0000000262c000002a000028af02e3fe1064bea4dd52602bef1c80a34ededbeb4ccabfa0430f4ffd5f1d8c0000") - avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name - response = self.generate_diameter_packet("01", "c0", 303, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Cx Registration Termination Request (RTR) - def Request_16777216_304(self, imsi, domain): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - #SIP-Deregistration-Reason - reason_code_avp = self.generate_vendor_avp(616, "c0", 10415, "00000000") - reason_info_avp = self.generate_vendor_avp(617, "c0", 10415, self.string_to_hex("Test Reason")) - avp += self.generate_vendor_avp(615, "c0", 10415, reason_code_avp + reason_info_avp) - - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host - - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(1, 40, self.string_to_hex(str(imsi) + "@" + domain)) #User-Name - avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + str(imsi) + "@" + domain)) #Public-Identity - avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name - #* [ Proxy-Info ] - proxy_host_avp = self.generate_avp(280, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - proxy_state_avp = self.generate_avp(33, "40", "0001") - avp += self.generate_avp(284, "40", proxy_host_avp + proxy_state_avp) #Proxy-Info AVP ( 284 ) - - #* [ Route-Record ] - avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - - - response = self.generate_diameter_packet("01", "c0", 304, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - - return response - - #3GPP Sh User-Data Request (UDR) - def Request_16777217_306(self, **kwargs): - avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_sh' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP - avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host - - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - - avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name - - #* [ Route-Record ] - avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - - if "msisdn" in kwargs: - msisdn = kwargs['msisdn'] - msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(msisdn))) #MSISDN - avp += self.generate_vendor_avp(700, "c0", 10415, msisdn_avp) #User-Identity - avp += self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(msisdn))) - elif "imsi" in kwargs: - imsi = kwargs['imsi'] - public_identity_avp = self.generate_vendor_avp(601, 'c0', 10415, self.string_to_hex(imsi)) #MSISDN - avp += self.generate_vendor_avp(700, "c0", 10415, public_identity_avp) #Username (IMSI) - - response = self.generate_diameter_packet("01", "c0", 306, 16777217, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - - return response - - #3GPP S13 - ME-Identity-Check Request - def Request_16777252_324(self, imsi, imei, software_version): - avp = '' - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'eir.localdomain'),'ascii')) #Destination Host - imei = self.generate_vendor_avp(1402, "c0", 10415, str(binascii.hexlify(str.encode(imei)),'ascii')) - software_version = self.generate_vendor_avp(1403, "c0", 10415, self.string_to_hex(software_version)) - avp += self.generate_vendor_avp(1401, "c0", 10415, imei + software_version) #Terminal Information - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - response = self.generate_diameter_packet("01", "c0", 324, 16777252, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP SLg - Provide Subscriber Location Request - def Request_16777255_8388620(self, imsi): - avp = '' - #ToDo - Update the Vendor Specific Application ID - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'mme-slg.localdomain'),'ascii')) #Destination Host - #SLg Location Type AVP - avp += self.generate_vendor_avp(2500, "c0", 10415, "00000000") - #Username (IMSI) - avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - #LCS-EPS-Client-Name - LCS_EPS_Client_Name = self.generate_vendor_avp(1238, "c0", 10415, str(binascii.hexlify(b'PyHSS GMLC'),'ascii')) #LCS Name String - LCS_EPS_Client_Name += self.generate_vendor_avp(1237, "c0", 10415, "00000002") #LCS Format Indicator - avp += self.generate_vendor_avp(2501, "c0", 10415, LCS_EPS_Client_Name) - #LCS-Client-Type (Emergency Services) - avp += self.generate_vendor_avp(1241, "c0", 10415, "00000000") - response = self.generate_diameter_packet("01", "c0", 8388620, 16777255, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP SLh - Provide Subscriber Location Request - def Request_16777291_8388622(self, **kwargs): - avp = '' - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777252),"x").zfill(8)) #Auth-Application-ID S13 - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_slh' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - - #Username (IMSI) - if 'imsi' in kwargs: - avp += self.generate_avp(1, 40, self.string_to_hex(str(kwargs.get('imsi')))) #Username (IMSI) - - #MSISDN (Optional) - if 'msisdn' in kwargs: - avp += self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(kwargs.get('msisdn')))) #Username (IMSI) - - #GMLC Address - avp += self.generate_vendor_avp(2405, 'c0', 10415, self.ip_to_hex('127.0.0.1')) #GMLC-Address - - response = self.generate_diameter_packet("01", "c0", 8388622, 16777291, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Gx - Credit Control Request - def Request_16777238_272(self, imsi, apn, ccr_type): - avp = '' - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_gx' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx - - #CCR Type - avp += self.generate_avp(416, 40, format(int(ccr_type),"x").zfill(8)) - avp += self.generate_avp(415, 40, format(int(0),"x").zfill(8)) - - #Subscription ID - Subscription_ID_Data = self.generate_avp(444, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) - Subscription_ID_Type = self.generate_avp(450, 40, format(int(1),"x").zfill(8)) - avp += self.generate_avp(443, 40, Subscription_ID_Type + Subscription_ID_Data) - - - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP - SupportedFeatures = '' - SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID - SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "0000000b") #Feature-List Flags - avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - - avp += self.generate_vendor_avp(1024, 80, 10415, self.int_to_hex(1, 4)) #Network Requests Supported - - avp += self.generate_avp(8, 40, binascii.b2a_hex(os.urandom(4)).decode('utf-8')) #Framed IP Address Randomly Generated - - avp += self.generate_vendor_avp(1027, 'c0', 10415, self.int_to_hex(5, 4)) #IP CAN Type (EPS) - avp += self.generate_vendor_avp(1032, 'c0', 10415, self.int_to_hex(1004, 4)) #RAT-Type (EUTRAN) - #Default EPS Bearer QoS - avp += self.generate_vendor_avp(1049, 80, 10415, - '0000041980000058000028af00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000080000041780000010000028af000000010000041880000010000028af00000001') - #3GPP-User-Location-Information - avp += self.generate_vendor_avp(22, 80, 10415, - '8205f539007b05f53900000001') - avp += self.generate_vendor_avp(23, 80, 10415, '00000000') #MS Timezone - - #Called Station ID (APN) - avp += self.generate_avp(30, 40, str(binascii.hexlify(str.encode(apn)),'ascii')) - - response = self.generate_diameter_packet("01", "c0", 272, 16777238, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Gx - Re Auth Request - def Request_16777238_258(self, sessionid, ChargingRules, ue_ip, Serving_PGW, Serving_Realm): - avp = '' - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP - - #Setup Charging Rule - DiameterLogger.debug(ChargingRules) - avp += self.Charging_Rule_Generator(ChargingRules=ChargingRules, ue_ip=ue_ip) - - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(293, 40, self.string_to_hex(Serving_PGW)) #Destination Host - avp += self.generate_avp(283, 40, self.string_to_hex(Serving_Realm)) #Destination Realm - - avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx - - avp += self.generate_avp(285, 40, format(int(0),"x").zfill(8)) #Re-Auth Request TYpe - - response = self.generate_diameter_packet("01", "c0", 258, 16777238, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP Gy - Credit Control Request - def Request_4_272(self, sessionid, imsi, CC_Request_Type, input_octets, output_octets): - avp = '' - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP - - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm - - avp += self.generate_avp(258, 40, format(int(4),"x").zfill(8)) #Auth-Application-ID Gx - avp += self.generate_avp(461, 40, self.string_to_hex("open5gs-smfd@open5gs.org")) #Service Context ID - avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC Request Type - avp += self.generate_avp(415, 40, format(int(0),"x").zfill(8)) #CC Request Number - avp += self.generate_avp(55, 40, '00000000') #Event Timestamp - - #Subscription ID - Subscription_ID_Data = self.generate_avp(444, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) - Subscription_ID_Type = self.generate_avp(450, 40, format(int(1),"x").zfill(8)) - avp += self.generate_avp(443, 40, Subscription_ID_Type + Subscription_ID_Data) - - avp += self.generate_avp(436, 40, format(int(0),"x").zfill(8)) #Requested Action (Direct Debiting) - - avp += self.generate_vendor_avp(2055, 'c0', 10415, "00000001") #AoC_FULL (1) - - avp += self.generate_avp(455, 40, format(int(0),"x").zfill(8)) #Multiple Services Indicator (Not Supported) - if int(CC_Request_Type) == 1: - mscc = '' #Multiple Services Credit Control - mscc += self.generate_avp(437, 40, '') #Requested Service Unit - used_service_unit = '' - used_service_unit += self.generate_avp(420, 40, format(int(0),"x").zfill(8)) #Time - used_service_unit += self.generate_avp(412, 40, format(int(0),"x").zfill(16)) #Input Octets - used_service_unit += self.generate_avp(414, 40, format(int(0),"x").zfill(16)) #Output Octets - mscc += self.generate_avp(446, 40, used_service_unit) #Used Service Unit - mscc += self.generate_vendor_avp(1016, 'c0', 10415, #QoS Information - "00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000090000041780000010000028af000000000000041880000010000028af000000000000041180000010000028af061a80000000041080000010000028af061a8000") - mscc += self.generate_vendor_avp(21, 'c0', 10415, '000028af') #3GPP RAT Type (WB-EUTRAN) - avp += self.generate_avp(456, 40, mscc) - - elif int(CC_Request_Type) == 2: - mscc = '' #Multiple Services Credit Control - mscc += self.generate_avp(437, 40, '') #Requested Service Unit - used_service_unit = '' - used_service_unit += self.generate_avp(420, 40, format(int(0),"x").zfill(8)) #Time - used_service_unit += self.generate_avp(412, 40, format(int(input_octets),"x").zfill(16)) #Input Octets - used_service_unit += self.generate_avp(414, 40, format(int(output_octets),"x").zfill(16)) #Output Octets - mscc += self.generate_avp(446, 40, used_service_unit) #Used Service Unit - mscc += self.generate_vendor_avp(872, 'c0', 10415, format(int(4),"x").zfill(8)) #3GPP Reporting Reason (Validity Time (4)) - mscc += self.generate_vendor_avp(1016, 'c0', 10415, #QoS Information - "00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000090000041780000010000028af000000000000041880000010000028af000000000000041180000010000028af061a80000000041080000010000028af061a8000") - mscc += self.generate_vendor_avp(21, 'c0', 10415, '000028af') #3GPP RAT Type (WB-EUTRAN) - avp += self.generate_avp(456, 40, mscc) - elif int(CC_Request_Type) == 3: - #Multiple Services Credit Control - avp += self.generate_avp(456, 40, - "000001be40000034000001a44000000c000000000000019c4000001000000000000000000000019e40000010000000000000000000000368c0000010000028af00000002000003f8c0000078000028af00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000020000041780000010000028af000000010000041880000010000028af000000000000041180000010000028af020000000000041080000010000028af0320000000000015c000000d000028af06000000") - - #Service Information - avp += self.generate_vendor_avp(873, 'c0', 10415, - "0000036ac00000d8000028af00000002c0000010000028af0000010400000003c0000010000028af00000000000004cbc0000012000028af00010a2d01050000000004ccc0000012000028af0001ac1212ca00000000034fc0000012000028af0001ac12120400000000001e40000010696e7465726e65740000000cc000000d000028af300000000000000dc0000010000028af3030303000000012c0000011000028af30303130310000000000000ac000000d000028af0100000000000016c0000019000028af8200f110000100f11000000017000000") - response = self.generate_diameter_packet("01", "c0", 272, 4, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - - #3GPP Sh - Profile Update Request - def Request_16777217_307(self, msisdn): - avp = '' - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_sh' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777217),"x").zfill(8)) #Auth-Application-ID Gx - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host - avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - DiameterLogger.debug("Getting susbcriber IMS info based on MSISDN") - subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - DiameterLogger.debug("Getting susbcriber info based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber details: " + str(subscriber_details)) - subscriber_details = {**subscriber_details, **subscriber_ims_details} - DiameterLogger.debug("Merged subscriber details: " + str(subscriber_details)) - - avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(subscriber_details['imsi'])),'ascii')) #Username AVP - - - #Sh-User-Data (XML) - #This loads a Jinja XML template containing the Sh-User-Data - templateLoader = jinja2.FileSystemLoader(searchpath="./") - templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = yaml_config['hss']['Default_Sh_UserData'] - DiameterLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") - template = templateEnv.get_template(sh_userdata_template) - #These variables are passed to the template for use - subscriber_details['mnc'] = self.MNC.zfill(3) - subscriber_details['mcc'] = self.MCC.zfill(3) - - DiameterLogger.debug("Rendering template with values: " + str(subscriber_details)) - xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer - avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) - - response = self.generate_diameter_packet("01", "c0", 307, 16777217, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response - - #3GPP S13 - ME-Identity-Check Request - def Request_16777252_324(self, imei, imsi): - avp = '' - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_s13' #Session state generate - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP - #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- - VendorSpecificApplicationId = '' - VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID - VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx - avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host - avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - - avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) #Username AVP - TerminalInformation = '' - TerminalInformation += self.generate_vendor_avp(1402, 'c0', 10415, str(binascii.hexlify(str.encode(imei)),'ascii')) - TerminalInformation += self.generate_vendor_avp(1403, 'c0', 10415, str(binascii.hexlify(str.encode('00')),'ascii')) - avp += self.generate_vendor_avp(1401, 'c0', 10415, TerminalInformation) - - - response = self.generate_diameter_packet("01", "c0", 324, 16777252, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet - return response \ No newline at end of file diff --git a/hss.py b/hss.py deleted file mode 100644 index f32c53e..0000000 --- a/hss.py +++ /dev/null @@ -1,1012 +0,0 @@ -# PyHSS -# This serves as a basic 3GPP Home Subscriber Server implimenting a EIR & IMS HSS functionality -import logging -import yaml -import os -import sys -import socket -import socketserver -import binascii -import time -import _thread -import threading -import sctp -import traceback -import pprint -import diameter as DiameterLib -import systemd.daemon -from threading import Thread, Lock -from logtool import * -import contextlib -import queue - - -class ThreadJoiner: - def __init__(self, threads, thread_event): - self.threads = threads - self.thread_event = thread_event - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is not None: - self.thread_event.set() - for thread in self.threads: - while thread.is_alive(): - try: - thread.join(timeout=1) - except Exception as e: - print( - f"ThreadJoiner Exception: failed to join thread {thread}: {e}" - ) - break - - -class PyHSS: - def __init__(self): - # Load config from yaml file - try: - with open("config.yaml", "r") as config_stream: - self.yaml_config = yaml.safe_load(config_stream) - except: - print(f"config.yaml not found, exiting PyHSS.") - quit() - - # Setup logging - self.logtool = LogTool(HSS_Init=True) - self.logtool.setup_logger( - "HSS_Logger", - self.yaml_config["logging"]["logfiles"]["hss_logging_file"], - level=self.yaml_config["logging"]["level"], - ) - self.logger = logging.getLogger("HSS_Logger") - if self.yaml_config["logging"]["log_to_terminal"]: - logging.getLogger().addHandler(logging.StreamHandler()) - - # Setup Diameter - self.diameter_instance = DiameterLib.Diameter( - str(self.yaml_config["hss"].get("OriginHost", "")), - str(self.yaml_config["hss"].get("OriginRealm", "")), - str(self.yaml_config["hss"].get("ProductName", "")), - str(self.yaml_config["hss"].get("MNC", "")), - str(self.yaml_config["hss"].get("MCC", "")), - ) - - self.max_diameter_retries = int( - self.yaml_config["hss"].get("diameter_max_retries", 1) - ) - - - - try: - assert(self.yaml_config['prometheus']['enabled'] == True) - assert(self.yaml_config['prometheus']['async_subscriber_count'] == True) - - self.logger.info("Enabling Prometheus Async Sub thread") - #Add Prometheus Async Calls - prom_async_thread = threading.Thread( - target=self.prom_async_function, - name=f"prom_async_function", - args=(), - ) - prom_async_thread.start() - except: - self.logger.info("Prometheus Async Sub Count thread disabled") - - - - def terminate_connection(self, clientsocket, client_address, thread_event): - thread_event.set() - clientsocket.close() - self.logtool.Manage_Diameter_Peer(client_address, client_address, "remove") - - def handle_new_connection(self, clientsocket, client_address): - # Create our threading event, accessible by sibling threads in this connection. - socket_close_event = threading.Event() - try: - send_queue = queue.Queue() - self.logger.debug(f"New connection from {client_address}") - if ( - "client_socket_timeout" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["client_socket_timeout"] == 0 - ): - self.yaml_config["hss"]["client_socket_timeout"] = 120 - clientsocket.settimeout( - self.yaml_config["hss"].get("client_socket_timeout", 120) - ) - - send_data_thread = threading.Thread( - target=self.send_data, - name=f"send_data_thread", - args=(clientsocket, send_queue, socket_close_event), - ) - self.logger.debug("handle_new_connection: Starting send_data thread") - send_data_thread.start() - - self.logtool.Manage_Diameter_Peer(client_address, client_address, "add") - manage_client_thread = threading.Thread( - target=self.manage_client, - name=f"manage_client_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug("handle_new_connection: Starting manage_client thread") - manage_client_thread.start() - - threads_to_join = [manage_client_thread] - threads_to_join.append(send_data_thread) - - # If Redis is enabled, start manage_client_async and manage_client_dwr threads. - if self.yaml_config["redis"]["enabled"]: - if ( - "async_check_interval" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["async_check_interval"] == 0 - ): - self.yaml_config["hss"]["async_check_interval"] = 10 - manage_client_async_thread = threading.Thread( - target=self.manage_client_async, - name=f"manage_client_async_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_async thread" - ) - manage_client_async_thread.start() - - manage_client_dwr_thread = threading.Thread( - target=self.manage_client_dwr, - name=f"manage_client_dwr_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_dwr thread" - ) - manage_client_dwr_thread.start() - - threads_to_join.append(manage_client_async_thread) - threads_to_join.append(manage_client_dwr_thread) - - self.logger.debug( - f"handle_new_connection: Total PyHSS Active Threads: {threading.active_count()}" - ) - for thread in threading.enumerate(): - if "dummy" not in thread.name.lower(): - self.logger.debug(f"Active Thread name: {thread.name}") - - with ThreadJoiner(threads_to_join, socket_close_event): - socket_close_event.wait() - self.terminate_connection( - clientsocket, client_address, socket_close_event - ) - self.logger.debug(f"Closing thread for client; {client_address}") - return - - except Exception as e: - self.logger.error(f"Exception for client {client_address}: {e}") - self.logger.error(f"Closing connection for {client_address}") - self.terminate_connection(clientsocket, client_address, socket_close_event) - return - - @prom_diam_response_time_diam.time() - def process_Diameter_request( - self, clientsocket, client_address, diameter, data, thread_event, send_queue - ): - packet_length = diameter.decode_diameter_packet_length( - data - ) # Calculate length of packet from start of packet - if packet_length <= 32: - self.logger.error("Received an invalid packet with length <= 32") - self.terminate_connection(clientsocket, client_address, thread_event) - return - - data_sum = data + clientsocket.recv( - packet_length - 32 - ) # Recieve remainder of packet from buffer - packet_vars, avps = diameter.decode_diameter_packet( - data_sum - ) # Decode packet into array of AVPs and Dict of Packet Variables (packet_vars) - try: - packet_vars["Source_IP"] = client_address[0] - except: - self.logger.debug("Failed to add Source_IP to packet_vars") - - start_time = time.time() - origin_host = diameter.get_avp_data(avps, 264)[0] # Get OriginHost from AVP - origin_host = binascii.unhexlify(origin_host).decode("utf-8") # Format it - - # label_values = str(packet_vars['ApplicationId']), str(packet_vars['command_code']), origin_host, 'request' - prom_diam_request_count.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).inc() - - - self.logger.info( - "\n\nNew request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + ", flags " - + str(packet_vars["flags"]) - + ", e2e ID: " - + str(packet_vars["end-to-end-identifier"]) - ) - - # Gobble up any Response traffic that is sent to us: - if packet_vars["flags_bin"][0:1] == "0": - self.logger.info("Got a Response, not a request - dropping it.") - self.logger.info(packet_vars) - return - - # Send Capabilities Exchange Answer (CEA) response to Capabilites Exchange Request (CER) - elif ( - packet_vars["command_code"] == 257 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 257 (CER) from {origin_host}" - + "\n\tSending response (CEA)" - ) - try: - response = diameter.Answer_257( - packet_vars, avps, str(self.yaml_config["hss"]["bind_ip"][0]) - ) # Generate Diameter packet - # prom_diam_response_count_successful.inc() - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - # prom_diam_response_count_fail.inc() - self.logger.info("Generated CEA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - prom_diam_connected_peers.labels(origin_host).set(1) - - # Send Credit Control Answer (CCA) response to Credit Control Request (CCR) - elif ( - packet_vars["command_code"] == 272 - and packet_vars["ApplicationId"] == 16777238 - ): - self.logger.info( - f"Received 3GPP Credit-Control-Request from {origin_host}" - + "\n\tGenerating (CCA)" - ) - try: - response = diameter.Answer_16777238_272( - packet_vars, avps - ) # Generate Diameter packet - except Exception as E: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error(f"Failed to generate response {str(E)}") - self.logger.info("Generated CCA") - - # Send Device Watchdog Answer (DWA) response to Device Watchdog Requests (DWR) - elif ( - packet_vars["command_code"] == 280 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 280 (DWR) from {origin_host}" - + "\n\tSending response (DWA)" - ) - self.logger.debug(f"Total PyHSS Active Threads: {threading.active_count()}") - try: - response = diameter.Answer_280( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.info("Generated DWA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - - # Send Disconnect Peer Answer (DPA) response to Disconnect Peer Request (DPR) - elif ( - packet_vars["command_code"] == 282 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 282 (DPR) from {origin_host}" - + "\n\tForwarding request..." - ) - response = diameter.Answer_282( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated DPA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "remove") - prom_diam_connected_peers.labels(origin_host).set(0) - - # S6a Authentication Information Answer (AIA) response to Authentication Information Request (AIR) - elif ( - packet_vars["command_code"] == 318 - and packet_vars["ApplicationId"] == 16777251 - and packet_vars["flags"] == "c0" - ): - self.logger.info( - f"Received Request with command code 318 (3GPP Authentication-Information-Request) from {origin_host}" - + "\n\tGenerating (AIA)" - ) - try: - response = diameter.Answer_16777251_318( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated AIR") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for AIR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated DIAMETER_USER_DATA_NOT_AVAILABLE AIR") - - # S6a Update Location Answer (ULA) response to Update Location Request (ULR) - elif ( - packet_vars["command_code"] == 316 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 316 (3GPP Update Location-Request) from {origin_host}" - + "\n\tGenerating (ULA)" - ) - try: - response = diameter.Answer_16777251_316( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated ULA") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for ULR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated error DIAMETER_USER_DATA_NOT_AVAILABLE ULA") - - # Send ULA data & clear tx buffer - clientsocket.sendall(bytes.fromhex(response)) - response = "" - if "Insert_Subscriber_Data_Force" in yaml_config["hss"]: - if yaml_config["hss"]["Insert_Subscriber_Data_Force"] == True: - self.logger.debug("ISD triggered after ULA") - # Generate Insert Subscriber Data Request - response = diameter.Request_16777251_319( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated IDR") - # Send ISD data - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent IDR") - return - # S6a inbound Insert-Data-Answer in response to our IDR - elif ( - packet_vars["command_code"] == 319 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received response with command code 319 (3GPP Insert-Subscriber-Answer) from {origin_host}" - ) - return - # S6a Purge UE Answer (PUA) response to Purge UE Request (PUR) - elif ( - packet_vars["command_code"] == 321 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 321 (3GPP Purge UE Request) from {origin_host}" - + "\n\tGenerating (PUA)" - ) - try: - response = diameter.Answer_16777251_321( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate PUA") - self.logger.info("Generated PUA") - # S6a Notify Answer (NOA) response to Notify Request (NOR) - elif ( - packet_vars["command_code"] == 323 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 323 (3GPP Notify Request) from {origin_host}" - + "\n\tGenerating (NOA)" - ) - try: - response = diameter.Answer_16777251_323( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate NOA") - self.logger.info("Generated NOA") - # S6a Cancel Location Answer eater - elif ( - packet_vars["command_code"] == 317 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info("Received Response with command code 317 (3GPP Cancel Location Request) from " + str(origin_host)) - - # Cx Authentication Answer - elif ( - packet_vars["command_code"] == 300 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 300 (3GPP Cx User Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_300( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Auth Answer" - ) - self.logger.info(e) - self.logger.info(traceback.print_exc()) - self.logger.info( - type(e).__name__, # TypeError - __file__, # /tmp/example.py - e.__traceback__.tb_lineno # 2 - ) - - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Auth Answer") - - # Cx Server Assignment Answer - elif ( - packet_vars["command_code"] == 301 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 301 (3GPP Cx Server Assignemnt Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_301( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Server Assignment Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Server Assignment Answer") - - # Cx Location Information Answer - elif ( - packet_vars["command_code"] == 302 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 302 (3GPP Cx Location Information Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_302( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Location Information Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Location Information Answer") - - # Cx Multimedia Authentication Answer - elif ( - packet_vars["command_code"] == 303 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 303 (3GPP Cx Multimedia Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_303( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Multimedia Authentication Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Multimedia Authentication Answer") - - # Sh User-Data-Answer - elif ( - packet_vars["command_code"] == 306 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 306 (3GPP Sh User-Data Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_306( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh User-Data Answer") - - # Sh Profile-Update-Answer - elif ( - packet_vars["command_code"] == 307 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 307 (3GPP Sh Profile-Update Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_307( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh Profile-Update Answer") - - # S13 ME-Identity-Check Answer - elif ( - packet_vars["command_code"] == 324 - and packet_vars["ApplicationId"] == 16777252 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP S13 ME-Identity-Check Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777252_324( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for S13 ME-Identity Check Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated S13 ME-Identity Check Answer") - - # SLh LCS-Routing-Info-Answer - elif ( - packet_vars["command_code"] == 8388622 - and packet_vars["ApplicationId"] == 16777291 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP SLh LCS-Routing-Info-Answer Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777291_8388622( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for SLh LCS-Routing-Info-Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated SLh LCS-Routing-Info-Answer") - - # Handle Responses generated by the Async functions - elif packet_vars["flags"] == "00": - self.logger.info( - "Got response back with command code " - + str(packet_vars["command_code"]) - ) - self.logger.info("response packet_vars: " + str(packet_vars)) - self.logger.info("response avps: " + str(avps)) - response = "" - else: - self.logger.error( - "\n\nRecieved unrecognised request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + " and flags " - + str(packet_vars["flags"]) - ) - for keys in packet_vars: - self.logger.error(keys) - self.logger.error("\t" + str(packet_vars[keys])) - self.logger.error(avps) - self.logger.error("Sending negative response") - response = diameter.Respond_ResultCode( - packet_vars, avps, 3001 - ) # Generate Diameter response with "Command Unsupported" (3001) - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - - prom_diam_response_time_method.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).observe(time.time() - start_time) - - # Diameter Transmission - retries = 0 - while retries < self.max_diameter_retries: - try: - send_queue.put(bytes.fromhex(response)) - break - except socket.error as e: - self.logger.error(f"Socket error for client {client_address}: {e}") - retries += 1 - if retries > self.max_diameter_retries: - self.logger.error( - f"Max retries reached for client {client_address}. Closing connection." - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - break - time.sleep(1) # Wait for 1 second before retrying - except Exception as e: - self.logger.info("Failed to send Diameter Response") - self.logger.debug(f"Diameter Response Body: {str(response)}") - self.logger.info(e) - traceback.print_exc() - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info("Thread terminated to " + str(client_address)) - break - - def manage_client( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - data = clientsocket.recv(32) - if not data: - self.logger.info( - f"manage_client: Connection closed by {str(client_address)}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - self.process_Diameter_request( - clientsocket, - client_address, - diameter, - data, - thread_event, - send_queue, - ) - - except socket.timeout: - self.logger.warning( - f"manage_client: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except KeyboardInterrupt: - # Clean up the connection on keyboard interrupt - response = ( - diameter.Request_282() - ) # Generate Disconnect Peer Request Diameter packet - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info( - "manage_client: Connection closed nicely due to keyboard interrupt" - ) - sys.exit() - - except Exception as manage_client_exception: - self.logger.error( - f"manage_client: Exception in manage_client: {manage_client_exception}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_async( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - # # Sleep for 10 seconds to wait for the connection to come up - time.sleep(10) - self.logger.debug("manage_client_async: Getting ActivePeerDict") - self.logger.debug( - f"manage_client_async: Total PyHSS Active Threads: {threading.active_count()}" - ) - ActivePeerDict = self.logtool.GetDiameterPeers() - self.logger.debug( - f"manage_client_async: Got Active Peer dict in Async Thread: {str(ActivePeerDict)}" - ) - if client_address[0] in ActivePeerDict: - self.logger.debug( - "manage_client_async: This is host: " - + str(ActivePeerDict[str(client_address[0])]["DiameterHostname"]) - ) - DiameterHostname = str( - ActivePeerDict[str(client_address[0])]["DiameterHostname"] - ) - else: - self.logger.debug("manage_client_async: No matching Diameter Host found.") - return - - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"manage_client_async: Closing manage_client_async thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - time.sleep(self.yaml_config["hss"]["async_check_interval"]) - self.logger.debug( - f"manage_client_async: Sleep interval expired for Diameter Peer {str(DiameterHostname)}" - ) - if int(self.yaml_config["hss"]["async_check_interval"]) == 0: - self.logger.error( - f"manage_client_async: No async_check_interval Timer set - Not checking Async Queue for host connection {str(DiameterHostname)}" - ) - return - try: - self.logger.debug( - "manage_client_async: Reading from request queue '" - + str(DiameterHostname) - + "_request_queue'" - ) - data_to_send = self.logtool.RedisHMGET( - str(DiameterHostname) + "_request_queue" - ) - for key in data_to_send: - data = data_to_send[key].decode("utf-8") - send_queue.put(bytes.fromhex(data)) - self.logtool.RedisHDEL( - str(DiameterHostname) + "_request_queue", key - ) - except Exception as redis_exception: - self.logger.error( - f"manage_client_async: Redis exception in manage_client_async: {redis_exception}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - - except socket.timeout: - self.logger.warning( - f"manage_client_async: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client_async: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception: - self.logger.error( - f"manage_client_async: Terminating for host connection {str(DiameterHostname)}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_dwr( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"Closing manage_client_dwr thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - if ( - int(self.yaml_config["hss"]["device_watchdog_request_interval"]) - != 0 - ): - time.sleep( - self.yaml_config["hss"]["device_watchdog_request_interval"] - ) - else: - self.logger.info("DWR Timer to set to 0 - Not sending DWRs") - return - - except: - self.logger.error( - "No DWR Timer set - Not sending Device Watchdog Requests" - ) - return - try: - self.logger.debug("Sending Keepalive to " + str(client_address) + "...") - request = diameter.Request_280() - send_queue.put(bytes.fromhex(request)) - # clientsocket.sendall(bytes.fromhex(request)) # Send it - self.logger.debug("Sent Keepalive to " + str(client_address) + "...") - except socket.error as e: - self.logger.error( - f"manage_client_dwr: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception as e: - self.logger.error( - f"manage_client_dwr: General exception for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - - def get_socket_family(self): - if ":" in self.yaml_config["hss"]["bind_ip"][0]: - self.logger.info("IPv6 Address Specified") - return socket.AF_INET6 - else: - self.logger.info("IPv4 Address Specified") - return socket.AF_INET - - def send_data(self, clientsocket, send_queue, thread_event): - while not thread_event.is_set(): - try: - data = send_queue.get(timeout=1) - # Check if data is bytes, otherwise convert it using bytes.fromhex() - if not isinstance(data, bytes): - data = bytes.fromhex(data) - - clientsocket.sendall(data) - except ( - queue.Empty - ): # Catch the Empty exception when the queue is empty and the timeout has expired - continue - except Exception as e: - self.logger.error(f"send_data_thread: Exception: {e}") - return - - def start_server(self): - if self.yaml_config["hss"]["transport"] == "SCTP": - self.logger.debug("Using SCTP for Transport") - # Create a SCTP socket - sock = sctp.sctpsocket_tcp(self.get_socket_family()) - sock.initparams.num_ostreams = 64 - # Loop through the possible Binding IPs from the config and bind to each for Multihoming - server_addresses = [] - - # Prepend each entry into list, so the primary IP is bound first - for host in self.yaml_config["hss"]["bind_ip"]: - self.logger.info("Seting up SCTP binding on IP address " + str(host)) - this_IP_binding = [ - (str(host), int(self.yaml_config["hss"]["bind_port"])) - ] - server_addresses = this_IP_binding + server_addresses - - print("server_addresses are: " + str(server_addresses)) - sock.bindx(server_addresses) - self.logger.info("PyHSS listening on SCTP port " + str(server_addresses)) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming SCTP connections - sock.listen(20) - elif self.yaml_config["hss"]["transport"] == "TCP": - self.logger.debug("Using TCP socket") - # Create a TCP/IP socket - sock = socket.socket(self.get_socket_family(), socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Bind the socket to the port - server_address = ( - str(self.yaml_config["hss"]["bind_ip"][0]), - int(self.yaml_config["hss"]["bind_port"]), - ) - sock.bind(server_address) - self.logger.debug( - "PyHSS listening on TCP port " - + str(self.yaml_config["hss"]["bind_ip"][0]) - ) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming TCP connections - sock.listen(20) - else: - self.logger.error("No valid transports found (No SCTP or TCP) - Exiting") - quit() - - while True: - # Wait for a connection - self.logger.info("Waiting for a connection...") - connection, client_address = sock.accept() - _thread.start_new_thread( - self.handle_new_connection, - ( - connection, - client_address, - ), - ) - - - def prom_async_function(self): - while True: - self.logger.debug("Running prom_async_function") - self.diameter_instance.Generate_Prom_Stats() - time.sleep(120) - - -if __name__ == "__main__": - pyHss = PyHSS() - pyHss.start_server() diff --git a/lib/S6a_crypt.py b/lib/S6a_crypt.py index 0a489b3..c1ab38f 100755 --- a/lib/S6a_crypt.py +++ b/lib/S6a_crypt.py @@ -2,17 +2,16 @@ import binascii import base64 import logging -import logtool import os import sys sys.path.append(os.path.realpath('../')) import yaml -with open("config.yaml", 'r') as stream: +with open("../config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -logtool = logtool.LogTool() -logtool.setup_logger('CryptoLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) +# logtool = logtool.LogTool() +# logtool.setup_logger('CryptoLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) CryptoLogger = logging.getLogger('CryptoLogger') CryptoLogger.info("Initialised Diameter Logger, importing database") diff --git a/lib/banners.py b/lib/banners.py new file mode 100644 index 0000000..933a361 --- /dev/null +++ b/lib/banners.py @@ -0,0 +1,92 @@ +class Banners: + + def diameterService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Diameter Service + +""" + return bannerText + + + def hssService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + HSS Service + +""" + return bannerText + + def georedService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Geographic Redundancy Service + +""" + return bannerText + + def metricService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Metric Service + +""" + return bannerText + + def logService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Log Service + +""" + return bannerText \ No newline at end of file diff --git a/lib/database.py b/lib/database.py new file mode 100755 index 0000000..7def195 --- /dev/null +++ b/lib/database.py @@ -0,0 +1,2502 @@ +from sqlalchemy import Column, Integer, String, MetaData, Table, Boolean, ForeignKey, select, UniqueConstraint, DateTime, BigInteger, Text, DateTime, Float +from sqlalchemy import create_engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql import desc, func +from sqlalchemy_utils import database_exists, create_database +from sqlalchemy.orm import sessionmaker, relationship, Session, class_mapper +from sqlalchemy.orm.attributes import History, get_history +from sqlalchemy.ext.declarative import declarative_base +import os +import datetime, time +from datetime import timezone +import re +import binascii +import uuid +import socket +import pprint +import S6a_crypt +from messaging import RedisMessaging +import yaml +import json +import traceback + +Base = declarative_base() + +class OPERATION_LOG_BASE(Base): + __tablename__ = 'operation_log' + id = Column(Integer, primary_key=True) + item_id = Column(Integer, nullable=False) + operation_id = Column(String(36), nullable=False) + operation = Column(String(10)) + changes = Column(Text) + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc)) + timestamp = Column(DateTime, default=func.now()) + table_name = Column('table_name', String(255)) + __mapper_args__ = {'polymorphic_on': table_name} + +class APN_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'apn'} + apn = relationship("APN", back_populates="operation_logs") + apn_id = Column(Integer, ForeignKey('apn.apn_id')) + +class SUBSCRIBER_ROUTING_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'subscriber_routing'} + subscriber_routing = relationship("SUBSCRIBER_ROUTING", back_populates="operation_logs") + subscriber_routing_id = Column(Integer, ForeignKey('subscriber_routing.subscriber_routing_id')) + +class SERVING_APN_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'serving_apn'} + serving_apn = relationship("SERVING_APN", back_populates="operation_logs") + serving_apn_id = Column(Integer, ForeignKey('serving_apn.serving_apn_id')) + +class AUC_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'auc'} + auc = relationship("AUC", back_populates="operation_logs") + auc_id = Column(Integer, ForeignKey('auc.auc_id')) + +class SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'subscriber'} + subscriber = relationship("SUBSCRIBER", back_populates="operation_logs") + subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id')) + +class IMS_SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'ims_subscriber'} + ims_subscriber = relationship("IMS_SUBSCRIBER", back_populates="operation_logs") + ims_subscriber_id = Column(Integer, ForeignKey('ims_subscriber.ims_subscriber_id')) + +class CHARGING_RULE_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'charging_rule'} + charging_rule = relationship("CHARGING_RULE", back_populates="operation_logs") + charging_rule_id = Column(Integer, ForeignKey('charging_rule.charging_rule_id')) + +class TFT_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'tft'} + tft = relationship("TFT", back_populates="operation_logs") + tft_id = Column(Integer, ForeignKey('tft.tft_id')) + +class EIR_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'eir'} + eir = relationship("EIR", back_populates="operation_logs") + eir_id = Column(Integer, ForeignKey('eir.eir_id')) + +class IMSI_IMEI_HISTORY_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'eir_history'} + eir_history = relationship("IMSI_IMEI_HISTORY", back_populates="operation_logs") + imsi_imei_history_id = Column(Integer, ForeignKey('eir_history.imsi_imei_history_id')) + +class SUBSCRIBER_ATTRIBUTES_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'subscriber_attributes'} + subscriber_attributes = relationship("SUBSCRIBER_ATTRIBUTES", back_populates="operation_logs") + subscriber_attributes_id = Column(Integer, ForeignKey('subscriber_attributes.subscriber_attributes_id')) + +class APN(Base): + __tablename__ = 'apn' + apn_id = Column(Integer, primary_key=True, doc='Unique ID of APN') + apn = Column(String(50), nullable=False, doc='Short name of the APN') + ip_version = Column(Integer, default=0, doc="IP version used - 0: ipv4, 1: ipv6 2: ipv4+6 3: ipv4 or ipv6 [3GPP TS 29.272 7.3.62]") + pgw_address = Column(String(50), doc='IP of the PGW') + sgw_address = Column(String(50), doc='IP of the SGW') + charging_characteristics = Column(String(4), default='0800', doc='For the encoding of this information element see 3GPP TS 32.298 [9]') + apn_ambr_dl = Column(Integer, nullable=False, doc='Downlink Maximum Bit Rate for this APN') + apn_ambr_ul = Column(Integer, nullable=False, doc='Uplink Maximum Bit Rate for this APN') + qci = Column(Integer, default=9, doc='QoS Class Identifier') + arp_priority = Column(Integer, default=4, doc='Allocation and Retention Policy - Bearer priority level (1-15)') + arp_preemption_capability = Column(Boolean, default=False, doc='Allocation and Retention Policy - Capability to Preempt resources from other Subscribers') + arp_preemption_vulnerability = Column(Boolean, default=True, doc='Allocation and Retention Policy - Vulnerability to have resources Preempted by other Subscribers') + charging_rule_list = Column(String(18), doc='Comma separated list of predefined ChargingRules to be installed in CCA-I') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("APN_OPERATION_LOG", back_populates="apn") + +class SUBSCRIBER_ROUTING(Base): + __tablename__ = 'subscriber_routing' + __table_args__ = ( + # this can be db.PrimaryKeyConstraint if you want it to be a primary key + UniqueConstraint('subscriber_id', 'apn_id'), + ) + subscriber_routing_id = Column(Integer, primary_key=True, doc='Unique ID of Subscriber Routing item') + subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='subscriber_id of the served subscriber') + apn_id = Column(Integer, ForeignKey('apn.apn_id', ondelete='CASCADE'), doc='apn_id of the target apn') + ip_version = Column(Integer, default=0, doc="IP version used - 0: ipv4, 1: ipv6 2: ipv4+6 3: ipv4 or ipv6 [3GPP TS 29.272 7.3.62]") + ip_address = Column(String(254), doc='IP of the UE') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("SUBSCRIBER_ROUTING_OPERATION_LOG", back_populates="subscriber_routing") + +class AUC(Base): + __tablename__ = 'auc' + auc_id = Column(Integer, primary_key = True, doc='Unique ID of AuC entry') + ki = Column(String(32), doc='SIM Key - Authentication Key - Ki', nullable=False) + opc = Column(String(32), doc='SIM Key - Network Operators key OPc', nullable=False) + amf = Column(String(4), doc='Authentication Management Field', nullable=False) + sqn = Column(BigInteger, doc='Authentication sequence number') + iccid = Column(String(20), unique=True, doc='Integrated Circuit Card Identification Number') + imsi = Column(String(18), unique=True, doc='International Mobile Subscriber Identity') + batch_name = Column(String(20), doc='Name of SIM Batch') + sim_vendor = Column(String(20), doc='SIM Vendor') + esim = Column(Boolean, default=0, doc='Card is eSIM') + lpa = Column(String(128), doc='LPA URL for activating eSIM') + pin1 = Column(String(20), doc='PIN1') + pin2 = Column(String(20), doc='PIN2') + puk1 = Column(String(20), doc='PUK1') + puk2 = Column(String(20), doc='PUK2') + kid = Column(String(20), doc='KID') + psk = Column(String(128), doc='PSK') + des = Column(String(128), doc='DES') + adm1 = Column(String(20), doc='ADM1') + misc1 = Column(String(128), doc='For misc data storage 1') + misc2 = Column(String(128), doc='For misc data storage 2') + misc3 = Column(String(128), doc='For misc data storage 3') + misc4 = Column(String(128), doc='For misc data storage 4') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("AUC_OPERATION_LOG", back_populates="auc") + +class SUBSCRIBER(Base): + __tablename__ = 'subscriber' + subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of Subscriber entry') + imsi = Column(String(18), unique=True, doc='International Mobile Subscriber Identity') + enabled = Column(Boolean, default=1, doc='Subscriber enabled/disabled') + auc_id = Column(Integer, ForeignKey('auc.auc_id'), doc='Reference to AuC ID defined with SIM Auth data', nullable=False) + default_apn = Column(Integer, ForeignKey('apn.apn_id'), doc='APN ID to use for the default APN', nullable=False) + apn_list = Column(String(64), doc='Comma separated list of allowed APNs', nullable=False) + msisdn = Column(String(18), doc='Primary Phone number of Subscriber') + ue_ambr_dl = Column(Integer, default=999999, doc='Downlink Aggregate Maximum Bit Rate') + ue_ambr_ul = Column(Integer, default=999999, doc='Uplink Aggregate Maximum Bit Rate') + nam = Column(Integer, default=0, doc='Network Access Mode [3GPP TS. 123 008 2.1.1.2] - 0 (PACKET_AND_CIRCUIT) or 2 (ONLY_PACKET)') + subscribed_rau_tau_timer = Column(Integer, default=300, doc='Subscribed periodic TAU/RAU timer value in seconds') + serving_mme = Column(String(512), doc='MME serving this subscriber') + serving_mme_timestamp = Column(DateTime, doc='Timestamp of attach to MME') + serving_mme_realm = Column(String(512), doc='Realm of serving mme') + serving_mme_peer = Column(String(512), doc='Diameter peer used to reach MME then ; then the HSS the Diameter peer is connected to') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("SUBSCRIBER_OPERATION_LOG", back_populates="subscriber") + +class SERVING_APN(Base): + __tablename__ = 'serving_apn' + serving_apn_id = Column(Integer, primary_key=True, doc='Unique ID of SERVING_APN') + subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='subscriber_id of the served subscriber') + apn = Column(Integer, ForeignKey('apn.apn_id', ondelete='CASCADE'), doc='apn_id of the APN served') + pcrf_session_id = Column(String(100), doc='Session ID from the PCRF') + subscriber_routing = Column(String(100), doc='IP Address allocated to the UE') + ip_version = Column(Integer, default=0, doc=APN.ip_version.doc) + serving_pgw = Column(String(512), doc='PGW serving this subscriber') + serving_pgw_timestamp = Column(DateTime, doc='Timestamp of attach to PGW') + serving_pgw_realm = Column(String(512), doc='Realm of serving PGW') + serving_pgw_peer = Column(String(512), doc='Diameter peer used to reach PGW') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("SERVING_APN_OPERATION_LOG", back_populates="serving_apn") + +class IMS_SUBSCRIBER(Base): + __tablename__ = 'ims_subscriber' + ims_subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of IMS_Subscriber entry') + msisdn = Column(String(18), unique=True, doc=SUBSCRIBER.msisdn.doc) + msisdn_list = Column(String(1200), doc='Comma Separated list of additional MSISDNs for Subscriber') + imsi = Column(String(18), unique=False, doc=SUBSCRIBER.imsi.doc) + ifc_path = Column(String(18), doc='Path to template file for the Initial Filter Criteria') + pcscf = Column(String(512), doc='Proxy-CSCF serving this subscriber') + pcscf_realm = Column(String(512), doc='Realm of PCSCF') + pcscf_active_session = Column(String(512), doc='Session Id for the PCSCF when in a call') + pcscf_timestamp = Column(DateTime, doc='Timestamp of last ue attach to PCSCF') + pcscf_peer = Column(String(512), doc='Diameter peer used to reach PCSCF') + sh_profile = Column(Text(12000), doc='Sh Subscriber Profile') + scscf = Column(String(512), doc='Serving-CSCF serving this subscriber') + scscf_timestamp = Column(DateTime, doc='Timestamp of last ue attach to SCSCF') + scscf_realm = Column(String(512), doc='Realm of SCSCF') + scscf_peer = Column(String(512), doc='Diameter peer used to reach SCSCF') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("IMS_SUBSCRIBER_OPERATION_LOG", back_populates="ims_subscriber") + +class CHARGING_RULE(Base): + __tablename__ = 'charging_rule' + charging_rule_id = Column(Integer, primary_key = True, doc='Unique ID of CHARGING_RULE entry') + rule_name = Column(String(20), doc='Name of rule pushed to PGW (Short, no special chars)') + + qci = Column(Integer, default=9, doc=APN.qci.doc) + arp_priority = Column(Integer, default=4, doc=APN.arp_priority.doc) + arp_preemption_capability = Column(Boolean, default=False, doc=APN.arp_preemption_capability.doc) + arp_preemption_vulnerability = Column(Boolean, default=True, doc=APN.arp_preemption_vulnerability.doc) + + mbr_dl = Column(Integer, nullable=False, doc='Maximum Downlink Bitrate for traffic matching this rule') + mbr_ul = Column(Integer, nullable=False, doc='Maximum Uplink Bitrate for traffic matching this rule') + gbr_dl = Column(Integer, nullable=False, doc='Guaranteed Downlink Bitrate for traffic matching this rule') + gbr_ul = Column(Integer, nullable=False, doc='Guaranteed Uplink Bitrate for traffic matching this rule') + tft_group_id = Column(Integer, doc='Will match any TFTs using this TFT Group to form the TFT list used in the Charging Rule') + precedence = Column(Integer, doc='Precedence of this rule, allows rule to override or be overridden by a higher priority rule') + rating_group = Column(Integer, doc='Rating Group in OCS / OFCS that traffic matching this rule will be charged under') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("CHARGING_RULE_OPERATION_LOG", back_populates="charging_rule") + +class TFT(Base): + __tablename__ = 'tft' + tft_id = Column(Integer, primary_key = True, doc='Unique ID of CHARGING_RULE entry') + tft_group_id = Column(Integer, nullable=False, doc=CHARGING_RULE.tft_group_id.doc) + tft_string = Column(String(100), nullable=False, doc='IPFilterRules as defined in [RFC 6733] taking the format: action dir proto from src to dst') + direction = Column(Integer, nullable=False, doc='Traffic Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("TFT_OPERATION_LOG", back_populates="tft") + +class EIR(Base): + __tablename__ = 'eir' + eir_id = Column(Integer, primary_key = True, doc='Unique ID of EIR entry') + imei = Column(String(60), doc='Exact IMEI or Regex to match IMEI (Depending on regex_mode value)') + imsi = Column(String(60), doc='Exact IMSI or Regex to match IMSI (Depending on regex_mode value)') + regex_mode = Column(Integer, default=1, doc='0 - Exact Match mode, 1 - Regex Mode') + match_response_code = Column(Integer, doc='0 - Whitelist, 1 - Blacklist, 2 - Greylist') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("EIR_OPERATION_LOG", back_populates="eir") + +class IMSI_IMEI_HISTORY(Base): + __tablename__ = 'eir_history' + imsi_imei_history_id = Column(Integer, primary_key = True, doc='Unique ID of IMSI_IMEI_HISTORY entry') + imsi_imei = Column(String(60), unique=True, doc='Combined IMSI + IMEI value') + match_response_code = Column(Integer, doc='Response code that was returned') + imsi_imei_timestamp = Column(DateTime, doc='Timestamp of last match') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("IMSI_IMEI_HISTORY_OPERATION_LOG", back_populates="eir_history") + +class SUBSCRIBER_ATTRIBUTES(Base): + __tablename__ = 'subscriber_attributes' + subscriber_attributes_id = Column(Integer, primary_key = True, doc='Unique ID of Attribute') + subscriber_id = Column(Integer, ForeignKey('subscriber.subscriber_id', ondelete='CASCADE'), doc='Reference to Subscriber ID defined within Subscriber Section', nullable=False) + key = Column(String(60), doc='Arbitrary key') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + value = Column(String(12000), doc='Arbitrary value') + operation_logs = relationship("SUBSCRIBER_ATTRIBUTES_OPERATION_LOG", back_populates="subscriber_attributes") + +class Database: + + def __init__(self, logTool, redisMessaging=None): + with open("../config.yaml", 'r') as stream: + self.config = (yaml.safe_load(stream)) + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + + self.logTool = logTool + if redisMessaging: + self.redisMessaging = redisMessaging + else: + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + + db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") + self.engine = create_engine( + db_string, + echo = self.config['logging'].get('sqlalchemy_sql_echo', True), + pool_recycle=self.config['logging'].get('sqlalchemy_pool_recycle', 5), + pool_size=self.config['logging'].get('sqlalchemy_pool_size', 30), + max_overflow=self.config['logging'].get('sqlalchemy_max_overflow', 0)) + + # Create database if it does not exist. + if not database_exists(self.engine.url): + self.logTool.log(service='Database', level='debug', message="Creating database", redisClient=self.redisMessaging) + create_database(self.engine.url) + Base.metadata.create_all(self.engine) + else: + self.logTool.log(service='Database', level='debug', message="Database already created", redisClient=self.redisMessaging) + + #Load IMEI TAC database into Redis if enabled + if ('tac_database_csv' in self.config['eir']): + self.load_IMEI_database_into_Redis() + self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase")) + else: + self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) + self.tacData = {} + + # Create individual tables if they do not exist. + inspector = Inspector.from_engine(self.engine) + for table_name in Base.metadata.tables.keys(): + if table_name not in inspector.get_table_names(): + self.logTool.log(service='Database', level='debug', message=f"Creating table {table_name}", redisClient=self.redisMessaging) + Base.metadata.tables[table_name].create(bind=self.engine) + else: + self.logTool.log(service='Database', level='debug', message=f"Table {table_name} already exists", redisClient=self.redisMessaging) + + + def load_IMEI_database_into_Redis(self): + try: + self.logTool.log(service='Database', level='info', message="Reading IMEI TAC database CSV from " + str(self.config['eir']['tac_database_csv']), redisClient=self.redisMessaging) + csvfile = open(str(self.config['eir']['tac_database_csv'])) + self.logTool.log(service='Database', level='info', message="This may take a few seconds to buffer into Redis...", redisClient=self.redisMessaging) + except: + self.logTool.log(service='Database', level='error', message="Failed to read CSV file of IMEI TAC database", redisClient=self.redisMessaging) + return + try: + count = 0 + tacList = {"tacList": []} + for line in csvfile: + line = line.replace('"', '') #Strip excess invered commas + line = line.replace("'", '') #Strip excess invered commas + line = line.rstrip() #Strip newlines + result = line.split(',') + tacPrefix = result[0] + name = result[1].lstrip() + model = result[2].lstrip() + + if count == 0: + self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) + redis_imei_result = self.redisMessaging.getValue(key="tacDatabase") + if redis_imei_result is not None: + if len(redis_imei_result) > 0: + self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) + return + self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) + tacList['tacList'].append({str(tacPrefix): {'name': name, 'model': model}}) + count += 1 + self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList)) + self.tacData = tacList + self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to load IMEI Database into Redis due to error: " + (str(E)), redisClient=self.redisMessaging) + return + + def safe_rollback(self, session): + try: + if session.is_active: + session.rollback() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to rollback session, error: {E}", redisClient=self.redisMessaging) + + def safe_close(self, session): + try: + if session.is_active: + session.close() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to run safe_close on session, error: {E}", redisClient=self.redisMessaging) + + def sqlalchemy_type_to_json_schema_type(self, sqlalchemy_type): + """ + Map SQLAlchemy types to JSON Schema types. + """ + if isinstance(sqlalchemy_type, Integer): + return "integer" + elif isinstance(sqlalchemy_type, String): + return "string" + elif isinstance(sqlalchemy_type, Boolean): + return "boolean" + elif isinstance(sqlalchemy_type, DateTime): + return "string" + elif isinstance(sqlalchemy_type, Float): + return "number" + else: + return "string" # Default to string for unsupported types. + + def generate_json_schema(self, model_class, required=None): + properties = {} + required = required or [] + + for column in model_class.__table__.columns: + prop_type = self.sqlalchemy_type_to_json_schema_type(column.type) + prop_dict = { + "type": prop_type, + "description": column.doc + } + if prop_type == "string": + if hasattr(column.type, 'length'): + prop_dict["maxLength"] = column.type.length + if isinstance(column.type, DateTime): + prop_dict["format"] = "date-time" + if not column.nullable: + required.append(column.name) + properties[column.name] = prop_dict + + return {"type": "object", "title" : str(model_class.__name__), "properties": properties, "required": required} + + def update_old_record(self, session, operation_log): + oldest_log = session.query(OPERATION_LOG_BASE).order_by(OPERATION_LOG_BASE.timestamp.asc()).first() + if oldest_log is not None: + for attr in class_mapper(oldest_log.__class__).column_attrs: + if attr.key != 'id' and hasattr(operation_log, attr.key): + setattr(oldest_log, attr.key, getattr(operation_log, attr.key)) + oldest_log.timestamp = datetime.datetime.now(tz=timezone.utc) + session.flush() + else: + raise ValueError("Unable to find record to update") + + def log_change(self, session, item_id, operation, changes, table_name, operation_id, generated_id=None): + # We don't want to log rollback operations + if session.info.get("operation") == 'ROLLBACK': + return + max_records = 1000 + count = session.query(OPERATION_LOG_BASE).count() + + # Combine all changes into a single string with their types + changes_string = '\r\n\r\n'.join(f"{column_name}: [{type(old_value).__name__}] {old_value} ----> [{type(new_value).__name__}] {new_value}" for column_name, old_value, new_value in changes) + + change = OPERATION_LOG_BASE( + item_id=item_id or generated_id, + operation_id=operation_id, + operation=operation, + last_modified=datetime.datetime.now(tz=timezone.utc), + changes=changes_string, + table_name=table_name + ) + + if count >= max_records: + self.update_old_record(session, change) + else: + try: + session.add(change) + session.flush() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to commit changelog, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + return operation_id + + + def log_changes_before_commit(self, session): + + operation_id = session.info.get("operation_id", None) or str(uuid.uuid4()) + if session.info.get("operation") == 'ROLLBACK': + return + + changelog_pending = any(isinstance(obj, OPERATION_LOG_BASE) for obj in session.new) + if changelog_pending: + return # Skip if there are pending OPERATION_LOG_BASE objects + + for state, operation in [ + (session.new, 'INSERT'), + (session.dirty, 'UPDATE'), + (session.deleted, 'DELETE') + ]: + for obj in state: + if isinstance(obj, OPERATION_LOG_BASE): + continue # Skip change log entries + + item_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) + generated_id = None + + #Avoid logging rollback operations + if operation == 'ROLLBACK': + return + + # Flush the session to generate primary key for new objects + if operation == 'INSERT': + session.flush() + + if operation == 'UPDATE': + changes = [] + for attr in class_mapper(obj.__class__).column_attrs: + hist = get_history(obj, attr.key) + self.logTool.log(service='Database', level='debug', message=f"History {hist}", redisClient=self.redisMessaging) + if hist.has_changes() and hist.added and hist.deleted: + old_value, new_value = hist.deleted[0], hist.added[0] + self.logTool.log(service='Database', level='debug', message=f"Old Value {old_value}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=f"New Value {new_value}", redisClient=self.redisMessaging) + changes.append((attr.key, old_value, new_value)) + continue + + if not changes: + continue + + operation_id = self.log_change(session, item_id, operation, changes, obj.__table__.name, operation_id) + + elif operation in ['INSERT', 'DELETE']: + changes = [] + for column in obj.__table__.columns: + column_name = column.name + value = getattr(obj, column_name) + if operation == 'INSERT': + old_value, new_value = None, value + if item_id is None: + generated_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) + elif operation == 'DELETE': + old_value, new_value = value, None + changes.append((column_name, old_value, new_value)) + operation_id = self.log_change(session, item_id, operation, changes, obj.__table__.name, operation_id, generated_id) + + def get_class_by_tablename(self, base, tablename): + """ + Returns a class object based on the given tablename. + + :param base: Base class of SQLAlchemy models + :param tablename: Name of the table to retrieve the class for + :return: Class object or None if not found + """ + for mapper in base.registry.mappers: + cls = mapper.class_ + if hasattr(cls, '__tablename__') and cls.__tablename__ == tablename: + return cls + return None + + def str_to_type(self, type_str, value_str): + if type_str == 'int': + return int(value_str) + elif type_str == 'float': + return float(value_str) + elif type_str == 'str': + return value_str + elif type_str == 'bool': + return value_str == 'True' + elif type_str == 'NoneType': + return None + else: + raise ValueError(f'Cannot convert to type: {type_str}') + + + def rollback_last_change(self, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + try: + # Get the most recent operation + last_operation = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() + + if last_operation is None: + return "No operations to roll back." + + rollback_messages = [] + operation_id = str(uuid.uuid4()) + + target_class = self.get_class_by_tablename(Base, last_operation.table_name) + if not target_class: + return f"Error: Could not find table {last_operation.table_name}" + + primary_key_col = target_class.__mapper__.primary_key[0].key + filter_by_kwargs = {primary_key_col: last_operation.item_id} + target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() + + if last_operation.operation == 'UPDATE': + if not target_item: + return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" + + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) + old_value = self.str_to_type(old_type_str, old_value_repr) + + # Revert the change + setattr(target_item, column_name, old_value) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" + ) + + elif last_operation.operation == 'INSERT': + if target_item: + session.delete(target_item) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" + ) + + elif last_operation.operation == 'DELETE': + # Aggregate old values of all columns into a single dictionary + old_values_dict = {} + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) + self.logTool.log(service='Database', level='error', message=f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}", redisClient=self.redisMessaging) + old_value = self.str_to_type(old_type_str, old_value_repr) + + old_values_dict[column_name] = old_value + self.logTool.log(service='Database', level='error', message="old_value_dict: " + str(old_values_dict), redisClient=self.redisMessaging) + + if not target_item: + try: + # Create the target item using the aggregated old values + target_item = target_class(**old_values_dict) + session.add(target_item) + except Exception as e: + return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" + ) + + else: + return f"Error: Unknown operation {last_operation.operation}" + + try: + session.commit() + self.safe_close(session) + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) + + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + def rollback_change_by_operation_id(self, operation_id, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + try: + # Get the most recent operation + last_operation = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() + + if last_operation is None: + return "No operation to roll back." + + rollback_messages = [] + operation_id = str(uuid.uuid4()) + + target_class = self.get_class_by_tablename(Base, last_operation.table_name) + if not target_class: + return f"Error: Could not find table {last_operation.table_name}" + + primary_key_col = target_class.__mapper__.primary_key[0].key + filter_by_kwargs = {primary_key_col: last_operation.item_id} + target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() + + if last_operation.operation == 'UPDATE': + if not target_item: + return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" + + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) + old_value = self.str_to_type(old_type_str, old_value_repr) + + # Revert the change + setattr(target_item, column_name, old_value) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" + ) + + elif last_operation.operation == 'INSERT': + if target_item: + session.delete(target_item) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" + ) + + elif last_operation.operation == 'DELETE': + # Aggregate old values of all columns into a single dictionary + old_values_dict = {} + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) + self.logTool.log(service='Database', level='error', message=f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}", redisClient=self.redisMessaging) + old_value = self.str_to_type(old_type_str, old_value_repr) + + old_values_dict[column_name] = old_value + self.logTool.log(service='Database', level='error', message="old_value_dict: " + str(old_values_dict), redisClient=self.redisMessaging) + + if not target_item: + try: + # Create the target item using the aggregated old values + target_item = target_class(**old_values_dict) + session.add(target_item) + except Exception as e: + return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" + ) + + else: + return f"Error: Unknown operation {last_operation.operation}" + + try: + session.commit() + self.safe_close(session) + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) + + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + def get_all_operation_logs(self, page=0, page_size=100, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + try: + # Get all distinct operation_ids ordered by max timestamp (descending order) + operation_ids = session.query(OPERATION_LOG_BASE.operation_id).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) + + operation_ids = operation_ids.limit(page_size).offset(page * page_size) + + operation_ids = operation_ids.all() + + all_operations = [] + + for operation_id in operation_ids: + operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() + + if operation_log is not None: + # Convert the object to dictionary + obj_dict = operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + all_operations.append(sanitized_obj_dict) + + self.safe_close(session) + return all_operations + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"get_all_operation_logs error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + def get_all_operation_logs_by_table(self, table_name, page=0, page_size=100, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + try: + # Get all distinct operation_ids ordered by max timestamp (descending order) + operation_ids = session.query(OPERATION_LOG_BASE.operation_id).filter(OPERATION_LOG_BASE.table_name == table_name).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) + + operation_ids = operation_ids.limit(page_size).offset(page * page_size) + + operation_ids = operation_ids.all() + + all_operations = [] + + for operation_id in operation_ids: + operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() + + if operation_log is not None: + # Convert the object to dictionary + obj_dict = operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + all_operations.append(sanitized_obj_dict) + + self.safe_close(session) + return all_operations + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"get_all_operation_logs_by_table error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + def get_last_operation_log(self, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + try: + # Get the top 100 records ordered by timestamp (descending order) + top_100_records = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).limit(100) + + # Get the most recent operation_id + most_recent_operation_log = top_100_records.first() + + # Convert the object to dictionary + if most_recent_operation_log is not None: + obj_dict = most_recent_operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + return sanitized_obj_dict + + self.safe_close(session) + return None + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"get_last_operation_log error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, asymmetricUrls: list=[]) -> bool: + """ + Validate the request, check configuration and queue the geored message. + Asymmetric geored is supported (where one or more specific or foreign geored endpoints are specified). + """ + try: + operation = operation.upper() + if operation not in ['PUT', 'PATCH', 'DELETE']: + self.logTool.log(service='Database', level='warning', message="Failed to send Geored message invalid operation type, received: " + str(operation), redisClient=self.redisMessaging) + return + georedDict = {} + if self.config.get('geored', {}).get('enabled', False): + if self.config.get('geored', {}).get('endpoints', []) is not None: + if len(self.config.get('geored', {}).get('endpoints', [])) > 0: + georedDict['body'] = jsonData + georedDict['operation'] = operation + georedDict['timestamp'] = time.time_ns() + self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120) + if asymmetric: + if len(asymmetricUrls) > 0: + georedDict['body'] = jsonData + georedDict['operation'] = operation + georedDict['timestamp'] = time.time_ns() + georedDict['urls'] = asymmetricUrls + self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120) + return True + + except Exception as E: + self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) + return False + + def handleWebhook(self, objectData, operation: str="PATCH"): + webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) + endpointList = self.config.get('webhooks', {}).get('endpoints', []) + webhook = {} + + if not webhooksEnabled: + return False + + if endpointList is None: + return False + + if not len (endpointList) > 0: + self.logTool.log(service='Database', level='error', message="Webhooks enabled, but endpoints are missing.", redisClient=self.redisMessaging) + return False + + webhookHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} + + webhook['body'] = self.Sanitize_Datetime(objectData) + webhook['headers'] = webhookHeaders + webhook['operation'] = operation + webhook['timestamp'] = time.time_ns() + self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120) + return True + + def Sanitize_Datetime(self, result): + for keys in result: + if "timestamp" in keys: + if result[keys] == None: + continue + else: + self.logTool.log(service='Database', level='debug', message="Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String", redisClient=self.redisMessaging) + try: + result[keys] = result[keys].strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result[keys] = str(result[keys]) + return result + + def Sanitize_Keys(self, result): + names_to_strip = ['opc', 'ki', 'des', 'kid', 'psk', 'adm1'] + for name_to_strip in names_to_strip: + try: + result.pop(name_to_strip) + except: + pass + return result + + def GetObj(self, obj_type, obj_id=None, page=None, page_size=None): + self.logTool.log(service='Database', level='debug', message="Called GetObj for type " + str(obj_type), redisClient=self.redisMessaging) + + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + + try: + if obj_id is not None: + result = session.query(obj_type).get(obj_id) + if result is None: + raise ValueError(f"No {obj_type} found with id {obj_id}") + + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + elif page is not None and page_size is not None: + if page < 1 or page_size < 1: + raise ValueError("page and page_size should be positive integers") + + offset = (page - 1) * page_size + results = ( + session.query(obj_type) + .order_by(obj_type.id) # Assuming obj_type has an attribute 'id' + .offset(offset) + .limit(page_size) + .all() + ) + + result = [] + for item in results: + item_dict = item.__dict__ + item_dict.pop('_sa_instance_state') + result.append(self.Sanitize_Datetime(item_dict)) + else: + raise ValueError("Provide either obj_id or both page and page_size") + + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + self.safe_close(session) + return result + + def GetAll(self, obj_type): + self.logTool.log(service='Database', level='debug', message="Called GetAll for type " + str(obj_type), redisClient=self.redisMessaging) + + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind = self.engine) + session = Session() + final_result_list = [] + + try: + result = session.query(obj_type) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + for record in result: + record = record.__dict__ + record.pop('_sa_instance_state') + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) + final_result_list.append(record) + + self.safe_close(session) + return final_result_list + + def getAllPaginated(self, obj_type, page=0, page_size=0, existingSession=None): + self.logTool.log(service='Database', level='debug', message="Called getAllPaginated for type " + str(obj_type), redisClient=self.redisMessaging) + + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession + + final_result_list = [] + + try: + # Query object type + result = session.query(obj_type) + + # Apply pagination + if page_size != 0: + result = result.limit(page_size).offset(page * page_size) + + result = result.all() + + for record in result: + record = record.__dict__ + record.pop('_sa_instance_state') + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) + final_result_list.append(record) + + self.safe_close(session) + return final_result_list + + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + + def GetAllByTable(self, obj_type, table): + self.logTool.log(service='Database', level='debug', message=f"Called GetAll for type {str(obj_type)} and table {table}", redisClient=self.redisMessaging) + + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind = self.engine) + session = Session() + final_result_list = [] + + try: + result = session.query(obj_type).filter_by(table_name=str(table)) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + for record in result: + record = record.__dict__ + record.pop('_sa_instance_state') + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) + final_result_list.append(record) + + self.safe_close(session) + return final_result_list + + def UpdateObj(self, obj_type, json_data, obj_id, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message=f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + obj_type_str = str(obj_type.__table__.name).upper() + self.logTool.log(service='Database', level='debug', message=f"obj_type_str is {obj_type_str}", redisClient=self.redisMessaging) + filter_input = eval(obj_type_str + "." + obj_type_str.lower() + "_id==obj_id") + try: + obj = session.query(obj_type).filter(filter_input).one() + for key, value in json_data.items(): + if hasattr(obj, key): + setattr(obj, key, value) + setattr(obj, "last_modified", datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z') + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to query or update object, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + try: + session.info["operation_id"] = operation_id # Pass the operation id + try: + if not disable_logging: + self.log_changes_before_commit(session) + objectData = self.GetObj(obj_type, obj_id) + session.commit() + self.handleWebhook(objectData, 'PATCH') + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) + raise ValueError(E) + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Exception in UpdateObj, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + finally: + self.safe_close(session) + + return self.GetObj(obj_type, obj_id) + + def DeleteObj(self, obj_type, obj_id, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message=f"Called DeleteObj for type {obj_type} with id {obj_id}", redisClient=self.redisMessaging) + + Session = sessionmaker(bind=self.engine) + session = Session() + + try: + res = session.query(obj_type).get(obj_id) + if res is None: + raise ValueError("The specified row does not exist") + objectData = self.GetObj(obj_type, obj_id) + session.delete(res) + session.info["operation_id"] = operation_id # Pass the operation id + try: + if not disable_logging: + self.log_changes_before_commit(session) + session.commit() + self.handleWebhook(objectData, 'DELETE') + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) + raise ValueError(E) + + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Exception in DeleteObj, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + finally: + self.safe_close(session) + + return {'Result': 'OK'} + + + def CreateObj(self, obj_type, json_data, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message="Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data), redisClient=self.redisMessaging) + last_modified_value = datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z' + json_data["last_modified"] = last_modified_value # set last_modified value in json_data + newObj = obj_type(**json_data) + Session = sessionmaker(bind=self.engine) + session = Session() + + session.add(newObj) + try: + session.info["operation_id"] = operation_id # Pass the operation id + try: + if not disable_logging: + self.log_changes_before_commit(session) + session.commit() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) + raise ValueError(E) + session.refresh(newObj) + result = newObj.__dict__ + result.pop('_sa_instance_state') + self.handleWebhook(result, 'PUT') + return result + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Exception in CreateObj, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + finally: + self.safe_close(session) + + def Generate_JSON_Model_for_Flask(self, obj_type): + self.logTool.log(service='Database', level='debug', message="Generating JSON model for Flask for object type: " + str(obj_type), redisClient=self.redisMessaging) + + dictty = dict(self.generate_json_schema(obj_type)) + # pprint.pprint(dictty) + + + #dictty['properties'] = dict(dictty['properties']) + + # Exclude 'table_name' column from the properties + if 'properties' in dictty: + dictty['properties'].pop('discriminator', None) + dictty['properties'].pop('last_modified', None) + + + # Set the ID Object to not required + obj_type_str = str(dictty['title']).lower() + dictty['required'].remove(obj_type_str + '_id') + + return dictty + + def Get_AuC(self, **kwargs): + #Get AuC data by IMSI or ICCID + + Session = sessionmaker(bind = self.engine) + session = Session() + + if 'iccid' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_AuC for iccid " + str(kwargs['iccid']), redisClient=self.redisMessaging) + try: + result = session.query(AUC).filter_by(iccid=str(kwargs['iccid'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_AuC for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(AUC).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + + result = result.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result + + def Get_IMS_Subscriber(self, **kwargs): + #Get subscriber by IMSI or MSISDN + Session = sessionmaker(bind = self.engine) + session = Session() + if 'msisdn' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn']), redisClient=self.redisMessaging) + try: + result = session.query(IMS_SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_IMS_Subscriber for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="Converting result to dict", redisClient=self.redisMessaging) + result = result.__dict__ + try: + result.pop('_sa_instance_state') + except: + pass + result = self.Sanitize_Datetime(result) + self.logTool.log(service='Database', level='debug', message="Returning IMS Subscriber Data: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result + + def Get_Subscriber(self, **kwargs): + #Get subscriber by IMSI or MSISDN + + Session = sessionmaker(bind = self.engine) + session = Session() + + if 'subscriber_id' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for id " + str(kwargs['subscriber_id']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(subscriber_id=int(kwargs['subscriber_id'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'msisdn' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for msisdn " + str(kwargs['msisdn']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + + result = result.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + if 'get_attributes' in kwargs: + if kwargs['get_attributes'] == True: + attributes = self.Get_Subscriber_Attributes(result['subscriber_id']) + result['attributes'] = attributes + + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result + + def Get_SUBSCRIBER_ROUTING(self, subscriber_id, apn_id): + Session = sessionmaker(bind = self.engine) + session = Session() + + self.logTool.log(service='Database', level='debug', message="Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER_ROUTING).filter_by(subscriber_id=subscriber_id, apn_id=apn_id).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + + result = result.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result + + def Get_Subscriber_Attributes(self, subscriber_id): + #Get subscriber attributes + + Session = sessionmaker(bind = self.engine) + session = Session() + + self.logTool.log(service='Database', level='debug', message="Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER_ATTRIBUTES).filter_by(subscriber_id=subscriber_id) + except Exception as E: + self.safe_close(session) + raise ValueError(E) + final_res = [] + for record in result: + result = record.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + final_res.append(result) + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(final_res), redisClient=self.redisMessaging) + self.safe_close(session) + return final_res + + + def Get_Served_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this HSS", redisClient=self.redisMessaging) + + Session = sessionmaker(bind = self.engine) + session = Session() + + Served_Subs = {} + try: + results = session.query(SUBSCRIBER).filter(SUBSCRIBER.serving_mme.isnot(None)) + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_hss = result['serving_mme_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving HSS: " + str(serving_hss) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_hss == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving HSS matches local HSS", redisClient=self.redisMessaging) + Served_Subs[result['imsi']] = {} + Served_Subs[result['imsi']] = result + #self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + continue + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote HSS: " + str(serving_hss), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering Get_Served_Subscribers to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + else: + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + + + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="Final Served_Subs: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs + + + def Get_Served_IMS_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this IMS-HSS", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + + Served_Subs = {} + try: + + results = session.query(IMS_SUBSCRIBER).filter( + IMS_SUBSCRIBER.scscf.isnot(None)) + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result, redisClient=self.redisMessaging) + + " type: " + str(type(result))) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_ims_hss = result['scscf_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_ims_hss == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving IMS-HSS matches local HSS for " + str(result['imsi']), redisClient=self.redisMessaging) + Served_Subs[result['imsi']] = {} + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + continue + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote IMS-HSS: " + str(serving_ims_hss), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + else: + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="Final Served_Subs: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs + + + def Get_Served_PCRF_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this PCRF", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + Served_Subs = {} + try: + results = session.query(SERVING_APN).all() + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_pcrf = result['serving_pgw_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_pcrf == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving PCRF matches local PCRF", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote PCRF: " + str(serving_pcrf), redisClient=self.redisMessaging) + continue + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + + # Get APN Info + apn_info = self.GetObj(APN, result['apn']) + #self.logTool.log(service='Database', level='debug', message="Got APN Info: " + str(apn_info), redisClient=self.redisMessaging) + result['apn_info'] = apn_info + + # Get Subscriber Info + subscriber_info = self.GetObj(SUBSCRIBER, result['subscriber_id']) + result['subscriber_info'] = subscriber_info + + #self.logTool.log(service='Database', level='debug', message="Got Subscriber Info: " + str(subscriber_info), redisClient=self.redisMessaging) + + Served_Subs[subscriber_info['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + except Exception as E: + raise ValueError(E) + #self.logTool.log(service='Database', level='debug', message="Final SERVING_APN: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs + + def Get_Vectors_AuC(self, auc_id, action, **kwargs): + self.logTool.log(service='Database', level='debug', message="Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action), redisClient=self.redisMessaging) + key_data = self.GetObj(AUC, auc_id) + vector_dict = {} + + if action == "air": + rand, xres, autn, kasme = S6a_crypt.generate_eutran_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) + vector_dict['rand'] = rand + vector_dict['xres'] = xres + vector_dict['autn'] = autn + vector_dict['kasme'] = kasme + + #Incriment SQN + self.Update_AuC(auc_id, sqn=key_data['sqn']+100) + + return vector_dict + + elif action == "sqn_resync": + self.logTool.log(service='Database', level='debug', message="Resync SQN", redisClient=self.redisMessaging) + rand = kwargs['rand'] + sqn, mac_s = S6a_crypt.generate_resync_s6a(key_data['ki'], key_data['opc'], key_data['amf'], kwargs['auts'], rand) + self.logTool.log(service='Database', level='debug', message="SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")", redisClient=self.redisMessaging) + self.Update_AuC(auc_id, sqn=sqn+100) + return + + elif action == "sip_auth": + rand, autn, xres, ck, ik = S6a_crypt.generate_maa_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) + self.logTool.log(service='Database', level='debug', message="RAND is: " + str(rand), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="AUTN is: " + str(autn), redisClient=self.redisMessaging) + vector_dict['SIP_Authenticate'] = rand + autn + vector_dict['xres'] = xres + vector_dict['ck'] = ck + vector_dict['ik'] = ik + self.Update_AuC(auc_id, sqn=key_data['sqn']+100) + return vector_dict + + elif action == "Digest-MD5": + self.logTool.log(service='Database', level='debug', message="Generating Digest-MD5 Auth vectors", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="key_data: " + str(key_data), redisClient=self.redisMessaging) + nonce = uuid.uuid4().hex + #nonce = "beef4d878f2642ed98afe491b943ca60" + vector_dict['nonce'] = nonce + vector_dict['SIP_Authenticate'] = key_data['ki'] + return vector_dict + + def Get_APN(self, apn_id): + self.logTool.log(service='Database', level='debug', message="Getting APN " + str(apn_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(APN).filter_by(apn_id=apn_id).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + + def Get_APN_by_Name(self, apn): + self.logTool.log(service='Database', level='debug', message="Getting APN named " + str(apn), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + try: + result = session.query(APN).filter_by(apn=str(apn)).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + + def Update_AuC(self, auc_id, sqn=1): + self.logTool.log(service='Database', level='debug', message="Updating AuC record for sub " + str(auc_id), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=self.UpdateObj(AUC, {'sqn': sqn}, auc_id, True), redisClient=self.redisMessaging) + return + + def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, serving_mme_timestamp=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + try: + result = session.query(SUBSCRIBER).filter_by(imsi=imsi).one() + if self.config['hss']['CancelLocationRequest_Enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Evaluating if we should trigger sending a CLR.", redisClient=self.redisMessaging) + serving_hss = str(result.serving_mme_peer).split(';',1)[1] + serving_mme_peer = str(result.serving_mme_peer).split(';',1)[0] + self.logTool.log(service='Database', level='debug', message="Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer), redisClient=self.redisMessaging) + #Evaluate if we need to send a CLR to the old MME + if result.serving_mme != None: + if str(result.serving_mme) == str(serving_mme): + self.logTool.log(service='Database', level='debug', message="This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR", redisClient=self.redisMessaging) + elif (str(result.serving_mme) != str(serving_mme)): + self.logTool.log(service='Database', level='debug', message="There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR", redisClient=self.redisMessaging) + if serving_hss != self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss), redisClient=self.redisMessaging) + URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) + else: + self.logTool.log(service='Database', level='debug', message="This subscriber is served by this HSS we need to send a CLR to old MME from this HSS", redisClient=self.redisMessaging) + + URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) + self.logTool.log(service='Database', level='debug', message="Sending CLR to API at " + str(URL), redisClient=self.redisMessaging) + + clrBody = { + "imsi": str(imsi), + "DestinationRealm": result.serving_mme_realm, + "DestinationHost": result.serving_mme, + "cancellationType": 2, + "diameterPeer": serving_mme_peer, + } + + self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(clrBody), redisClient=self.redisMessaging) + transaction_id = str(uuid.uuid4()) + self.handleGeored(clrBody, asymmetric=True, asymmetricUrls=[URL]) + else: + #No currently serving MME - No action to take + self.logTool.log(service='Database', level='debug', message="No currently serving MME - No need to send CLR", redisClient=self.redisMessaging) + + if type(serving_mme) == str: + self.logTool.log(service='Database', level='debug', message="Updating serving MME & Timestamp", redisClient=self.redisMessaging) + result.serving_mme = serving_mme + try: + if serving_mme_timestamp != None and serving_mme_timestamp != 'None': + result.serving_mme_timestamp = datetime.strptime(serving_mme_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.serving_mme_timestamp = result.serving_mme_timestamp.replace(tzinfo=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + result.serving_mme_realm = serving_mme_realm + result.serving_mme_peer = serving_mme_peer + else: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing serving MME", redisClient=self.redisMessaging) + result.serving_mme = None + result.serving_mme_timestamp = None + result.serving_mme_realm = None + result.serving_mme_peer = None + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + + session.commit() + objectData = self.GetObj(SUBSCRIBER, result.subscriber_id) + self.handleWebhook(objectData, 'PATCH') + + if result.serving_mme_timestamp is not None: + result.serving_mme_timestamp = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + + #Sync state change with geored + if propagate == True: + if 'HSS' in self.config['geored'].get('sync_actions', []) and self.config['geored'].get('enabled', False) == True: + self.logTool.log(service='Database', level='debug', message="Propagate MME changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({ + "imsi": str(imsi), + "serving_mme": result.serving_mme, + "serving_mme_realm": result.serving_mme_realm, + "serving_mme_peer": result.serving_mme_peer, + "serving_mme_timestamp": serving_mme_timestamp_string + }) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of HSS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Error occurred in Update_Serving_MME: " + str(E), redisClient=self.redisMessaging) + finally: + self.safe_close(session) + + + def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, pcscf_timestamp=None, pcscf_active_session=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Update_Proxy_CSCF for sub " + str(imsi) + " to pcscf " + str(proxy_cscf) + " with realm " + str(pcscf_realm) + " and peer " + str(pcscf_peer) + " for session id " + str(pcscf_active_session), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() + try: + assert(type(proxy_cscf) == str) + assert(len(proxy_cscf) > 0) + self.logTool.log(service='Database', level='debug', message="Setting Proxy CSCF", redisClient=self.redisMessaging) + #Strip duplicate SIP prefix before storing + proxy_cscf = proxy_cscf.replace("sip:sip:", "sip:") + result.pcscf = proxy_cscf + result.pcscf_active_session = pcscf_active_session + try: + if pcscf_timestamp != None and pcscf_timestamp != 'None': + result.pcscf_timestamp = datetime.strptime(pcscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.pcscf_timestamp = result.pcscf_timestamp.replace(tzinfo=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + result.pcscf_realm = pcscf_realm + result.pcscf_peer = str(pcscf_peer) + except: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing Proxy CSCF", redisClient=self.redisMessaging) + result.pcscf = None + result.pcscf_timestamp = None + result.pcscf_realm = None + result.pcscf_peer = None + result.pcscf_active_session = None + pcscf_timestamp_string = None + + session.commit() + objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) + self.handleWebhook(objectData, 'PATCH') + + #Sync state change with geored + if propagate == True: + if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": result.pcscf_realm, "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": result.pcscf_peer, "pcscf_active_session": pcscf_active_session}) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="An error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + raise + finally: + self.safe_close(session) + + def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=None, scscf_timestamp=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() + try: + assert(type(serving_cscf) == str) + assert(len(serving_cscf) > 0) + self.logTool.log(service='Database', level='debug', message="Setting serving CSCF", redisClient=self.redisMessaging) + #Strip duplicate SIP prefix before storing + serving_cscf = serving_cscf.replace("sip:sip:", "sip:") + result.scscf = serving_cscf + try: + if scscf_timestamp != None and scscf_timestamp != 'None': + result.scscf_timestamp = datetime.strptime(scscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.scscf_timestamp = result.scscf_timestamp.replace(tzinfo=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + result.scscf_realm = scscf_realm + result.scscf_peer = str(scscf_peer) + except: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing serving CSCF", redisClient=self.redisMessaging) + result.scscf = None + result.scscf_timestamp = None + result.scscf_realm = None + result.scscf_peer = None + scscf_timestamp_string = None + + session.commit() + objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) + self.handleWebhook(objectData, 'PATCH') + + #Sync state change with geored + if propagate == True: + if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": result.scscf_realm, "scscf_timestamp": scscf_timestamp_string, "scscf_peer": result.scscf_peer}) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="An error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + raise + finally: + self.safe_close(session) + + def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, serving_pgw_timestamp=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="subscriber_routing: " + str(subscriber_routing), redisClient=self.redisMessaging) + + #Get Subscriber ID from IMSI + subscriber_details = self.Get_Subscriber(imsi=str(imsi)) + subscriber_id = subscriber_details['subscriber_id'] + + #Split the APN list into a list + apn_list = subscriber_details['apn_list'].split(',') + self.logTool.log(service='Database', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) + #Remove the default APN from the list + try: + apn_list.remove(str(subscriber_details['default_apn'])) + except: + self.logTool.log(service='Database', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) + pass + #Add default APN in first position + apn_list.insert(0, str(subscriber_details['default_apn'])) + + #Get APN ID from APN + for apn_id in apn_list: + #Get each APN in List + apn_data = self.Get_APN(apn_id) + self.logTool.log(service='Database', level='debug', message=apn_data, redisClient=self.redisMessaging) + if str(apn_data['apn']).lower() == str(apn).lower(): + self.logTool.log(service='Database', level='debug', message="Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id), redisClient=self.redisMessaging) + break + self.logTool.log(service='Database', level='debug', message="APN ID is " + str(apn_id), redisClient=self.redisMessaging) + + try: + if serving_pgw_timestamp != None and serving_pgw_timestamp != 'None': + serving_pgw_timestamp = datetime.strptime(serving_pgw_timestamp, '%Y-%m-%dT%H:%M:%SZ') + serving_pgw_timestamp = serving_pgw_timestamp.replace(tzinfo=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + serving_pgw_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + serving_pgw_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + serving_pgw_realm = serving_pgw_realm + serving_pgw_peer = serving_pgw_peer + + json_data = { + 'apn' : apn_id, + 'subscriber_id' : subscriber_id, + 'pcrf_session_id' : str(pcrf_session_id), + 'serving_pgw' : str(serving_pgw), + 'serving_pgw_realm' : str(serving_pgw_realm), + 'serving_pgw_peer' : str(serving_pgw_peer), + 'serving_pgw_timestamp' : serving_pgw_timestamp, + 'subscriber_routing' : str(subscriber_routing) + } + + if serving_pgw is None: + try: + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'DELETE') + self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) + except Exception as e: + self.logTool.log(service='Database', level='debug', message=f"Error when trying to delete serving_apn id: {apn_id}", redisClient=self.redisMessaging) + else: + try: + #Check if already a serving APN on record + self.logTool.log(service='Database', level='debug', message="Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id), redisClient=self.redisMessaging) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + self.logTool.log(service='Database', level='debug', message="Existing Serving APN ID on record, updating", redisClient=self.redisMessaging) + try: + assert(type(serving_pgw) == str) + assert(len(serving_pgw) > 0) + assert("None" not in serving_pgw) + + self.UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'PATCH') + except: + self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'DELETE') + self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) + #Create if does not exist + self.CreateObj(SERVING_APN, json_data, True) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'PUT') + + #Sync state change with geored + if propagate == True: + try: + if 'PCRF' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate PCRF changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), + 'serving_apn' : apn, + 'pcrf_session_id': pcrf_session_id, + 'serving_pgw': serving_pgw, + 'serving_pgw_realm': serving_pgw_realm, + 'serving_pgw_peer': serving_pgw_peer, + 'serving_pgw_timestamp': serving_pgw_timestamp_string, + 'subscriber_routing': subscriber_routing + }) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of PCRF events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Nothing synced to Geographic PyHSS instances for event PCRF", redisClient=self.redisMessaging) + + return + + def Get_Serving_APN(self, subscriber_id, apn_id): + self.logTool.log(service='Database', level='debug', message="Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(SERVING_APN).filter_by(subscriber_id=subscriber_id, apn=apn_id).first() + except Exception as E: + self.logTool.log(service='Database', level='debug', message=E, redisClient=self.redisMessaging) + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + + self.safe_close(session) + return result + + def Get_Serving_APNs(self, subscriber_id: int) -> dict: + """ + Returns all a dictionary containing all APNs that a subscriber is configured for (subscriber/apn_list), + with active sessions being a populated dictionary, and inactive sessions being an empty dictionary. + """ + self.logTool.log(service='Database', level='debug', message=f"Getting Serving APNs for subscriber_id: {subscriber_id}", redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + apnDict = {'apns': {}} + + try: + subscriber = self.Get_Subscriber(subscriber_id=subscriber_id) + except: + self.logTool.log(service='Database', level='debug', message=f"Unable to get subscriber with ID: {subscriber_id}: {traceback.format_exc()} ", redisClient=self.redisMessaging) + return apnDict + + apnList = subscriber.get('apn_list', []).split(',') + for apnId in apnList: + try: + apnData = self.Get_APN(apnId) + apnName = apnData.get('apn', 'Unknown') + try: + servingApn = self.Sanitize_Datetime(self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apnId)) + self.logTool.log(service='Database', level='debug', message=f"Got serving APN: {servingApn}", redisClient=self.redisMessaging) + if len(servingApn) > 0: + apnDict['apns'][apnName] = servingApn + else: + apnDict['apns'][apnName] = {} + except Exception as e: + apnDict['apns'][apnName] = {} + continue + except Exception as E: + self.logTool.log(service='Database', level='debug', message=f"Error getting apn for subscriber id: {subscriber_id}: {traceback.format_exc()} ", redisClient=self.redisMessaging) + + self.logTool.log(service='Database', level='debug', message=f"Returning: {apnDict}", redisClient=self.redisMessaging) + + return apnDict + + def Get_Charging_Rule(self, charging_rule_id): + self.logTool.log(service='Database', level='debug', message="Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + #Get base Rule + ChargingRule = self.GetObj(CHARGING_RULE, charging_rule_id) + ChargingRule['tft'] = [] + #Get TFTs + try: + results = session.query(TFT).filter_by(tft_group_id=ChargingRule['tft_group_id']) + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + ChargingRule['tft'].append(result) + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.safe_close(session) + return ChargingRule + + def Get_Charging_Rules(self, imsi, apn): + self.logTool.log(service='Database', level='debug', message="Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn), redisClient=self.redisMessaging) + #Get Subscriber ID from IMSI + subscriber_details = self.Get_Subscriber(imsi=str(imsi)) + + #Split the APN list into a list + apn_list = subscriber_details['apn_list'].split(',') + self.logTool.log(service='Database', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) + #Remove the default APN from the list + try: + apn_list.remove(str(subscriber_details['default_apn'])) + except: + self.logTool.log(service='Database', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) + pass + #Add default APN in first position + apn_list.insert(0, str(subscriber_details['default_apn'])) + + #Get APN ID from APN + for apn_id in apn_list: + self.logTool.log(service='Database', level='debug', message="Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn), redisClient=self.redisMessaging) + #Get each APN in List + apn_data = self.Get_APN(apn_id) + self.logTool.log(service='Database', level='debug', message=apn_data, redisClient=self.redisMessaging) + if str(apn_data['apn']).lower() == str(apn).lower(): + self.logTool.log(service='Database', level='debug', message="Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id), redisClient=self.redisMessaging) + + self.logTool.log(service='Database', level='debug', message="Getting charging rule list from " + str(apn_data['charging_rule_list']), redisClient=self.redisMessaging) + ChargingRule = {} + ChargingRule['charging_rule_list'] = str(apn_data['charging_rule_list']).split(',') + ChargingRule['apn_data'] = apn_data + + #Get Charging Rules list + if apn_data['charging_rule_list'] == None: + self.logTool.log(service='Database', level='debug', message="No Charging Rule associated with this APN", redisClient=self.redisMessaging) + ChargingRule['charging_rules'] = None + return ChargingRule + + self.logTool.log(service='Database', level='debug', message="ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list']), redisClient=self.redisMessaging) + #Empty dict for the Charging Rules to go into + ChargingRule['charging_rules'] = [] + #Add each of the Charging Rules for the APN + for individual_charging_rule in ChargingRule['charging_rule_list']: + self.logTool.log(service='Database', level='debug', message="Getting Charging rule " + str(individual_charging_rule), redisClient=self.redisMessaging) + individual_charging_rule_complete = self.Get_Charging_Rule(individual_charging_rule) + self.logTool.log(service='Database', level='debug', message="Got individual_charging_rule_complete: " + str(individual_charging_rule_complete), redisClient=self.redisMessaging) + ChargingRule['charging_rules'].append(individual_charging_rule_complete) + self.logTool.log(service='Database', level='debug', message="Completed Get_Charging_Rules()", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=ChargingRule, redisClient=self.redisMessaging) + return ChargingRule + + def Get_UE_by_IP(self, subscriber_routing): + self.logTool.log(service='Database', level='debug', message="Called Get_UE_by_IP() for IP " + str(subscriber_routing), redisClient=self.redisMessaging) + + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(SERVING_APN).filter_by(subscriber_routing=subscriber_routing).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + return result + + def Get_IMS_Subscriber_By_Session_Id(self, sessionId): + self.logTool.log(service='Database', level='debug', message="Called Get_IMS_Subscriber_By_Session_Id() for Session " + str(sessionId), redisClient=self.redisMessaging) + + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(IMS_SUBSCRIBER).filter_by(pcscf_active_session=sessionId).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + return result + + def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): + #IMSI 14-15 Digits + #IMEI 15 Digits + #IMEI-SV 2 Digits + self.logTool.log(service='Database', level='debug', message="Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code), redisClient=self.redisMessaging) + if self.config['eir']['imsi_imei_logging'] != True: + self.logTool.log(service='Database', level='debug', message="Skipping storing binding", redisClient=self.redisMessaging) + return + #Concat IMEI + IMSI + imsi_imei = str(imsi) + "," + str(imei) + Session = sessionmaker(bind = self.engine) + session = Session() + + #Check if exist already & update + try: + session.query(IMSI_IMEI_HISTORY).filter_by(imsi_imei=imsi_imei).one() + self.logTool.log(service='Database', level='debug', message="Entry already present for IMSI/IMEI Combo", redisClient=self.redisMessaging) + self.safe_close(session) + return + except Exception as E: + newObj = IMSI_IMEI_HISTORY(imsi_imei=imsi_imei, match_response_code=match_response_code, imsi_imei_timestamp = datetime.datetime.now(tz=timezone.utc)) + session.add(newObj) + try: + session.commit() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to commit session, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + self.safe_close(session) + self.logTool.log(service='Database', level='debug', message="Added new IMSI_IMEI_HISTORY binding", redisClient=self.redisMessaging) + + if 'sim_swap_notify_webhook' in self.config['eir']: + self.logTool.log(service='Database', level='debug', message="Sending SIM Swap notification to Webhook", redisClient=self.redisMessaging) + try: + dictToSend = {'imei':imei, 'imsi': imsi, 'match_response_code': match_response_code} + self.handleWebhook(dictToSend) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Failed to post to Webhook", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=str(E), redisClient=self.redisMessaging) + + #Lookup Device Info + if 'tac_database_csv' in self.config['eir']: + try: + device_info = self.get_device_info_from_TAC(imei=str(imei)) + self.logTool.log(service='Database', level='debug', message="Got Device Info: " + str(device_info), redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='database', metricName='prom_eir_devices', + metricType='counter', metricAction='inc', + metricValue=1, metricHelp='Profile of attached devices', + metricLabels={'imei_prefix': device_info['tacPrefix'], + 'device_type': device_info['name'], + 'device_name': device_info['model']}, + metricExpiry=60) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Failed to get device info from TAC", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='database', metricName='prom_eir_devices', + metricType='counter', metricAction='inc', + metricValue=1, metricHelp='Profile of attached devices', + metricLabels={'imei_prefix': str(imei)[0:8], + 'device_type': 'Unknown', + 'device_name': 'Unknown'}, + metricExpiry=60) + else: + self.logTool.log(service='Database', level='debug', message="No TAC database configured, skipping device info lookup", redisClient=self.redisMessaging) + + #Sync state change with geored + if propagate == True: + try: + if 'EIR' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate EIR changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored( + {"imsi": str(imsi), + "imei": str(imei), + "match_response_code": str(match_response_code)} + ) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of EIR events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Nothing synced to Geographic PyHSS instances for EIR event", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=E, redisClient=self.redisMessaging) + + return + + def Get_IMEI_IMSI_History(self, attribute): + self.logTool.log(service='Database', level='debug', message="Called Get_IMEI_IMSI_History() for entry matching " + str(self.Get_IMEI_IMSI_History), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + result_array = [] + try: + results = session.query(IMSI_IMEI_HISTORY).filter(IMSI_IMEI_HISTORY.imsi_imei.ilike("%" + str(attribute) + "%")).all() + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + try: + result['imsi'] = result['imsi_imei'].split(",")[0] + except: + continue + try: + result['imei'] = result['imsi_imei'].split(",")[1] + except: + continue + result_array.append(result) + self.safe_close(session) + return result_array + except Exception as E: + self.safe_close(session) + raise ValueError(E) + + def Check_EIR(self, imsi, imei): + eir_response_code_table = {0 : 'Whitelist', 1: 'Blacklist', 2: 'Greylist'} + self.logTool.log(service='Database', level='debug', message="Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + #Check for Exact Matches + self.logTool.log(service='Database', level='debug', message="Looking for exact matches", redisClient=self.redisMessaging) + #Check for exact Matches + try: + results = session.query(EIR).filter_by(imei=str(imei), regex_mode=0) + for result in results: + result = result.__dict__ + match_response_code = result['match_response_code'] + if result['imsi'] == '': + self.logTool.log(service='Database', level='debug', message="No IMSI specified in DB, so matching only on IMEI", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + elif result['imsi'] == str(imsi): + self.logTool.log(service='Database', level='debug', message="Matched on IMEI and IMSI", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + except Exception as E: + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + self.logTool.log(service='Database', level='debug', message="Did not match any Exact Matches - Checking Regex", redisClient=self.redisMessaging) + try: + results = session.query(EIR).filter_by(regex_mode=1) #Get all Regex records from DB + for result in results: + result = result.__dict__ + match_response_code = result['match_response_code'] + if re.match(result['imei'], imei): + self.logTool.log(service='Database', level='debug', message="IMEI matched " + str(result['imei']), redisClient=self.redisMessaging) + #Check if IMSI also specified + if len(result['imsi']) != 0: + self.logTool.log(service='Database', level='debug', message="With IMEI matched, now checking if IMSI matches regex", redisClient=self.redisMessaging) + if re.match(result['imsi'], imsi): + self.logTool.log(service='Database', level='debug', message="IMSI also matched, so match OK!", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + else: + self.logTool.log(service='Database', level='debug', message="No IMSI specified, so match OK!", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + except Exception as E: + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + try: + session.commit() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to commit session, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="No matches at all - Returning default response", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=self.config['eir']['no_match_response']) + self.safe_close(session) + return self.config['eir']['no_match_response'] + + def Get_EIR_Rules(self): + self.logTool.log(service='Database', level='debug', message="Getting all EIR Rules", redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + EIR_Rules = [] + try: + results = session.query(EIR) + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + EIR_Rules.append(result) + except Exception as E: + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="Final EIR_Rules: " + str(EIR_Rules), redisClient=self.redisMessaging) + self.safe_close(session) + return EIR_Rules + + + def dict_bytes_to_dict_string(self, dict_bytes): + dict_string = {} + for key, value in dict_bytes.items(): + dict_string[key.decode()] = value.decode() + return + + def find_imei_in_tac_list(self, imei, tacList): + """ + Iterate over every tac in the tacList and try to match the first 8 digits of the IMEI. + If that fails, try to match the first 6 digits of the IMEI. + """ + for tac in tacList['tacList']: + for key, value in tac.items(): + if str(key) == str(imei[0:8]): + return {'tacPrefix': key, 'name': tac[key]['name'], 'model': tac[key]['model']} + for key, value in tac.items(): + if str(key) == str(imei[0:6]): + return {'tacPrefix': key, 'name': tac[key]['name'], 'model': tac[key]['model']} + return {} + + def get_device_info_from_TAC(self, imei) -> dict: + self.logTool.log(service='Database', level='debug', message="Getting Device Info from IMEI: " + str(imei), redisClient=self.redisMessaging) + try: + self.logTool.log(service='Database', level='debug', message="Taclist: self.tacList ", redisClient=self.redisMessaging) + imei_result = self.find_imei_in_tac_list(imei, self.tacData) + assert(len(imei_result) != 0) + self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) + return imei_result + except: + self.logTool.log(service='Database', level='debug', message="Failed to match on 8 digit IMEI", redisClient=self.redisMessaging) + + raise ValueError("No matching TAC in IMEI Database") + + +if __name__ == "__main__": + import binascii,os,pprint + DeleteAfter = True + database = Database() + + #Define Charging Rule + charging_rule = { + 'rule_name' : 'charging_rule_A', + 'qci' : 4, + 'arp_priority' : 5, + 'arp_preemption_capability' : True, + 'arp_preemption_vulnerability' : False, + 'mbr_dl' : 128000, + 'mbr_ul' : 128000, + 'gbr_dl' : 128000, + 'gbr_ul' : 128000, + 'tft_group_id' : 1, + 'precedence' : 100, + 'rating_group' : 20000 + } + print("Creating Charging Rule A") + ChargingRule_newObj_A = database.CreateObj(CHARGING_RULE, charging_rule) + print("ChargingRule_newObj A: " + str(ChargingRule_newObj_A)) + charging_rule['gbr_ul'], charging_rule['gbr_dl'], charging_rule['mbr_ul'], charging_rule['mbr_dl'] = 256000, 256000, 256000, 256000 + print("Creating Charging Rule B") + charging_rule['rule_name'], charging_rule['precedence'], charging_rule['tft_group_id'] = 'charging_rule_B', 80, 2 + ChargingRule_newObj_B = database.CreateObj(CHARGING_RULE, charging_rule) + print("ChargingRule_newObj B: " + str(ChargingRule_newObj_B)) + + #Define TFTs + tft_template1 = { + 'tft_group_id' : 1, + 'tft_string' : 'permit out ip from any to any', + 'direction' : 1 + } + tft_template2 = { + 'tft_group_id' : 1, + 'tft_string' : 'permit out ip from any to any', + 'direction' : 2 + } + print("Creating TFT") + database.CreateObj(TFT, tft_template1) + database.CreateObj(TFT, tft_template2) + + tft_template3 = { + 'tft_group_id' : 2, + 'tft_string' : 'permit out ip from 10.98.0.0 255.255.255.0 to any', + 'direction' : 1 + } + tft_template4 = { + 'tft_group_id' : 2, + 'tft_string' : 'permit out ip from any to 10.98.0.0 255.255.255.0', + 'direction' : 2 + } + print("Creating TFT") + database.CreateObj(TFT, tft_template3) + database.CreateObj(TFT, tft_template4) + + + apn2 = { + 'apn':'ims', + 'apn_ambr_dl' : 9999, + 'apn_ambr_ul' : 9999, + 'arp_priority': 1, + 'arp_preemption_capability' : False, + 'arp_preemption_vulnerability': True, + 'charging_rule_list' : str(ChargingRule_newObj_A['charging_rule_id']) + "," + str(ChargingRule_newObj_B['charging_rule_id']) + } + print("Creating APN " + str(apn2['apn'])) + newObj = database.CreateObj(APN, apn2) + print(newObj) + + print("Getting APN " + str(apn2['apn'])) + print(database.GetObj(APN, newObj['apn_id'])) + apn_id = newObj['apn_id'] + UpdatedObj = newObj + UpdatedObj['apn'] = 'UpdatedInUnitTest' + + print("Updating APN " + str(apn2['apn'])) + newObj = database.UpdateObj(APN, UpdatedObj, newObj['apn_id']) + print(newObj) + + #Create AuC + auc_json = { + "ki": binascii.b2a_hex(os.urandom(16)).zfill(16), + "opc": binascii.b2a_hex(os.urandom(16)).zfill(16), + "amf": "9000", + "sqn": 0 + } + print(auc_json) + print("Creating AuC entry") + newObj = database.CreateObj(AUC, auc_json) + print(newObj) + + #Get AuC + print("Getting AuC entry") + newObj = database.GetObj(AUC, newObj['auc_id']) + auc_id = newObj['auc_id'] + print(newObj) + + #Update AuC + print("Updating AuC entry") + newObj['sqn'] = newObj['sqn'] + 10 + newObj = database.UpdateObj(AUC, newObj, auc_id) + + #Generate Vectors + print("Generating Vectors") + database.Get_Vectors_AuC(auc_id, "air", plmn='12ff') + print(database.Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) + + + #Update AuC + database.Update_AuC(auc_id, sqn=100) + + #New Subscriber + subscriber_json = { + "imsi": "001001000000006", + "enabled": True, + "msisdn": "12345678", + "ue_ambr_dl": 999999, + "ue_ambr_ul": 999999, + "nam": 0, + "subscribed_rau_tau_timer": 600, + "auc_id" : auc_id, + "default_apn" : apn_id, + "apn_list" : apn_id + } + + #Delete IMSI if already exists + try: + existing_sub_data = database.Get_Subscriber(imsi=subscriber_json['imsi']) + database.DeleteObj(SUBSCRIBER, existing_sub_data['subscriber_id']) + except: + print("Did not find old sub to delete") + + print("Creating new Subscriber") + print(subscriber_json) + newObj = database.CreateObj(SUBSCRIBER, subscriber_json) + print(newObj) + subscriber_id = newObj['subscriber_id'] + + #Get SUBSCRIBER + print("Getting Subscriber") + newObj = database.GetObj(SUBSCRIBER, subscriber_id) + print(newObj) + + #Update SUBSCRIBER + print("Updating Subscriber") + newObj['ue_ambr_ul'] = 999995 + newObj = database.UpdateObj(SUBSCRIBER, newObj, subscriber_id) + + #Set MME Location for Subscriber + print("Updating Serving MME for Subscriber") + database.Update_Serving_MME(imsi=newObj['imsi'], serving_mme="Test123", serving_mme_peer="Test123", serving_mme_realm="TestRealm") + + #Update Serving APN for Subscriber + print("Updating Serving APN for Subscriber") + database.Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='kjsdlkjfd', serving_pgw='pgw.test.com', subscriber_routing='1.2.3.4') + + print("Getting Charging Rule for Subscriber / APN Combo") + ChargingRule = database.Get_Charging_Rules(imsi=newObj['imsi'], apn=apn2['apn']) + pprint.pprint(ChargingRule) + + #New IMS Subscriber + ims_subscriber_json = { + "msisdn": newObj['msisdn'], + "msisdn_list": newObj['msisdn'], + "imsi": subscriber_json['imsi'], + "ifc_path" : "default_ifc.xml", + "sh_profile" : "default_sh_user_data.xml" + } + print(ims_subscriber_json) + newObj = database.CreateObj(IMS_SUBSCRIBER, ims_subscriber_json) + print(newObj) + ims_subscriber_id = newObj['ims_subscriber_id'] + + + #Test Get Subscriber + print("Test Getting Subscriber") + GetSubscriber_Result = database.Get_Subscriber(imsi=subscriber_json['imsi']) + print(GetSubscriber_Result) + + #Test IMS Get Subscriber + print("Getting IMS Subscribers") + print(database.Get_IMS_Subscriber(imsi='001001000000006')) + print(database.Get_IMS_Subscriber(msisdn='12345678')) + + #Set SCSCF for Subscriber + database.Update_Serving_CSCF(newObj['imsi'], "NickTestCSCF") + #Get Served Subscriber List + print(database.Get_Served_IMS_Subscribers()) + + #Clear Serving PGW for PCRF Subscriber + print("Clear Serving PGW for PCRF Subscriber") + database.Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='sessionid123', serving_pgw=None, subscriber_routing=None) + + #Clear MME Location for Subscriber + print("Clear MME Location for Subscriber") + database.Update_Serving_MME(newObj['imsi'], None) + + #Generate Vectors for IMS Subscriber + print("Generating Vectors for IMS Subscriber") + print(database.Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) + + #print("Generating Resync for IMS Subscriber") + #print(Get_Vectors_AuC(auc_id, "sqn_resync", auts='7964347dfdfe432289522183fcfb', rand='1bc9f096002d3716c65e4e1f4c1c0d17')) + + #Test getting APNs + GetAPN_Result = database.Get_APN(GetSubscriber_Result['default_apn']) + print(GetAPN_Result) + + #handleGeored({"imsi": "001001000000006", "serving_mme": "abc123"}) + + + if DeleteAfter == True: + #Delete IMS Subscriber + print(database.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id)) + #Delete Subscriber + print(database.DeleteObj(SUBSCRIBER, subscriber_id)) + #Delete AuC + print(database.DeleteObj(AUC, auc_id)) + #Delete APN + print(database.DeleteObj(APN, apn_id)) + + #Whitelist IMEI / IMSI Binding + eir_template = {'imei': '1234', 'imsi': '567', 'regex_mode': 0, 'match_response_code': 0} + database.CreateObj(EIR, eir_template) + + #Blacklist Example + eir_template = {'imei': '99881232', 'imsi': '', 'regex_mode': 0, 'match_response_code': 1} + database.CreateObj(EIR, eir_template) + + #IMEI Prefix Regex Example (Blacklist all IMEIs starting with 666) + eir_template = {'imei': '^666.*', 'imsi': '', 'regex_mode': 1, 'match_response_code': 1} + database.CreateObj(EIR, eir_template) + + #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) + eir_template = {'imei': '^777.*', 'imsi': '^1234123412341234$', 'regex_mode': 1, 'match_response_code': 2} + database.CreateObj(EIR, eir_template) + + print("\n\n\n\n") + #Check Whitelist (No Match) + assert database.Check_EIR(imei='1234', imsi='') == 2 + + print("\n\n\n\n") + #Check Whitelist (Matched) + assert database.Check_EIR(imei='1234', imsi='567') == 0 + + print("\n\n\n\n") + #Check Blacklist (Match) + assert database.Check_EIR(imei='99881232', imsi='567') == 1 + + print("\n\n\n\n") + #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) + assert database.Check_EIR(imei='7771234', imsi='1234123412341234') == 2 + + print(database.Get_IMEI_IMSI_History('1234123412')) + + + print("\n\n\n") + print(database.Generate_JSON_Model_for_Flask(SUBSCRIBER)) + + + diff --git a/lib/diameter.py b/lib/diameter.py new file mode 100644 index 0000000..4692925 --- /dev/null +++ b/lib/diameter.py @@ -0,0 +1,3580 @@ +#Diameter Packet Decoder / Encoder & Tools +import socket +import binascii +import math +import uuid +import os +import random +import ipaddress +import jinja2 +from database import Database +from messaging import RedisMessaging +import yaml +import json +import time +import traceback + +class Diameter: + + def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999", redisMessaging=None): + with open("../config.yaml", 'r') as stream: + self.config = (yaml.safe_load(stream)) + + self.OriginHost = self.string_to_hex(originHost) + self.OriginRealm = self.string_to_hex(originRealm) + self.ProductName = self.string_to_hex(productName) + self.MNC = str(mnc) + self.MCC = str(mcc) + self.logTool = logTool + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + if redisMessaging: + self.redisMessaging = redisMessaging + else: + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + + self.database = Database(logTool=logTool) + self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) + + self.templateLoader = jinja2.FileSystemLoader(searchpath="../") + self.templateEnv = jinja2.Environment(loader=self.templateLoader) + + self.logTool.log(service='HSS', level='info', message=f"Initialized Diameter Library", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Origin Host: {str(originHost)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Realm: {str(originRealm)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Product Name: {str(productName)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"PLMN: {str(self.MCC)}/{str(self.MNC)}", redisClient=self.redisMessaging) + + self.diameterResponseList = [ + {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, + {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, + {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, + {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, + {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, + {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, + {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, + {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, + {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, + {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, + {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 258, "applicationId": 16777238, "responseMethod": self.Answer_16777238_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, + {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, + {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + ] + + self.diameterRequestList = [ + {"commandCode": 304, "applicationId": 16777216, "requestMethod": self.Request_16777216_304, "failureResultCode": 5012 ,"requestAcronym": "RTR", "responseAcronym": "RTA", "requestName": "Registration Termination Request", "responseName": "Registration Termination Answer"}, + {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 272, "applicationId": 16777238, "requestMethod": self.Request_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, + {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, + ] + + #Generates rounding for calculating padding + def myround(self, n, base=4): + if(n > 0): + return math.ceil(n/4.0) * 4 + elif( n < 0): + return math.floor(n/4.0) * 4 + else: + return 4 + + #Converts a dotted-decimal IPv4 address or IPV6 address to hex + def ip_to_hex(self, ip): + #Determine IPvX version: + if "." in ip: + ip = ip.split('.') + ip_hex = "0001" #IPv4 + ip_hex = ip_hex + str(format(int(ip[0]), 'x').zfill(2)) + ip_hex = ip_hex + str(format(int(ip[1]), 'x').zfill(2)) + ip_hex = ip_hex + str(format(int(ip[2]), 'x').zfill(2)) + ip_hex = ip_hex + str(format(int(ip[3]), 'x').zfill(2)) + else: + ip_hex = "0002" #IPv6 + ip_hex += format(ipaddress.IPv6Address(ip), 'X') + return ip_hex + + def hex_to_int(self, hex): + return int(str(hex), base=16) + + + #Converts a hex formatted IPv4 address or IPV6 address to dotted-decimal + def hex_to_ip(self, hex_ip): + if len(hex_ip) == 8: + octet_1 = int(str(hex_ip[0:2]), base=16) + octet_2 = int(str(hex_ip[2:4]), base=16) + octet_3 = int(str(hex_ip[4:6]), base=16) + octet_4 = int(str(hex_ip[6:8]), base=16) + return str(octet_1) + "." + str(octet_2) + "." + str(octet_3) + "." + str(octet_4) + elif len(hex_ip) == 32: + n=4 + ipv6_split = [hex_ip[idx:idx + n] for idx in range(0, len(hex_ip), n)] + ipv6_str = '' + for octect in ipv6_split: + ipv6_str += str(octect).lstrip('0') + ":" + #Strip last Colon + ipv6_str = ipv6_str[:-1] + return ipv6_str + + #Converts string to hex + def string_to_hex(self, string): + string_bytes = string.encode('utf-8') + return str(binascii.hexlify(string_bytes), 'ascii') + + #Converts int to hex padded to required number of bytes + def int_to_hex(self, input_int, output_bytes): + + return format(input_int,"x").zfill(output_bytes*2) + + #Converts Hex byte to Binary + def hex_to_bin(self, input_hex): + return bin(int(str(input_hex), 16))[2:].zfill(8) + + #Generates a valid random ID to use + def generate_id(self, length): + length = length * 2 + return str(uuid.uuid4().hex[:length]) + + def Reverse(self, str): + stringlength=len(str) + slicedString=str[stringlength::-1] + return (slicedString) + + def DecodePLMN(self, plmn): + self.logTool.log(service='HSS', level='debug', message="Decoded PLMN: " + str(plmn), redisClient=self.redisMessaging) + mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') + self.logTool.log(service='HSS', level='debug', message="Decoded MCC: " + mcc, redisClient=self.redisMessaging) + + mnc = self.Reverse(plmn[4:6]) + self.logTool.log(service='HSS', level='debug', message="Decoded MNC: " + mnc, redisClient=self.redisMessaging) + return mcc, mnc + + def EncodePLMN(self, mcc, mnc): + plmn = list('XXXXXX') + plmn[0] = self.Reverse(mcc)[1] + plmn[1] = self.Reverse(mcc)[2] + plmn[2] = "f" + plmn[3] = self.Reverse(mcc)[0] + plmn[4] = self.Reverse(mnc)[0] + plmn[5] = self.Reverse(mnc)[1] + plmn_list = plmn + plmn = '' + for bits in plmn_list: + plmn = plmn + bits + self.logTool.log(service='HSS', level='debug', message="Encoded PLMN: " + str(plmn), redisClient=self.redisMessaging) + return plmn + + def TBCD_special_chars(self, input): + self.logTool.log(service='HSS', level='debug', message="Special character possible in " + str(input), redisClient=self.redisMessaging) + if input == "*": + self.logTool.log(service='HSS', level='debug', message="Found * - Returning 1010", redisClient=self.redisMessaging) + return "1010" + elif input == "#": + self.logTool.log(service='HSS', level='debug', message="Found # - Returning 1011", redisClient=self.redisMessaging) + return "1011" + elif input == "a": + self.logTool.log(service='HSS', level='debug', message="Found a - Returning 1100", redisClient=self.redisMessaging) + return "1100" + elif input == "b": + self.logTool.log(service='HSS', level='debug', message="Found b - Returning 1101", redisClient=self.redisMessaging) + return "1101" + elif input == "c": + self.logTool.log(service='HSS', level='debug', message="Found c - Returning 1100", redisClient=self.redisMessaging) + return "1100" + else: + binform = "{:04b}".format(int(input)) + self.logTool.log(service='HSS', level='debug', message="input " + str(input) + " is not a special char, converted to bin: " + str(binform), redisClient=self.redisMessaging) + return (binform) + + def TBCD_encode(self, input): + self.logTool.log(service='HSS', level='debug', message="TBCD_encode input value is " + str(input), redisClient=self.redisMessaging) + offset = 0 + output = '' + matches = ['*', '#', 'a', 'b', 'c'] + while offset < len(input): + if len(input[offset:offset+2]) == 2: + self.logTool.log(service='HSS', level='debug', message="processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset), redisClient=self.redisMessaging) + bit = input[offset:offset+2] #Get two digits at a time + bit = bit[::-1] #Reverse them + #Check if *, #, a, b or c + if any(x in bit for x in matches): + self.logTool.log(service='HSS', level='debug', message="Special char in bit " + str(bit), redisClient=self.redisMessaging) + new_bit = '' + new_bit = new_bit + str(self.TBCD_special_chars(bit[0])) + new_bit = new_bit + str(self.TBCD_special_chars(bit[1])) + self.logTool.log(service='HSS', level='debug', message="Final bin output of new_bit is " + str(new_bit), redisClient=self.redisMessaging) + bit = hex(int(new_bit, 2))[2:] #Get Hex value + self.logTool.log(service='HSS', level='debug', message="Formatted as Hex this is " + str(bit), redisClient=self.redisMessaging) + output = output + bit + offset = offset + 2 + else: + #If odd-length input + last_digit = str(input[offset:offset+2]) + #Check if *, #, a, b or c + if any(x in last_digit for x in matches): + self.logTool.log(service='HSS', level='debug', message="Special char in bit " + str(bit), redisClient=self.redisMessaging) + new_bit = '' + new_bit = new_bit + '1111' #Add the F first + #Encode the symbol into binary and append it to the new_bit var + new_bit = new_bit + str(self.TBCD_special_chars(last_digit)) + self.logTool.log(service='HSS', level='debug', message="Final bin output of new_bit is " + str(new_bit), redisClient=self.redisMessaging) + bit = hex(int(new_bit, 2))[2:] #Get Hex value + self.logTool.log(service='HSS', level='debug', message="Formatted as Hex this is " + str(bit), redisClient=self.redisMessaging) + else: + bit = "f" + last_digit + offset = offset + 2 + output = output + bit + self.logTool.log(service='HSS', level='debug', message="TBCD_encode final output value is " + str(output), redisClient=self.redisMessaging) + return output + + def TBCD_decode(self, input): + self.logTool.log(service='HSS', level='debug', message="TBCD_decode Input value is " + str(input), redisClient=self.redisMessaging) + offset = 0 + output = '' + while offset < len(input): + if "f" not in input[offset:offset+2]: + bit = input[offset:offset+2] #Get two digits at a time + bit = bit[::-1] #Reverse them + output = output + bit + offset = offset + 2 + else: #If f in bit strip it + bit = input[offset:offset+2] + output = output + bit[1] + self.logTool.log(service='HSS', level='debug', message="TBCD_decode output value is " + str(output), redisClient=self.redisMessaging) + return output + + #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) + #AVP content must already be in HEX - This can be done with binascii.hexlify(avp_content.encode()) + def generate_avp(self, avp_code, avp_flags, avp_content): + avp_code = format(avp_code,"x").zfill(8) + + avp_length = 1 ##This is a placeholder that's overwritten later + + #AVP Must always be a multiple of 4 - Round up to nearest multiple of 4 and fill remaining bits with padding + avp = str(avp_code) + str(avp_flags) + str("000000") + str(avp_content) + avp_length = int(len(avp)/2) + + if avp_length % 4 == 0: #Multiple of 4 - No Padding needed + avp_padding = '' + else: #Not multiple of 4 - Padding needed + rounded_value = self.myround(avp_length) + avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) + + avp = str(avp_code) + str(avp_flags) + str(format(avp_length,"x").zfill(6)) + str(avp_content) + str(avp_padding) + return avp + + #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) + #AVP content must already be in HEX - This can be done with binascii.hexlify(avp_content.encode()) + def generate_vendor_avp(self, avp_code, avp_flags, avp_vendorid, avp_content): + avp_code = format(avp_code,"x").zfill(8) + + avp_length = 1 ##This is a placeholder that gets overwritten later + + avp_vendorid = format(int(avp_vendorid),"x").zfill(8) + + #AVP Must always be a multiple of 4 - Round up to nearest multiple of 4 and fill remaining bits with padding + avp = str(avp_code) + str(avp_flags) + str("000000") + str(avp_vendorid) + str(avp_content) + avp_length = int(len(avp)/2) + + if avp_length % 4 == 0: #Multiple of 4 - No Padding needed + avp_padding = '' + else: #Not multiple of 4 - Padding needed + rounded_value = self.myround(avp_length) + self.logTool.log(service='HSS', level='debug', message="Rounded value is " + str(rounded_value), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Has " + str( int( rounded_value - avp_length)) + " bytes of padding", redisClient=self.redisMessaging) + avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) + + + + avp = str(avp_code) + str(avp_flags) + str(format(avp_length,"x").zfill(6)) + str(avp_vendorid) + str(avp_content) + str(avp_padding) + return avp + + def generate_diameter_packet(self, packet_version, packet_flags, packet_command_code, packet_application_id, packet_hop_by_hop_id, packet_end_to_end_id, avp): + try: + packet_length = 228 + packet_length = format(packet_length,"x").zfill(6) + + packet_command_code = format(packet_command_code,"x").zfill(6) + + packet_application_id = format(packet_application_id,"x").zfill(8) + + packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp + packet_length = int(round(len(packet_hex))/2) + packet_length = format(packet_length,"x").zfill(6) + + packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp + return packet_hex + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [generate_diameter_packet] Exception: {e}", redisClient=self.redisMessaging) + + + + def roundUpToMultiple(self, n, multiple): + return ((n + multiple - 1) // multiple) * multiple + + + def validateSingleAvp(self, data) -> bool: + """ + Attempts to validate a single hex string diameter AVP as being an AVP. + """ + try: + avpCode = int(data[0:8], 16) + # The next byte contains the AVP Flags + avpFlags = data[8:10] + # The next 3 bytes contain the AVP Length + avpLength = int(data[10:16], 16) + if avpFlags not in ['80', '40', '20', '00', 'c0']: + return False + if int(len(data[16:]) / 2) < ((avpLength - 8)): + return False + return True + except Exception as e: + return False + + + def decode_diameter_packet(self, data): + """ + Handles decoding of a full diameter packet. + """ + packet_vars = {} + avps = [] + + if type(data) is bytes: + data = data.hex() + # One byte is 2 hex characters + # First Byte is the Diameter Packet Version + packet_vars['packet_version'] = data[0:2] + # Next 3 Bytes are the length of the entire Diameter packet + packet_vars['length'] = int(data[2:8], 16) + # Next Byte is the Diameter Flags + packet_vars['flags'] = data[8:10] + packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + # Next 3 Bytes are the Diameter Command Code + packet_vars['command_code'] = int(data[10:16], 16) + # Next 4 Bytes are the Application Id + packet_vars['ApplicationId'] = int(data[16:24], 16) + # Next 4 Bytes are the Hop By Hop Identifier + packet_vars['hop-by-hop-identifier'] = data[24:32] + # Next 4 Bytes are the End to End Identifier + packet_vars['end-to-end-identifier'] = data[32:40] + + + lengthOfDiameterVars = int(len(data[:40]) / 2) + + #Length of all AVPs, in bytes + avpLength = int(packet_vars['length'] - lengthOfDiameterVars) + avpCharLength = int((avpLength * 2)) + remaining_avps = data[40:] + + avps = self.decodeAvpPacket(remaining_avps) + + return packet_vars, avps + + def decodeAvpPacket(self, data): + """ + Returns a list of decoded AVP Packet dictionaries. + This function is called at a high frequency, decoding methods should stick to iteration and not recursion, to avoid a memory leak. + """ + # Note: After spending hours on this, I'm leaving the following technical debt: + # Subavps and all their descendents are lifted up, flat, side by side into the parent's sub_avps list. + # It's definitely possible to keep the nested tree structure, if anyone wants to improve this function. But I can't figure out a simple way to do so, without invoking recursion. + + + # Our final list of AVP Dictionaries, which will be returned once processing is complete. + processed_avps = [] + # Initialize a failsafe counter, to prevent packets that pass validation but aren't AVPs from causing an infinite loop + failsafeCounter = 0 + + # If the avp data is 8 bytes (16 chars) or less, it's invalid. + if len(data) < 16: + return [] + + # Working stack to aid in iterative processing of sub-avps. + subAvpUnprocessedStack = [] + + # Keep processing AVPs until they're all dealt with + while len(data) > 16: + try: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + avp_vars = {} + # The first 4 bytes contains the AVP code + avp_vars['avp_code'] = int(data[0:8], 16) + # The next byte contains the AVP Flags + avp_vars['avp_flags'] = data[8:10] + # The next 3 bytes contains the AVP Length + avp_vars['avp_length'] = int(data[10:16], 16) + # The remaining bytes (until the end, defined by avp_length) is the AVP payload. + # Padding is excluded from avp_length. It's calculated separately, and unknown by the AVP itself. + # We calculate the avp payload length (in bytes) by subtracting 8, because the avp headers are always 8 bytes long. + # The result is then multiplied by 2 to give us chars. + avpPayloadLength = int((avp_vars['avp_length'])*2) + + # Work out our vendor id and add the payload itself (misc_data) + if avp_vars['avp_flags'] == 'c0' or avp_vars['avp_flags'] == '80': + avp_vars['vendor_id'] = int(data[16:24], 16) + avp_vars['misc_data'] = data[24:avpPayloadLength] + else: + avp_vars['vendor_id'] = '' + avp_vars['misc_data'] = data[16:avpPayloadLength] + + payloadContainsSubAvps = self.validateSingleAvp(avp_vars['misc_data']) + if payloadContainsSubAvps: + # If the payload contains sub or grouped AVPs, append misc_data to the subAvpUnprocessedStack to start working through one or more subavp + subAvpUnprocessedStack.append(avp_vars["misc_data"]) + avp_vars["misc_data"] = '' + + # Rounds up the length to the nearest multiple of 4, which we can differential against the avp length to give us the padding length (if required) + avp_padded_length = int((self.roundUpToMultiple(avp_vars['avp_length'], 4))) + avpPaddingLength = ((avp_padded_length - avp_vars['avp_length']) * 2) + + # Initialize a blank sub_avps list, regardless of whether or not we have any sub avps. + avp_vars['sub_avps'] = [] + + while payloadContainsSubAvps: + # Increment our failsafe counter, which will fail after 100 tries. This prevents a rare validation error from causing the function to hang permanently. + failsafeCounter += 1 + + if failsafeCounter > 100: + break + + # Pop the sub avp data from the list (remove from the end) + sub_avp_data = subAvpUnprocessedStack.pop() + + # Initialize our sub avp dictionary, and grab the usual values + sub_avp = {} + sub_avp['avp_code'] = int(sub_avp_data[0:8], 16) + sub_avp['avp_flags'] = sub_avp_data[8:10] + sub_avp['avp_length'] = int(sub_avp_data[10:16], 16) + sub_avpPayloadLength = int((sub_avp['avp_length'])*2) + + if sub_avp['avp_flags'] == 'c0' or sub_avp['avp_flags'] == '80': + sub_avp['vendor_id'] = int(sub_avp_data[16:24], 16) + sub_avp['misc_data'] = sub_avp_data[24:sub_avpPayloadLength] + else: + sub_avp['vendor_id'] = '' + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + + containsSubAvps = self.validateSingleAvp(sub_avp["misc_data"]) + if containsSubAvps: + subAvpUnprocessedStack.append(sub_avp["misc_data"]) + sub_avp["misc_data"] = '' + + avp_vars['sub_avps'].append(sub_avp) + + sub_avp_padded_length = int((self.roundUpToMultiple(sub_avp['avp_length'], 4))) + subAvpPaddingLength = ((sub_avp_padded_length - sub_avp['avp_length']) * 2) + + sub_avp_data = sub_avp_data[sub_avpPayloadLength+subAvpPaddingLength:] + containsNestedSubAvps = self.validateSingleAvp(sub_avp_data) + + # Check for nested sub avps and bring them to the top of the stack, for further processing. + if containsNestedSubAvps: + subAvpUnprocessedStack.append(sub_avp_data) + + if containsSubAvps or containsNestedSubAvps: + payloadContainsSubAvps = True + else: + payloadContainsSubAvps = False + + if avpPaddingLength > 0: + processed_avps.append(avp_vars) + data = data[avpPayloadLength+avpPaddingLength:] + else: + processed_avps.append(avp_vars) + data = data[avpPayloadLength:] + except Exception as e: + print(e) + continue + + return processed_avps + + def get_avp_data(self, avps, avp_code): #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) + misc_data = [] + for avpObject in avps: + if int(avpObject['avp_code']) == int(avp_code): + if len(avpObject['misc_data']) == 0: + misc_data.append(avpObject['sub_avps']) + else: + misc_data.append(avpObject['misc_data']) + if 'sub_avps' in avpObject: + for sub_avp in avpObject['sub_avps']: + if int(sub_avp['avp_code']) == int(avp_code): + misc_data.append(sub_avp['misc_data']) + return misc_data + + def decode_diameter_packet_length(self, data): + packet_vars = {} + data = data.hex() + packet_vars['packet_version'] = data[0:2] + packet_vars['length'] = int(data[2:8], 16) + if packet_vars['packet_version'] == "01": + return packet_vars['length'] + else: + return False + + def getPeerType(self, originHost: str) -> str: + try: + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + + for peer in peerTypes: + if peer in originHost.lower(): + return peer + + except Exception as e: + return '' + + def getConnectedPeersByType(self, peerType: str) -> list: + try: + peerType = peerType.lower() + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + + if peerType not in peerTypes: + return [] + filteredConnectedPeers = [] + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + filteredConnectedPeers.append(activePeers.get(key, {})) + + return filteredConnectedPeers + + except Exception as e: + return [] + + def getPeerByHostname(self, hostname: str) -> dict: + try: + hostname = hostname.lower() + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + return(activePeers.get(key, {})) + + except Exception as e: + return {} + + def getDiameterMessageType(self, binaryData: str) -> dict: + """ + Determines whether a message is a request or a response, and the appropriate acronyms for each type. + """ + packet_vars, avps = self.decode_diameter_packet(binaryData) + response = {} + + for diameterApplication in self.diameterResponseList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if packet_vars["flags_bin"][0:1] == "1": + response['inbound'] = diameterApplication["requestAcronym"] + response['outbound'] = diameterApplication["responseAcronym"] + else: + response['inbound'] = diameterApplication["responseAcronym"] + response['outbound'] = diameterApplication["requestAcronym"] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Matched message types: {response}", redisClient=self.redisMessaging) + except Exception as e: + continue + return response + + def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: + """ + Sends a given diameter request of requestType to the provided peer hostname, if the peer is connected. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeer = self.getPeerByHostname(hostname=hostname) + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + sendTime = time.time_ns() + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + return request + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Error generating diameter outbound request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + + def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> bool: + """ + Sends a diameter request of requestType to one or more connected peers, specified by peerType. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Broadcasting a diameter outbound request of type: {requestType} to peers of type: {peerType}", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeerList = self.getConnectedPeersByType(peerType=peerType) + for connectedPeer in connectedPeerList: + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + sendTime = time.time_ns() + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Queueing for peer type: {peerType} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + return connectedPeerList + except Exception as e: + return '' + + def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeout: float=0.12, **kwargs) -> str: + """ + Sends a given diameter request of requestType to the provided peer hostname. + Ensures the peer is connected, sends the request, then waits on and returns the response. + If the timeout is reached, the function fails. + + Diameter lacks a unique identifier for all message types, the closest being Session-ID which exists for most. + We attempt to get the associated response given the following logic: + - If sessionId is none, attempt to return the first response that matches the expected response method (eg AAA, CEA, etc.) which has a timestamp greater than sendTime. + - If sessionId is not none, perform the logic above, and also ensure that sessionId matches. + + Returns an empty string if fails. + + Until diameter.py is rewritten to be asynchronous, this method should be called only when strictly necessary. It potentially adds up to 120ms of delay per invocation. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeer = self.getPeerByHostname(hostname=hostname) + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' + request = diameterApplication["requestMethod"](**kwargs) + responseType = diameterApplication["responseAcronym"] + sessionId = kwargs.get('sessionId', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + sendTime = time.time_ns() + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + startTimer = time.time() + while True: + try: + if not time.time() >= startTimer + timeout: + if sessionId is None: + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages(NoSessionId): {queuedMessages}", redisClient=self.redisMessaging) + for queuedMessage in queuedMessages: + queuedMessage = json.loads(queuedMessage) + clientAddress = queuedMessage.get('clientAddress', None) + clientPort = queuedMessage.get('clientPort', None) + if clientAddress != peerIp or clientPort != peerPort: + continue + messageReceiveTime = queuedMessage.get('inbound-received-timestamp', None) + if float(messageReceiveTime) > sendTime: + messageHex = queuedMessage.get('diameter-inbound') + messageType = self.getDiameterMessageType(messageHex) + if messageType['inbound'].upper() == responseType.upper(): + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Found inbound response: {inboundResponse}", redisClient=self.redisMessaging) + return messageHex + time.sleep(0.02) + else: + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages({sessionId}): {queuedMessages} responseType: {responseType}", redisClient=self.redisMessaging) + for queuedMessage in queuedMessages: + queuedMessage = json.loads(queuedMessage) + clientAddress = queuedMessage.get('clientAddress', None) + clientPort = queuedMessage.get('clientPort', None) + if clientAddress != peerIp or clientPort != peerPort: + continue + messageReceiveTime = queuedMessage.get('inbound-received-timestamp', None) + if float(messageReceiveTime) > sendTime: + messageHex = queuedMessage.get('diameter-inbound') + messageType = self.getDiameterMessageType(messageHex) + if messageType['inbound'].upper() == responseType.upper(): + packetVars, avps = self.decode_diameter_packet(messageHex) + messageSessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') + if messageSessionId == sessionId: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Matched on Session Id: {sessionId}", redisClient=self.redisMessaging) + return messageHex + time.sleep(0.02) + else: + return '' + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Traceback: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Error generating diameter outbound request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + + def generateDiameterResponse(self, binaryData: str) -> str: + try: + packet_vars, avps = self.decode_diameter_packet(binaryData) + origin_host = self.get_avp_data(avps, 264)[0] + origin_host = binascii.unhexlify(origin_host).decode("utf-8") + response = '' + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] Generating a diameter response", redisClient=self.redisMessaging) + + # Drop packet if it's a response packet: + if packet_vars["flags_bin"][0:1] == "0": + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [generateDiameterResponse] Got a Response, not a request - dropping it.", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=packet_vars, redisClient=self.redisMessaging) + return + + for diameterApplication in self.diameterResponseList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if 'flags' in diameterApplication: + assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Attempting to generate response", redisClient=self.redisMessaging) + response = diameterApplication["responseMethod"](packet_vars, avps) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + break + except Exception as e: + continue + + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_response_count_successful', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Successful Diameter Responses', + metricExpiry=60) + return response + except Exception as e: + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_response_count_fail', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Failed Diameter Responses', + metricExpiry=60) + return '' + + def validateImsSubscriber(self, imsi=None, msisdn=None) -> bool: + """ + Ensures that a given IMSI or MSISDN (Or both, if specified) are associated with a subscriber that is enabled, and has an associated IMS Subscriber record. + """ + if imsi == None and msisdn == None: + return False + + try: + if imsi is not None: + subscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if not subscriberDetails.get('enabled', False): + return False + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=imsi) + except Exception as e: + return False + try: + if msisdn is not None: + subscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + if not subscriberDetails.get('enabled', False): + return False + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=msisdn) + except Exception as e: + return False + + return True + + + def deregisterApn(self, imsi: str=None, msisdn: str=None, apn: str=None) -> bool: + """ + Revokes a given UE's session with the assigned PGW (If it exists), and sends a CLR to the MME. + """ + try: + if imsi is None and msisdn is None: + return False + + if imsi is not None: + subscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if msisdn is not None: + subscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + imsi = subscriberDetails.get('imsi', '') + + if subscriberDetails is None: + return False + + subscriberId = subscriberDetails.get('subscriber_id', None) + + # If a subscriber has an active serving apn, grab the pcrf session id for that apn and send a CCR-T, then a Registration Termination Request to the serving pgw peer. + if subscriberId is not None: + servingApns = self.database.Get_Serving_APNs(subscriber_id=subscriberId) + if len(servingApns.get('apns', {})) > 0: + for apnKey, apnDict in servingApns['apns'].items(): + pcrfSessionId = None + servingPgwPeer = None + servingPgwRealm = None + servingPgw = None + for apnDataKey, apnDataValue in servingApns['apns'][apnKey].items(): + if apnDataKey == 'pcrf_session_id': + pcrfSessionId = apnDataValue + if apnDataKey == 'serving_pgw_peer': + servingPgwPeer = apnDataValue + if apnDataKey == 'serving_pgw_realm': + servingPgwRealm = apnDataValue + if apnDataKey == 'serving_pgw': + servingPgwRealm = apnDataValue + + if pcrfSessionId is not None and servingPgwPeer is not None and servingPgwRealm is not None and servingPgw is not None: + if ';' in servingPgwPeer: + servingPgwPeer = servingPgwPeer.split(';')[0] + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [deregisterData] Sending CCR-T with Session-ID:{pcrfSessionId} to peer: {servingPgwPeer} {apnKey}", redisClient=self.redisMessaging) + + self.sendDiameterRequest( + requestType='CCR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId=pcrfSessionId, + domain=servingPgwRealm + ) + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [deregisterData] Sending RTR to peer: {servingPgwPeer} {apnKey}", redisClient=self.redisMessaging) + + self.sendDiameterRequest( + requestType='RTR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + self.database.Update_Serving_APN(imsi=imsi, apn=apnKey, pcrf_session_id=None, serving_pgw=None, subscriber_routing='') + + return True + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [deregisterIms] Error deregistering subscriber from IMS: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + + def deregisterIms(self, imsi=None, msisdn=None) -> bool: + """ + Revokes a given UE's IMS registration, and sends a RTR to the SCSCF (if defined). + Does not revoke the pgw session, or notify the mme. + """ + try: + if imsi is None and msisdn is None: + return False + + if imsi is not None: + imsSubscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if msisdn is not None: + imsSubscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + + if imsSubscriberDetails is None: + return False + + servingScscf = imsSubscriberDetails.get('scscf', None) + servingScscfPeer = imsSubscriberDetails.get('scscf_peer', None) + servingScscfRealm = imsSubscriberDetails.get('scscf_realm', None) + + if servingScscfPeer is not None and servingScscfRealm is not None and servingScscf is not None: + if ';' in servingScscfPeer: + servingScscfPeer = servingScscfPeer.split(';')[0] + servingScscf = servingScscf.replace('sip:', '') + if ';' in servingScscf: + servingScscf = servingScscf.split(';')[0] + self.sendDiameterRequest( + requestType='RTR', + peerType=servingScscfPeer, + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + + if imsi is not None: + self.database.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + elif msisdn is not None: + self.database.Update_Serving_CSCF(msisdn=msisdn, serving_cscf=None) + + return True + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [deregisterIms] Error deregistering subscriber from IMS: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + + def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body + for avp_dicts in avps: + if avp_dicts['avp_code'] == 278: + origin_state_incriment_int = int(avp_dicts['misc_data'], 16) + origin_state_incriment_int = origin_state_incriment_int + 1 + origin_state_incriment_hex = format(origin_state_incriment_int,"x").zfill(8) + return origin_state_incriment_hex + + def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleName=None, action="install"): + self.logTool.log(service='HSS', level='debug', message=f"Called Charging_Rule_Generator with action: {action}", redisClient=self.redisMessaging) + if action not in ['install', 'remove']: + self.logTool.log(service='HSS', level='debug', message="Invalid action supplied to Charging_Rule_Generator", redisClient=self.redisMessaging) + return None + + if action == 'remove': + if chargingRuleName is None: + self.logTool.log(service='HSS', level='error', message="chargingRuleName must be defined when removing a charging rule", redisClient=self.redisMessaging) + return None + Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(chargingRuleName))),'ascii')) + ChargingRuleDef = Charging_Rule_Name + return self.generate_vendor_avp(1002, "c0", 10415, ChargingRuleDef) + + else: + if ChargingRules is None or ue_ip is None: + self.logTool.log(service='HSS', level='error', message="ChargingRules and ue_ip must be defined when installing a charging rule", redisClient=self.redisMessaging) + return None + + #Install Charging Rules + self.logTool.log(service='HSS', level='debug', message="Naming Charging Rule", redisClient=self.redisMessaging) + Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Named Charging Rule", redisClient=self.redisMessaging) + + #Populate all Flow Information AVPs + Flow_Information = '' + for tft in ChargingRules['tft']: + self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) + #If {{ UE_IP }} in TFT splice in the real UE IP Value + try: + tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) + tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) + self.logTool.log(service='HSS', level='debug', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Failed to splice in UE IP into flow description", redisClient=self.redisMessaging) + + #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional + Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) + Flow_Description = self.generate_vendor_avp(507, "c0", 10415, str(binascii.hexlify(str.encode(tft['tft_string'])),'ascii')) + Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) + + Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='debug', message="Defining QoS information", redisClient=self.redisMessaging) + #QCI + QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) + + #ARP + self.logTool.log(service='HSS', level='debug', message="Defining ARP information", redisClient=self.redisMessaging) + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_vulnerability']), 4)) + ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + + self.logTool.log(service='HSS', level='debug', message="Defining MBR information", redisClient=self.redisMessaging) + #Max Requested Bandwidth + Bandwidth_info = '' + Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) + Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) + + self.logTool.log(service='HSS', level='debug', message="Defining GBR information", redisClient=self.redisMessaging) + #GBR + if int(ChargingRules['gbr_ul']) != 0: + Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) + if int(ChargingRules['gbr_dl']) != 0: + Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) + + #Populate QoS Information + QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) + self.logTool.log(service='HSS', level='debug', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) + + #Precedence + self.logTool.log(service='HSS', level='debug', message="Defining Precedence information", redisClient=self.redisMessaging) + Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) + + #Rating Group + self.logTool.log(service='HSS', level='debug', message="Defining Rating Group information", redisClient=self.redisMessaging) + if ChargingRules['rating_group'] != None: + RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID + else: + RatingGroup = '' + self.logTool.log(service='HSS', level='debug', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) + + + #Complete Charging Rule Defintion + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup + ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) + + #Charging Rule Install + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) + + def Get_IMS_Subscriber_Details_from_AVP(self, username): + #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data + username = binascii.unhexlify(username).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) + username = username.split('@')[0] #Strip Domain to get User part + username = username[4:] #Strip tel: or sip: prefix + #Determine if dealing with IMSI or MSISDN + if (len(username) == 15) or (len(username) == 16): + self.logTool.log(service='HSS', level='debug', message="We have an IMSI: " + str(username), redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(imsi=username) + else: + self.logTool.log(service='HSS', level='debug', message="We have an msisdn: " + str(username), redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(msisdn=username) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) + return ims_subscriber_details + + + def Generate_Prom_Stats(self): + self.logTool.log(service='HSS', level='debug', message="Called Generate_Prom_Stats", redisClient=self.redisMessaging) + try: + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_ims_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_IMS_Subscribers(get_local_users_only=True)), metricHelp='Number of attached IMS Subscribers', + metricExpiry=60) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_mme_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_Subscribers(get_local_users_only=True)), metricHelp='Number of attached MME Subscribers', + metricExpiry=60) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_pcrf_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_PCRF_Subscribers(get_local_users_only=True)), metricHelp='Number of attached PCRF Subscribers', + metricExpiry=60) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message="Failed to generate Prometheus Stats for IMS Subscribers", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=e, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Generated Prometheus Stats for IMS Subscribers", redisClient=self.redisMessaging) + + return + + + #### Diameter Answers #### + + #Capabilities Exchange Answer + def Answer_257(self, packet_vars, avps): + avp = '' #Initiate empty var AVP + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it + if avps_to_check['avp_code'] == 278: + avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) + for host in self.config['hss']['bind_ip']: #Loop through all IPs from Config and add to response + avp += self.generate_avp(257, 40, self.ip_to_hex(host)) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) + avp += self.generate_avp(266, 40, "00000000") #Vendor-Id + avp += self.generate_avp(269, "00", self.ProductName) #Product-Name + + avp += self.generate_avp(267, 40, "000027d9") #Firmware-Revision + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777252),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S13) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777291),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (SLh) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777236),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Rx) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777238),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Gx) + avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID - Diameter Gx + avp += self.generate_avp(258, 40, format(int(10),"x").zfill(8)) #Auth-Application-ID - Diameter CER + avp += self.generate_avp(265, 40, format(int(5535),"x").zfill(8)) #Supported-Vendor-ID (3GGP v2) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) + + response = self.generate_diameter_packet("01", "00", 257, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated CEA", redisClient=self.redisMessaging) + return response + + #Device Watchdog Answer + def Answer_280(self, packet_vars, avps): + + avp = '' #Initiate empty var AVP + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it + if avps_to_check['avp_code'] == 278: + avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) + response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated DWA", redisClient=self.redisMessaging) + orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it + return response + + #Disconnect Peer Answer + def Answer_282(self, packet_vars, avps): + avp = '' #Initiate empty var AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, "000007d1") #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "00", 282, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated DPA", redisClient=self.redisMessaging) + return response + + #3GPP S6a/S6d Update Location Answer + def Answer_16777251_316(self, packet_vars, avps): + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) #AVP: Auth-Application-Id(258) l=12 f=-M- val=3GPP S6a/S6d (16777251) + + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID + SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + + #APNs from DB + APN_Configuration = '' + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI + try: + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + self.logTool.log(service='HSS', level='debug', message="Got back subscriber_details: " + str(subscriber_details), redisClient=self.redisMessaging) + + if subscriber_details['enabled'] == 0: + self.logTool.log(service='HSS', level='debug', message=f"Subscriber {imsi} is disabled", redisClient=self.redisMessaging) + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + self.logTool.log(service='HSS', level='debug', message=f"Successfully Generated ULA for disabled Subscriber: {imsi}", redisClient=self.redisMessaging) + response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) + return response + + except ValueError as e: + self.logTool.log(service='HSS', level='info', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Error is " + str(e), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) + response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + return response + except Exception as ex: + template = "An exception of type {0} occurred. Arguments:\n{1!r}" + message = template.format(type(ex).__name__, ex.args) + raise + + #Store MME Location into Database + OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it + OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it + self.logTool.log(service='HSS', level='debug', message="Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm), redisClient=self.redisMessaging) + + #Find Remote Peer we need to address CLRs through + try: #Check if we have a record-route set as that's where we'll need to send the response + remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header + remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it + except: #If we don't have a record-route set, we'll send the response to the OriginHost + remote_peer = OriginHost + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777251_316] [ULR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + + self.database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) + + #Boilerplate AVPs + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_vendor_avp(1406, "c0", 10415, "00000001") #ULA Flags + + #Subscription Data: + subscription_data = '' + subscription_data += self.generate_vendor_avp(1426, "c0", 10415, "00000000") #Access Restriction Data + subscription_data += self.generate_vendor_avp(1424, "c0", 10415, "00000000") #Subscriber-Status (SERVICE_GRANTED) + subscription_data += self.generate_vendor_avp(1417, "c0", 10415, self.int_to_hex(int(subscriber_details['nam']), 4)) #Network-Access-Mode (PACKET_AND_CIRCUIT) + + #AMBR is a sub-AVP of Subscription Data + AMBR = '' #Initiate empty var AVP for AMBR + ue_ambr_ul = int(subscriber_details['ue_ambr_ul']) + ue_ambr_dl = int(subscriber_details['ue_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + subscription_data += self.generate_vendor_avp(1435, "c0", 10415, AMBR) #Add AMBR AVP in two sub-AVPs + + + subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(int(subscriber_details['subscribed_rau_tau_timer']), 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) + + + #APN Configuration Profile is a sub AVP of Subscription Data + APN_Configuration_Profile = '' + APN_Configuration_Profile += self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(1, 4)) #Context Identifier for default APN (First APN is default in our case) + APN_Configuration_Profile += self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) #All-APN-Configurations-Included-Indicator + + #Split the APN list into a list + apn_list = subscriber_details['apn_list'].split(',') + self.logTool.log(service='HSS', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) + #Remove the default APN from the list + try: + apn_list.remove(str(subscriber_details['default_apn'])) + except: + self.logTool.log(service='HSS', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) + pass + #Add default APN in first position + apn_list.insert(0, str(subscriber_details['default_apn'])) + + self.logTool.log(service='HSS', level='debug', message="APN list: " + str(apn_list), redisClient=self.redisMessaging) + APN_context_identifer_count = 1 + for apn_id in apn_list: + #Per APN Setup + self.logTool.log(service='HSS', level='debug', message="Processing APN ID " + str(apn_id), redisClient=self.redisMessaging) + try: + apn_data = self.database.Get_APN(apn_id) + except: + self.logTool.log(service='HSS', level='error', message="Failed to get APN " + str(apn_id), redisClient=self.redisMessaging) + continue + APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_data['apn']))) + + self.logTool.log(service='HSS', level='debug', message="Setting APN Configuration Profile", redisClient=self.redisMessaging) + #Sub AVPs of APN Configuration Profile + APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) + APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(int(apn_data['ip_version']), 4)) + + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) + #AMBR + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = int(apn_data['apn_ambr_ul']) + apn_ambr_dl = int(apn_data['apn_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) + #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(not apn_data['arp_preemption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) + APN_EPS_Subscribed_QoS_Profile = self.generate_vendor_avp(1431, "c0", 10415, AVP_QoS + AVP_ARP) + + #Try static IP allocation + try: + subscriber_routing_dict = self.database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details + self.logTool.log(service='HSS', level='debug', message="Got static UE IP " + str(subscriber_routing_dict), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(subscriber_routing_dict['ip_address']), redisClient=self.redisMessaging) + Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) + except Exception as E: + self.logTool.log(service='HSS', level='debug', message="No static UE IP found: " + str(E), redisClient=self.redisMessaging) + Served_Party_Address = "" + + + #if 'PDN_GW_Allocation_Type' in apn_profile: + # self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) + # PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) + # self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) + # else: + # PDN_GW_Allocation_Type = '' + # if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: + # self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) + # VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) + # self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) + # else: + # VPLMN_Dynamic_Address_Allowed = '' + PDN_GW_Allocation_Type = '' + VPLMN_Dynamic_Address_Allowed = '' + + #If static SMF / PGW-C defined + if apn_data['pgw_address'] is not None: + self.logTool.log(service='HSS', level='debug', message="MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address']), redisClient=self.redisMessaging) + MIP_Home_Agent_Address = self.generate_avp(334, '40', self.ip_to_hex(apn_data['pgw_address'])) + MIP6_Agent_Info = self.generate_avp(486, '40', MIP_Home_Agent_Address) + else: + MIP6_Agent_Info = '' + + APN_Configuration_AVPS = APN_context_identifer + APN_PDN_type + APN_AMBR + APN_Service_Selection \ + + APN_EPS_Subscribed_QoS_Profile + Served_Party_Address + MIP6_Agent_Info + PDN_GW_Allocation_Type + VPLMN_Dynamic_Address_Allowed + + APN_Configuration += self.generate_vendor_avp(1430, "c0", 10415, APN_Configuration_AVPS) + + #Incriment Context Identifier Count to keep track of how many APN Profiles returned + APN_context_identifer_count = APN_context_identifer_count + 1 + self.logTool.log(service='HSS', level='debug', message="Completed processing APN ID " + str(apn_id), redisClient=self.redisMessaging) + + subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_Configuration_Profile + APN_Configuration) + + try: + self.logTool.log(service='HSS', level='debug', message="MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA", redisClient=self.redisMessaging) + msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(subscriber_details['msisdn']))) #MSISDN + self.logTool.log(service='HSS', level='debug', message=msisdn_avp, redisClient=self.redisMessaging) + subscription_data += msisdn_avp + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Failed to populate MSISDN in ULA due to error " + str(E), redisClient=self.redisMessaging) + + if 'RAT_freq_priorityID' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA", redisClient=self.redisMessaging) + rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID + self.logTool.log(service='HSS', level='debug', message="Adding rat_freq_priorityID: " + str(rat_freq_priorityID), redisClient=self.redisMessaging) + subscription_data += rat_freq_priorityID + + if 'charging_characteristics' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA", redisClient=self.redisMessaging) + _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, str(subscriber_details['charging_characteristics'])) + subscription_data += _3gpp_charging_characteristics + self.logTool.log(service='HSS', level='debug', message="Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics), redisClient=self.redisMessaging) + + #ToDo - Fix this + # if 'APN_OI_replacement' in subscriber_details: + # self.logTool.log(service='HSS', level='debug', message="APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA", redisClient=self.redisMessaging) + # subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) + + avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data + + response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + + self.logTool.log(service='HSS', level='debug', message="Successfully Generated ULA", redisClient=self.redisMessaging) + return response + + #3GPP S6a/S6d Authentication Information Answer + def Answer_16777251_318(self, packet_vars, avps): + self.logTool.log(service='HSS', level='debug', message=f"AIA AVPS: {avps}", redisClient=self.redisMessaging) + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI + plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from User-Name AVP in request + + try: + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + if subscriber_details['enabled'] == 0: + self.logTool.log(service='HSS', level='debug', message=f"Subscriber {imsi} is disabled", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #Result Code + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Disabled User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message=f"Successfully Generated ULA for disabled Subscriber: {imsi}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"{response}", redisClient=self.redisMessaging) + return response + except ValueError as e: + self.logTool.log(service='HSS', level='debug', message="Error getting subscriber details for IMSI " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=e, redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " is unknown in database", redisClient=self.redisMessaging) + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as ex: + template = "An exception of type {0} occurred. Arguments:\n{1!r}" + message = template.format(type(ex).__name__, ex.args) + raise + + + try: + requested_vectors = 1 + EUTRAN_Authentication_Info = self.get_avp_data(avps, 1408) + self.logTool.log(service='HSS', level='debug', message=f"authInfo: {EUTRAN_Authentication_Info}", redisClient=self.redisMessaging) + if len(EUTRAN_Authentication_Info) > 0: + EUTRAN_Authentication_Info = EUTRAN_Authentication_Info[0] + self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) + for sub_avp in EUTRAN_Authentication_Info: + #If resync request + if sub_avp['avp_code'] == 1411: + self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Resync", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + auts = str(sub_avp['misc_data'])[32:] + rand = str(sub_avp['misc_data'])[:32] + rand = binascii.unhexlify(rand) + #Calculate correct SQN + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + + #Get number of requested vectors + if sub_avp['avp_code'] == 1410: + self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) + requested_vectors = int(sub_avp['misc_data'], 16) + if requested_vectors >= 32: + self.logTool.log(service='HSS', level='debug', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) + requested_vectors = 32 + + self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) + eutranvector_complete = '' + while requested_vectors != 0: + self.logTool.log(service='HSS', level='debug', message="Generating vector number " + str(requested_vectors), redisClient=self.redisMessaging) + plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) + eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) + eutranvector += self.generate_vendor_avp(1419, "c0", 10415, self.int_to_hex(requested_vectors, 4)) + eutranvector += self.generate_vendor_avp(1447, "c0", 10415, vector_dict['rand']) #And is made up of other AVPs joined together with RAND + eutranvector += self.generate_vendor_avp(1448, "c0", 10415, vector_dict['xres']) #XRes + eutranvector += self.generate_vendor_avp(1449, "c0", 10415, vector_dict['autn']) #AUTN + eutranvector += self.generate_vendor_avp(1450, "c0", 10415, vector_dict['kasme']) #And KASME + + requested_vectors = requested_vectors - 1 + eutranvector_complete += self.generate_vendor_avp(1414, "c0", 10415, eutranvector) #Put EUTRAN vectors in E-UTRAN-Vector AVP + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_vendor_avp(1413, "c0", 10415, eutranvector_complete) #Authentication-Info (3GPP) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") + #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + + response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated AIA", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=response, redisClient=self.redisMessaging) + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=traceback.format_exc(), redisClient=self.redisMessaging) + + + #Purge UE Answer (PUA) + def Answer_16777251_321(self, packet_vars, avps): + + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') + + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #1442 - PUA-Flags + avp += self.generate_vendor_avp(1442, "c0", 10415, self.int_to_hex(1, 4)) + + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID + SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + + + response = self.generate_diameter_packet("01", "40", 321, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + + + self.database.Update_Serving_MME(imsi, None) + self.logTool.log(service='HSS', level='debug', message="Successfully Generated PUA", redisClient=self.redisMessaging) + return response + + #Notify Answer (NOA) + def Answer_16777251_323(self, packet_vars, avps): + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + response = self.generate_diameter_packet("01", "40", 323, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated NOA", redisClient=self.redisMessaging) + return response + + #3GPP Gx Credit Control Answer + def Answer_16777238_272(self, packet_vars, avps): + try: + CC_Request_Type = self.get_avp_data(avps, 416)[0] + CC_Request_Number = self.get_avp_data(avps, 415)[0] + #Called Station ID + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Attempting to find APN in CCR", redisClient=self.redisMessaging) + apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') + # Strip plmn based domain from apn, if present + try: + if '.' in apn: + assert('mcc' in apn) + assert('mnc' in apn) + apn = apn.split('.')[0] + except Exception as e: + apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] CCR for APN " + str(apn), redisClient=self.redisMessaging) + + OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it + + OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it + + try: #Check if we have a record-route set as that's where we'll need to send the response + remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header + remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it + except: #If we don't have a record-route set, we'll send the response to the OriginHost + remote_peer = OriginHost + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Session Id is " + str(binascii.unhexlify(session_id).decode()), redisClient=self.redisMessaging) + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) + avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type + avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number + + #Get Subscriber info from Subscription ID + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Found IMSI " + str(imsi), redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) + try: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details + ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) + except Exception as E: + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) + + + # CCR - Initial Request + if int(CC_Request_Type) == 1: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) + + #Get UE IP + try: + ue_ip = self.get_avp_data(avps, 8)[0] + ue_ip = str(self.hex_to_ip(ue_ip)) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to get UE IP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + ue_ip = 'Failed to Decode / Get UE IP' + + #Store PGW location into Database + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) + + #Supported-Features(628) (Gx feature list) + avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + + #Default EPS Bearer QoS (From database with fallback source CCR-I, then omission) + try: + apn_data = ChargingRules['apn_data'] + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Setting APN AMBR", redisClient=self.redisMessaging) + #AMBR + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = int(apn_data['apn_ambr_ul']) + apn_ambr_dl = int(apn_data['apn_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) + #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP + # Per TS 29.212, we need to flip our stored values for capability and vulnerability: + # PRE-EMPTION_CAPABILITY_ENABLED (0) + # PRE-EMPTION_CAPABILITY_DISABLED (1) + # PRE-EMPTION_VULNERABILITY_ENABLED (0) + # PRE-EMPTION_VULNERABILITY_DISABLED (1) + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) + except Exception as E: + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) + default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] + if len(default_EPS_QoS) > 0: + avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) + + + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Creating QoS Information", redisClient=self.redisMessaging) + #QoS-Information + try: + apn_data = ChargingRules['apn_data'] + apn_ambr_ul = int(apn_data['apn_ambr_ul']) + apn_ambr_dl = int(apn_data['apn_ambr_dl']) + QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) + QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Populated QoS_Information", redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + + QoS_Information = '' + for AMBR_Part in self.get_avp_data(avps, 1016)[0]: + self.logTool.log(service='HSS', level='debug', message=AMBR_Part, redisClient=self.redisMessaging) + AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) + QoS_Information += AMBR_AVP + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS information set statically", redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Added to AVP List", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) + + # If database returned an existing ChargingRule defintion add ChargingRule to CCA-I + # If a Charging Rule Install AVP is present, it may trigger the creation of a dedicated bearer. + if ChargingRules and ChargingRules['charging_rules'] is not None: + try: + self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) + for individual_charging_rule in ChargingRules['charging_rules']: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) + chargingRule = self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) + if len(chargingRule) > 0: + avp += chargingRule + + except Exception as E: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) + + # CCR - Termination Request + elif int(CC_Request_Type) == 3: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) + if 'ims' in apn: + try: + self.database.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=None) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), serving_pgw=OriginHost, subscriber_routing='') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored IMS state", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear stored IMS state: {traceback.format_exc()}", redisClient=self.redisMessaging) + else: + try: + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), serving_pgw=OriginHost, subscriber_routing='') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored state for: {apn}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear apn state for {apn}: {traceback.format_exc()}", redisClient=self.redisMessaging) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + except Exception as e: #Get subscriber details + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Subscriber " + str(imsi) + " unknown in HSS for CCR", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=traceback.format_exc(), redisClient=self.redisMessaging) + + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777238, + "diameter_cmd_code": 272, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) #Result Code (DIAMETER ERROR - User Unknown) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Cx User Authorization Answer + def Answer_16777216_300(self, packet_vars, avps): + + avp = '' #Initiate empty var AVP #Session-ID + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + + + OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it + OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it + + try: #Check if we have a record-route set as that's where we'll need to send the response + remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header + remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it + except: #If we don't have a record-route set, we'll send the response to the OriginHost + remote_peer = OriginHost + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_300] [UAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + + try: + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) + username = self.get_avp_data(avps, 1)[0] + username = binascii.unhexlify(username).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) + imsi = username.split('@')[0] #Strip Domain + domain = username.split('@')[1] #Get Domain Part + self.logTool.log(service='HSS', level='debug', message="Extracted imsi: " + str(imsi) + " now checking backend for this IMSI", redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(imsi=imsi) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_300() input", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 300, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + result_code = 5001 #IMS User Unknown + #Experimental Result AVP + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #Determine SAR Type & Store + user_authorization_type_avp_data = self.get_avp_data(avps, 623) + if user_authorization_type_avp_data: + try: + User_Authorization_Type = int(user_authorization_type_avp_data[0]) + self.logTool.log(service='HSS', level='debug', message="User_Authorization_Type is: " + str(User_Authorization_Type), redisClient=self.redisMessaging) + if (User_Authorization_Type == 1): + self.logTool.log(service='HSS', level='debug', message="This is Deregister", redisClient=self.redisMessaging) + self.database.Update_Serving_CSCF(imsi, serving_cscf=None) + #Populate S-CSCF Address + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + except Exception as E: + self.logTool.log(service='HSS', level='debug', message="Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) + if ims_subscriber_details['scscf'] != None: + self.logTool.log(service='HSS', level='debug', message="Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf']), redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) + experimental_avp = '' + experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID + experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2002),"x").zfill(8)) #DIAMETER_SUBSEQUENT_REGISTRATION (2002) + avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result + else: + self.logTool.log(service='HSS', level='debug', message="No SCSCF Assigned from DB", redisClient=self.redisMessaging) + if 'scscf_pool' in self.config['hss']: + try: + scscf = random.choice(self.config['hss']['scscf_pool']) + self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) + except Exception as E: + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Using generated S-CSCF Address as failed to source from list due to " + str(E), redisClient=self.redisMessaging) + else: + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Using generated S-CSCF Address as none set in scscf_pool in config", redisClient=self.redisMessaging) + experimental_avp = '' + experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID + experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2001),"x").zfill(8)) #DIAMETER_FIRST_REGISTRATION (2001) + avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result + + response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Cx Server Assignment Answer + def Answer_16777216_301(self, packet_vars, avps): + avp = '' #Initiate empty var AVP #Session-ID + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + + OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it + + OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it + + #Find Remote Peer we need to address CLRs through + try: #Check if we have a record-route set as that's where we'll need to send the response + remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header + remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it + except: #If we don't have a record-route set, we'll send the response to the OriginHost + remote_peer = OriginHost + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_301] [SAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + + try: + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) + username = self.get_avp_data(avps, 601)[0] + ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) + imsi = ims_subscriber_details['imsi'] + domain = "ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org" + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_301() input", redisClient=self.redisMessaging) + result_code = 5005 + #Experimental Result AVP + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + response = self.generate_diameter_packet("01", "40", 301, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(str(imsi) + '@' + str(domain))),'ascii')) + #Cx-User-Data (XML) + + #This loads a Jinja XML template as the default iFC + templateLoader = jinja2.FileSystemLoader(searchpath="../") + templateEnv = jinja2.Environment(loader=templateLoader) + self.logTool.log(service='HSS', level='debug', message="Loading iFC from path " + str(ims_subscriber_details['ifc_path']), redisClient=self.redisMessaging) + template = templateEnv.get_template(ims_subscriber_details['ifc_path']) + + #These variables are passed to the template for use + ims_subscriber_details['mnc'] = self.MNC.zfill(3) + ims_subscriber_details['mcc'] = self.MCC.zfill(3) + + xmlbody = template.render(iFC_vars=ims_subscriber_details) # this is where to put args to the template renderer + avp += self.generate_vendor_avp(606, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) + + #Charging Information + #avp += self.generate_vendor_avp(618, "c0", 10415, "0000026dc000001b000028af7072695f6363665f6164647265737300") + #avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS + + #Determine SAR Type & Store + Server_Assignment_Type_Hex = self.get_avp_data(avps, 614)[0] + Server_Assignment_Type = self.hex_to_int(Server_Assignment_Type_Hex) + self.logTool.log(service='HSS', level='debug', message="Server-Assignment-Type is: " + str(Server_Assignment_Type), redisClient=self.redisMessaging) + ServingCSCF = self.get_avp_data(avps, 602)[0] #Get OriginHost from AVP + ServingCSCF = binascii.unhexlify(ServingCSCF).decode('utf-8') #Format it + self.logTool.log(service='HSS', level='debug', message="Subscriber is served by S-CSCF " + str(ServingCSCF), redisClient=self.redisMessaging) + if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): + self.logTool.log(service='HSS', level='debug', message="SAR is Register / Re-Register", redisClient=self.redisMessaging) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + self.database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) + else: + self.logTool.log(service='HSS', level='debug', message="SAR is not Register", redisClient=self.redisMessaging) + self.database.Update_Serving_CSCF(imsi, serving_cscf=None) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + + response = self.generate_diameter_packet("01", "40", 301, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Cx Location Information Answer + def Answer_16777216_302(self, packet_vars, avps): + avp = '' #Initiate empty var AVP #Session-ID + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) + avp += self.generate_avp(277, 40, "00000001") #Auth Session State + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + + + + try: + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) + username = self.get_avp_data(avps, 601)[0] + ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) + if ims_subscriber_details['scscf'] != None: + self.logTool.log(service='HSS', level='debug', message="Got SCSCF on record for Sub", redisClient=self.redisMessaging) + #Strip double sip prefix + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(str(ims_subscriber_details['scscf']))),'ascii')) + else: + self.logTool.log(service='HSS', level='debug', message="No SCSF assigned - Using SCSCF Pool", redisClient=self.redisMessaging) + if 'scscf_pool' in self.config['hss']: + try: + scscf = random.choice(self.config['hss']['scscf_pool']) + self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) + except Exception as E: + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Using generated iFC as failed to source from list due to " + str(E), redisClient=self.redisMessaging) + else: + avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Using generated iFC", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_302() input", redisClient=self.redisMessaging) + result_code = 5001 + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 302, + "event": "Unknown User", + "imsi_prefix": str(username[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + #Experimental Result AVP + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + response = self.generate_diameter_packet("01", "40", 302, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS + response = self.generate_diameter_packet("01", "40", 302, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + + return response + + #3GPP Cx Multimedia Authentication Answer + def Answer_16777216_303(self, packet_vars, avps): + public_identity = self.get_avp_data(avps, 601)[0] + public_identity = binascii.unhexlify(public_identity).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Got MAR for public_identity : " + str(public_identity), redisClient=self.redisMessaging) + username = self.get_avp_data(avps, 1)[0] + username = binascii.unhexlify(username).decode('utf-8') + imsi = username.split('@')[0] #Strip Domain + domain = username.split('@')[1] #Get Domain Part + self.logTool.log(service='HSS', level='debug', message="Got MAR username: " + str(username), redisClient=self.redisMessaging) + auth_scheme = '' + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + avp += self.generate_avp(277, 40, "00000001") #Auth Session State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + try: + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + except: + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for MAA", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 303, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #Result Code (DIAMETER ERROR - User Unknown) + experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") + #Experimental Result (297) + avp += self.generate_avp(297, 40, experimental_result) + response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + self.logTool.log(service='HSS', level='debug', message="Got subscriber data for MAA OK", redisClient=self.redisMessaging) + + mcc, mnc = imsi[0:3], imsi[3:5] + plmn = self.EncodePLMN(mcc, mnc) + + #Determine if SQN Resync is required & auth type to use + for sub_avp_612 in self.get_avp_data(avps, 612)[0]: + if sub_avp_612['avp_code'] == 610: + self.logTool.log(service='HSS', level='debug', message="SQN in HSS is out of sync - Performing resync", redisClient=self.redisMessaging) + auts = str(sub_avp_612['misc_data'])[32:] + rand = str(sub_avp_612['misc_data'])[:32] + rand = binascii.unhexlify(rand) + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + self.logTool.log(service='HSS', level='debug', message="Resynced SQN in DB", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 302, + "event": "ReAuth", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + if sub_avp_612['avp_code'] == 608: + self.logTool.log(service='HSS', level='debug', message="Auth mechansim requested: " + str(sub_avp_612['misc_data']), redisClient=self.redisMessaging) + auth_scheme = binascii.unhexlify(sub_avp_612['misc_data']).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Auth mechansim requested: " + str(auth_scheme), redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='debug', message="IMSI is " + str(imsi), redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) + avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi + "@" + domain)),'ascii')) #Username + + + + #Determine Vectors to Generate + if auth_scheme == "Digest-MD5": + self.logTool.log(service='HSS', level='debug', message="Generating MD5 Challenge", redisClient=self.redisMessaging) + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "Digest-MD5", username=imsi, plmn=plmn) + avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) + avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-MD5'),'ascii')) + #Nonce + avp_SIP_Authenticate = self.generate_vendor_avp(609, "c0", 10415, str(vector_dict['nonce'])) + #Expected Response + avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(str.encode(vector_dict['SIP_Authenticate'])),'ascii')) + auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization + else: + self.logTool.log(service='HSS', level='debug', message="Generating AKA-MD5 Auth Challenge", redisClient=self.redisMessaging) + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sip_auth", plmn=plmn) + + + #diameter.3GPP-SIP-Auth-Data-Items: + + #AVP Code: 613 3GPP-SIP-Item-Number + avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) + #AVP Code: 608 3GPP-SIP-Authentication-Scheme + avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-AKAv1-MD5'),'ascii')) + #AVP Code: 609 3GPP-SIP-Authenticate + avp_SIP_Authenticate = self.generate_vendor_avp(609, "c0", 10415, str(binascii.hexlify(vector_dict['SIP_Authenticate']),'ascii')) #RAND + AUTN + #AVP Code: 610 3GPP-SIP-Authorization + avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(vector_dict['xres']),'ascii')) #XRES + #AVP Code: 625 Confidentiality-Key + avp_Confidentialility_Key = self.generate_vendor_avp(625, "c0", 10415, str(binascii.hexlify(vector_dict['ck']),'ascii')) #CK + #AVP Code: 626 Integrity-Key + avp_Integrity_Key = self.generate_vendor_avp(626, "c0", 10415, str(binascii.hexlify(vector_dict['ik']),'ascii')) #IK + + auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization + avp_Confidentialility_Key + avp_Integrity_Key + avp += self.generate_vendor_avp(612, "c0", 10415, auth_data_item) #3GPP-SIP-Auth-Data-Item + + avp += self.generate_vendor_avp(607, "c0", 10415, "00000001") #3GPP-SIP-Number-Auth-Items + + + avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS + + response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #Generate a Generic error handler with Result Code as input + def Respond_ResultCode(self, packet_vars, avps, result_code): + self.logTool.log(service='HSS', level='error', message="Responding with result code " + str(result_code) + " to request with command code " + str(packet_vars['command_code']), redisClient=self.redisMessaging) + avp = '' #Initiate empty var AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + try: + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + except: + self.logTool.log(service='HSS', level='debug', message="Failed to add SessionID into error", redisClient=self.redisMessaging) + for avps_to_check in avps: #Only include AVP 260 (Vendor-Specific-Application-ID) if inital request included it + if avps_to_check['avp_code'] == 260: + concat_subavp = '' + for sub_avp in avps_to_check['misc_data']: + concat_subavp += self.generate_avp(sub_avp['avp_code'], sub_avp['avp_flags'], sub_avp['misc_data']) + avp += self.generate_avp(260, 40, concat_subavp) #Vendor-Specific-Application-ID + avp += self.generate_avp(268, 40, self.int_to_hex(result_code, 4)) #Response Code + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + response = self.generate_diameter_packet("01", "60", int(packet_vars['command_code']), int(packet_vars['ApplicationId']), packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Cx Registration Termination Answer + def Answer_16777216_304(self, packet_vars, avps): + avp = '' #Initiate empty var AVP #Session-ID + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + vendor_id = self.generate_avp(266, 40, str(binascii.hexlify('10415'),'ascii')) + self.logTool.log(service='HSS', level='debug', message="vendor_id avp: " + str(vendor_id), redisClient=self.redisMessaging) + auth_application_id = self.generate_avp(248, 40, self.int_to_hex(16777252, 8)) + self.logTool.log(service='HSS', level='debug', message="auth_application_id: " + auth_application_id, redisClient=self.redisMessaging) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + avp += self.generate_avp(268, 40, "000007d1") #Result Code - DIAMETER_SUCCESS + avp += self.generate_avp(277, 40, "00000001") #Auth Session State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + #* [ Proxy-Info ] + proxy_host_avp = self.generate_avp(280, "40", str(binascii.hexlify(b'localdomain'),'ascii')) + proxy_state_avp = self.generate_avp(33, "40", "0001") + avp += self.generate_avp(284, "40", proxy_host_avp + proxy_state_avp) #Proxy-Info AVP ( 284 ) + + #* [ Route-Record ] + avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) + + response = self.generate_diameter_packet("01", "40", 304, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Sh User-Data Answer + def Answer_16777217_306(self, packet_vars, avps): + avp = '' #Initiate empty var AVP #Session-ID + + #Define values so we can check if they've been changed + msisdn = None + try: + user_identity_avp = self.get_avp_data(avps, 700)[0] + msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request + self.logTool.log(service='HSS', level='debug', message="Got raw MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + msisdn = self.TBCD_decode(msisdn) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + except: + self.logTool.log(service='HSS', level='error', message="No MSISDN", redisClient=self.redisMessaging) + try: + username = self.get_avp_data(avps, 601)[0] + except Exception as e: + self.logTool.log(service='HSS', level='error', message="No Username", redisClient=self.redisMessaging) + + if msisdn is not None: + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber IMS info based on MSISDN", redisClient=self.redisMessaging) + subscriber_ims_details = self.database.Get_IMS_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber IMS details: " + str(subscriber_ims_details), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber info based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) + subscriber_details = {**subscriber_details, **subscriber_ims_details} + self.logTool.log(service='HSS', level='debug', message="Merged subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI in Answer_16777217_306() input", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 306, + "event": "Unknown User", + "imsi_prefix": str(username[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + result_code = 5005 + #Experimental Result AVP + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + response = self.generate_diameter_packet("01", "40", 306, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000001") #Vendor-Specific-Application-ID for Cx + + #Sh-User-Data (XML) + #This loads a Jinja XML template containing the Sh-User-Data + sh_userdata_template = self.config['hss']['Default_Sh_UserData'] + self.logTool.log(service='HSS', level='debug', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) + template = self.templateEnv.get_template(sh_userdata_template) + #These variables are passed to the template for use + subscriber_details['mnc'] = self.MNC.zfill(3) + subscriber_details['mcc'] = self.MCC.zfill(3) + + self.logTool.log(service='HSS', level='debug', message="Rendering template with values: " + str(subscriber_details), redisClient=self.redisMessaging) + xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer + avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) + + avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS + + response = self.generate_diameter_packet("01", "40", 306, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + + return response + + #3GPP Sh Profile-Update Answer + def Answer_16777217_307(self, packet_vars, avps): + + + #Get IMSI + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') + + #Get Sh User Data + sh_user_data = self.get_avp_data(avps, 702)[0] #Get IMSI from User-Name AVP in request + sh_user_data = binascii.unhexlify(sh_user_data).decode('utf-8') + + self.logTool.log(service='HSS', level='debug', message="Got Sh User data: " + str(sh_user_data), redisClient=self.redisMessaging) + + #Push updated User Data into IMS Backend + #Start with the Current User Data + subscriber_ims_details = self.database.Get_IMS_Subscriber(imsi=imsi) + self.database.UpdateObj(self.database.IMS_SUBSCRIBER, {'sh_profile': sh_user_data}, subscriber_ims_details['ims_subscriber_id']) + + avp = '' #Initiate empty var AVP #Session-ID + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (No state maintained) + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777217),"x").zfill(8)) #Auth-Application-ID Sh + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + response = self.generate_diameter_packet("01", "40", 307, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + ################################ + #### 3GPP RX #### + ################################ + + #3GPP Rx - AA Answer (AAA) + def Answer_16777236_265(self, packet_vars, avps): + try: + """ + Generates a response to a provided AAR. + The response is determined by whether or not the subscriber is enabled, and has a matching ims_subscriber entry. + """ + avp = '' + sessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') #Get Session-ID + avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + subscriptionId = bytes.fromhex(self.get_avp_data(avps, 444)[0]).decode('ascii') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Received subscription ID: {subscriptionId}", redisClient=self.redisMessaging) + subscriptionId = subscriptionId.replace('sip:', '') + imsi = None + msisdn = None + identifier = None + if '@' in subscriptionId: + subscriberIdentifier = subscriptionId.split('@')[0] + # Subscriber Identifier can be either an IMSI or an MSISDN + try: + subscriberDetails = self.database.Get_Subscriber(imsi=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=subscriberIdentifier) + identifier = 'imsi' + imsi = imsSubscriberDetails.get('imsi', None) + except Exception as e: + pass + try: + subscriberDetails = self.database.Get_Subscriber(msisdn=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=subscriberIdentifier) + identifier = 'msisdn' + msisdn = imsSubscriberDetails.get('msisdn', None) + except Exception as e: + pass + else: + imsi = None + msisdn = None + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) + imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) + + if imsEnabled: + """ + Add the PCSCF to the IMS_Subscriber object, and set the result code to 2001. + """ + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request authorized", redisClient=self.redisMessaging) + + if imsi is None: + imsi = subscriberDetails.get('imsi', None) + + aarOriginHost = self.get_avp_data(avps, 264)[0] + aarOriginHost = bytes.fromhex(aarOriginHost).decode('ascii') + aarOriginRealm = self.get_avp_data(avps, 296)[0] + aarOriginRealm = bytes.fromhex(aarOriginRealm).decode('ascii') + #Check if we have a record-route set as that's where we'll need to send the response + try: + #Get first record-route header, then parse it + remotePeer = self.get_avp_data(avps, 282)[-1] + remotePeer = binascii.unhexlify(remotePeer).decode('utf-8') + except Exception as e: + #If we don't have a record-route set, we'll send the response to the OriginHost + remotePeer = aarOriginHost + + remotePeer = f"{remotePeer};{self.config['hss']['OriginHost']}" + + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=None) + """ + Check for AVP's 504 (AF-Application-Identifier) and 520 (Media-Type), which indicates the UE is making a call. + Media-Type: 0 = Audio, 4 = Control + """ + try: + afApplicationIdentifier = self.get_avp_data(avps, 504)[0] + mediaType = self.get_avp_data(avps, 520)[0] + assert(bytes.fromhex(afApplicationIdentifier).decode('ascii') == "IMS Services") + assert(int(mediaType, 16) == 0) + + # At this point, we know the AAR is indicating a call setup, so we'll send get the serving pgw information, then send a + # RAR to the PGW over Gx, asking it to setup the dedicated bearer. + + try: + subscriberId = subscriberDetails.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + ueIp = servingApn.get('subscriber_routing', None) + + """ + The below charging rule needs to be replaced by the following logic: + 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, + 2. Compare it to a given backup rule + - If the flowrates are greater than the backup rule (UE is asking for more than allowed), use the backup rule + - If the flowrates are lesser than the backup rule, use the requested flowrates. This will allow for better utilization of radio resources. + 3. Maybe something to do with the TFT's + 4. Send the winning rule. + """ + + chargingRule = { + "charging_rule_id": 1000, + "qci": 1, + "arp_preemption_capability": True, + "mbr_dl": 128000, + "mbr_ul": 128000, + "gbr_ul": 128000, + "precedence": 100, + "arp_priority": 2, + "rule_name": "GBR-Voice", + "arp_preemption_vulnerability": False, + "gbr_dl": 128000, + "tft_group_id": 1, + "rating_group": None, + "tft": [ + { + "tft_group_id": 1, + "direction": 1, + "tft_id": 1, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + }, + { + "tft_group_id": 1, + "direction": 2, + "tft_id": 2, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + } + ] + } + + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) + + reAuthAnswer = self.awaitDiameterRequestAndResponse( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + chargingRules=chargingRule, + ueIp=ueIp, + servingPgw=servingPgw, + servingRealm=servingPgwRealm + ) + + if not len(reAuthAnswer) > 0: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) + assert() + + raaPacketVars, raaAvps = self.decode_diameter_packet(reAuthAnswer) + raaResultCode = int(self.get_avp_data(raaAvps, 268)[0], 16) + + if raaResultCode == 2001: + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA returned Successfully, authorizing request", redisClient=self.redisMessaging) + else: + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA returned Unauthorized, declining request", redisClient=self.redisMessaging) + + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error processing RAR / RAA, Authorizing request: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + + except Exception as e: + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + pass + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request unauthorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + + response = self.generate_diameter_packet("01", "40", 265, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error generating AAA: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5012, 4)) #Result Code 5012 UNABLE_TO_COMPLY + response = self.generate_diameter_packet("01", "40", 265, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Rx - Re Auth Answer (RAA) + def Answer_16777236_258(self, packet_vars, avps): + try: + """ + Generates a response to a provided RAR. + The response is determined by whether or not the subscriber is enabled, and has a matching ims_subscriber entry. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + subscriptionId = bytes.fromhex(self.get_avp_data(avps, 444)[0]).decode('ascii') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Received subscription ID: {subscriptionId}", redisClient=self.redisMessaging) + subscriptionId = subscriptionId.replace('sip:', '') + imsi = None + msisdn = None + identifier = None + if '@' in subscriptionId: + subscriberIdentifier = subscriptionId.split('@')[0] + # Subscriber Identifier can be either an IMSI or an MSISDN + try: + subscriberDetails = self.database.Get_Subscriber(imsi=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=subscriberIdentifier) + identifier = 'imsi' + imsi = imsSubscriberDetails.get('imsi', None) + except Exception as e: + pass + try: + subscriberDetails = self.database.Get_Subscriber(msisdn=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=subscriberIdentifier) + identifier = 'msisdn' + msisdn = imsSubscriberDetails.get('msisdn', None) + except Exception as e: + pass + else: + imsi = None + msisdn = None + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) + imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) + + if imsEnabled: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Request authorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Request unauthorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + + response = self.generate_diameter_packet("01", "40", 258, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_258] [RAA] Error generating RAA: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5012, 4)) #Result Code 5012 UNABLE_TO_COMPLY + response = self.generate_diameter_packet("01", "40", 258, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Rx - Session Termination Answer (STA) + def Answer_16777236_275(self, packet_vars, avps): + try: + """ + Triggers a Re-Auth-Request to the PGW, the returns a Session Termination Answer. + """ + avp = '' + sessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') #Get Session-ID + avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) + imsi = imsSubscriber.get('imsi', None) + pcscf = imsSubscriber.get('pcscf', None) + pcscf_realm = imsSubscriber.get('pcscf_realm', None) + pcscf_peer = imsSubscriber.get('pcscf_peer', None) + subscriber = self.database.Get_Subscriber(imsi=imsi) + subscriberId = subscriber.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + + if servingApn is not None: + servingPgw = servingApn.get('serving_pgw', '') + servingPgwRealm = servingApn.get('serving_pgw_realm', '') + servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] + pcrfSessionId = servingApn.get('pcrf_session_id', None) + reAuthAnswer = self.awaitDiameterRequestAndResponse( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + servingPgw=servingPgw, + servingRealm=servingPgwRealm, + chargingRuleName='GBR-Voice', + chargingRuleAction='remove' + ) + + if not len(reAuthAnswer) > 0: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) + assert() + + raaPacketVars, raaAvps = self.decode_diameter_packet(reAuthAnswer) + raaResultCode = int(self.get_avp_data(raaAvps, 268)[0], 16) + + if raaResultCode == 2001: + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA returned Successfully, authorizing request", redisClient=self.redisMessaging) + else: + avp += self.generate_avp(268, 40, self.int_to_hex(5001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA returned Unauthorized, returning Result-Code 5001", redisClient=self.redisMessaging) + + else: + self.logTool.log(service='HSS', level='info', message=f"[diameter.py] [Answer_16777236_275] [STA] Unable to find serving APN for RAR, returning Result-Code 2001", redisClient=self.redisMessaging) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001", redisClient=self.redisMessaging) + avp = '' + sessionId = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, sessionId) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Rx - Abort Session Answer (ASA) + def Answer_16777236_274(self, packet_vars, avps): + try: + """ + Generates a response to a provided ASR. + Returns Result-Code 2001. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 274, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_274] [STA] Error generating STA: {traceback.format_exc()}", redisClient=self.redisMessaging) + + + # Re Auth Answer + def Answer_16777238_258(self, packet_vars, avps): + try: + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 274, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_274] [RAA] Error generating RAA: {traceback.format_exc()}", redisClient=self.redisMessaging) + + #3GPP S13 - ME-Identity-Check Answer + def Answer_16777252_324(self, packet_vars, avps): + + #Get IMSI + try: + imei = '' + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI + #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + self.logTool.log(service='HSS', level='debug', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) + + try: + #Get IMEI + for sub_avp in self.get_avp_data(avps, 1401)[0]: + self.logTool.log(service='HSS', level='debug', message="Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI", redisClient=self.redisMessaging) + if sub_avp['avp_code'] == 1402: + imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Found IMEI " + str(imei), redisClient=self.redisMessaging) + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 + avp += self.generate_avp(277, 40, "00000001") #Auth Session State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 'c0', 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 'c0', self.int_to_hex(2001, 4)) #AVP Experimental-Result-Code: SUCESS (2001) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + + #Equipment-Status + EquipmentStatus = self.database.Check_EIR(imsi=imsi, imei=imei) + avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_eir_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "response": EquipmentStatus}, + metricHelp='Diameter EIR event related Counters', + metricExpiry=60) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=traceback.format_exc(), redisClient=self.redisMessaging) + + + response = self.generate_diameter_packet("01", "40", 324, 16777252, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP SLh - LCS-Routing-Info-Answer + def Answer_16777291_8388622(self, packet_vars, avps): + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777291),"x").zfill(8)) #Auth-Application-ID SLh + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + avp += self.generate_avp(277, 40, "00000001") #Auth Session State (NO_STATE_MAINTAINED) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #Create list of valid AVPs + present_avps = [] + for avp_id in avps: + present_avps.append(avp_id['avp_code']) + + #Define values so we can check if they've been changed + msisdn = None + imsi = None + + #Try and get IMSI if present + if 1 in present_avps: + self.logTool.log(service='HSS', level='debug', message="IMSI AVP is present", redisClient=self.redisMessaging) + try: + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + self.logTool.log(service='HSS', level='debug', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) + elif 701 in present_avps: + #Try and get MSISDN if present + try: + msisdn = self.get_avp_data(avps, 701)[0] #Get MSISDN from AVP in request + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(701, 'c0', 10415, self.get_avp_data(avps, 701)[0]) #MSISDN + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with encoded value " + str(msisdn), redisClient=self.redisMessaging) + msisdn = self.TBCD_decode(msisdn) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with decoded value " + str(msisdn), redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message="Failed to get MSISDN from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI", redisClient=self.redisMessaging) + + try: + if imsi is not None: + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber location based on IMSI", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(imsi=imsi) + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details from IMSI: " + str(subscriber_details), redisClient=self.redisMessaging) + elif msisdn is not None: + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber location based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details from MSISDN: " + str(subscriber_details), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='HSS', level='info', message="No MSISDN or IMSI returned in Answer_16777291_8388622 input", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Error is " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) + response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + return response + + + + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details for subscriber: " + str(subscriber_details), redisClient=self.redisMessaging) + + + if subscriber_details['serving_mme'] == None: + #DB has no location on record for subscriber + self.logTool.log(service='HSS', level='debug', message="No location on record for Subscriber", redisClient=self.redisMessaging) + result_code = 4201 + #DIAMETER_ERROR_ABSENT_USER (4201) + #This result code shall be sent by the HSS to indicate that the location of the targeted user is not known at this time to + #satisfy the requested operation. + + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + + + #Serving Node AVP + avp_serving_node = '' + avp_serving_node += self.generate_vendor_avp(2402, "c0", 10415, self.string_to_hex(subscriber_details['serving_mme'])) #MME-Name + avp_serving_node += self.generate_vendor_avp(2408, "c0", 10415, self.OriginRealm) #MME-Realm + avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(self.config['hss']['bind_ip'][0])) #GMLC-Address + avp += self.generate_vendor_avp(2401, "c0", 10415, avp_serving_node) #Serving-Node AVP + + #Set Result-Code + result_code = 2001 #Diameter Success + avp += self.generate_avp(268, 40, self.int_to_hex(result_code, 4)) #Result Code - DIAMETER_SUCCESS + + response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #### Diameter Requests #### + + #Capabilities Exchange Request + def Request_257(self): + avp = '' + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(257, 40, self.ip_to_hex(socket.gethostbyname(socket.gethostname()))) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) + avp += self.generate_avp(266, 40, "00000000") #Vendor-Id + avp += self.generate_avp(269, "00", self.ProductName) #Product-Name + avp += self.generate_avp(260, 40, "000001024000000c01000023" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + avp += self.generate_avp(260, 40, "000001024000000c01000016" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Gx) + avp += self.generate_avp(260, 40, "000001024000000c01000027" + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (SLg) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) + avp += self.generate_avp(258, 40, format(int(4294967295),"x").zfill(8)) #Auth-Application-ID Relay + avp += self.generate_avp(265, 40, format(int(5535),"x").zfill(8)) #Supported-Vendor-ID (3GGP v2) + avp += self.generate_avp(265, 40, format(int(10415),"x").zfill(8)) #Supported-Vendor-ID (3GPP) + avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) + response = self.generate_diameter_packet("01", "80", 257, 0, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #Device Watchdog Request + def Request_280(self): + avp = '' + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + response = self.generate_diameter_packet("01", "80", 280, 0, self.generate_id(4), self.generate_id(4), avp)#Generate Diameter packet + return response + + #Disconnect Peer Request + def Request_282(self): + avp = '' #Initiate empty var AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(273, 40, "00000000") #Disconnect-Cause (REBOOTING (0)) + response = self.generate_diameter_packet("01", "80", 282, 0, self.generate_id(4), self.generate_id(4), avp)#Generate Diameter packet + return response + + #3GPP S6a/S6d Authentication Information Request + def Request_16777251_318(self, imsi, DestinationHost, DestinationRealm, requested_vectors=1): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm + #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + number_of_requested_vectors = self.generate_vendor_avp(1410, "c0", 10415, format(int(requested_vectors),"x").zfill(8)) + immediate_response_preferred = self.generate_vendor_avp(1412, "c0", 10415, format(int(1),"x").zfill(8)) + avp += self.generate_vendor_avp(1408, "c0", 10415, str(number_of_requested_vectors) + str(immediate_response_preferred)) + + mcc = str(imsi)[:3] + mnc = str(imsi)[3:5] + avp += self.generate_vendor_avp(1407, "c0", 10415, self.EncodePLMN(mcc, mnc)) #Visited-PLMN-Id(1407) (Derrived from start of IMSI) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID + response = self.generate_diameter_packet("01", "c0", 318, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S6a/S6d Update Location Request (ULR) + def Request_16777251_316(self, imsi, DestinationRealm): + mcc = imsi[0:3] + mnc = imsi[3:5] + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + avp += self.generate_vendor_avp(1032, "80", 10415, self.int_to_hex(1004, 4)) #RAT-Type val=EUTRAN (1004) + avp += self.generate_vendor_avp(1405, "c0", 10415, "00000002") #ULR-Flags val=2 + avp += self.generate_vendor_avp(1407, "c0", 10415, self.EncodePLMN(mcc, mnc)) #Visited-PLMN-Id(1407) (Derrived from start of IMSI) + avp += self.generate_vendor_avp(1615, "80", 10415, "00000000") #E-SRVCC-Capability val=UE-SRVCC-NOT-SUPPORTED (0) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID + response = self.generate_diameter_packet("01", "c0", 316, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S6a/S6d Purge UE Request PUR + def Request_16777251_321(self, imsi, DestinationRealm, DestinationHost): + avp = '' + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm + #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID + response = self.generate_diameter_packet("01", "c0", 321, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S6a/S6d NOtify Request NOR + def Request_16777251_323(self, imsi, DestinationRealm, DestinationHost): + avp = '' + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm + #avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID + response = self.generate_diameter_packet("01", "c0", 323, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S6a/S6d Cancel-Location-Request Request CLR + def Request_16777251_317(self, imsi, DestinationRealm, DestinationHost=None, CancellationType=2): + avp = '' + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm + if DestinationHost != None: + avp += self.generate_avp(293, 40, self.string_to_hex(DestinationHost)) #Destination Host + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") #Vendor-Specific-Application-ID + avp += self.generate_vendor_avp(1420, "c0", 10415, self.int_to_hex(CancellationType, 4)) #Cancellation-Type (Subscription Withdrawl) + response = self.generate_diameter_packet("01", "c0", 317, 16777251, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S6a/S6d Insert Subscriber Data Request (ISD) + def Request_16777251_319(self, packet_vars, avps, **kwargs): + avp = '' #Initiate empty var AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session ID generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID set AVP + avp += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + + + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID + SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags + if 'GetLocation' in kwargs: + self.logTool.log(service='HSS', level='debug', message="Requsted Get Location ISD", redisClient=self.redisMessaging) + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID + SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "18000007") #Feature-List Flags + avp += self.generate_vendor_avp(1490, "c0", 10415, "00000018") #IDR-Flags + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + + try: + user_identity_avp = self.get_avp_data(avps, 700)[0] + self.logTool.log(service='HSS', level='debug', message=user_identity_avp, redisClient=self.redisMessaging) + msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request + msisdn = self.TBCD_decode(msisdn) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + except: + self.logTool.log(service='HSS', level='error', message="No MSISDN present", redisClient=self.redisMessaging) + return + #Get Subscriber Location from Database + subscriber_location = self.database.GetSubscriberLocation(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber location: " + subscriber_location, redisClient=self.redisMessaging) + + + self.logTool.log(service='HSS', level='debug', message="Getting IMSI for MSISDN " + str(msisdn), redisClient=self.redisMessaging) + imsi = self.database.Get_IMSI_from_MSISDN(msisdn) + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + + self.logTool.log(service='HSS', level='debug', message="Got back location data: " + str(subscriber_location), redisClient=self.redisMessaging) + + #Populate Destination Host & Realm + avp += self.generate_avp(293, 40, self.string_to_hex(subscriber_location)) #Destination Host #Destination-Host + avp += self.generate_avp(283, 40, self.string_to_hex('epc.mnc001.mcc214.3gppnetwork.org')) #Destination Realm + + else: + #APNs from DB + imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request + imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + avp += self.generate_vendor_avp(1490, "c0", 10415, "00000000") #IDR-Flags + + destinationHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + destinationHost = binascii.unhexlify(destinationHost).decode('utf-8') #Format it + self.logTool.log(service='HSS', level='debug', message="Received originHost to use as destinationHost is " + str(destinationHost), redisClient=self.redisMessaging) + destinationRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + destinationRealm = binascii.unhexlify(destinationRealm).decode('utf-8') #Format it + self.logTool.log(service='HSS', level='debug', message="Received originRealm to use as destinationRealm is " + str(destinationRealm), redisClient=self.redisMessaging) + avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination-Host + avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) + + APN_Configuration = '' + + try: + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + except ValueError as e: + self.logTool.log(service='HSS', level='error', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Error is " + str(e), redisClient=self.redisMessaging) + raise + except Exception as ex: + template = "An exception of type {0} occurred. Arguments:\n{1!r}" + message = template.format(type(ex).__name__, ex.args) + raise + + + + #Subscription Data: + subscription_data = '' + subscription_data += self.generate_vendor_avp(1426, "c0", 10415, "00000000") #Access Restriction Data + subscription_data += self.generate_vendor_avp(1424, "c0", 10415, "00000000") #Subscriber-Status (SERVICE_GRANTED) + subscription_data += self.generate_vendor_avp(1417, "c0", 10415, "00000000") #Network-Access-Mode (PACKET_AND_CIRCUIT) + + #AMBR is a sub-AVP of Subscription Data + AMBR = '' #Initiate empty var AVP for AMBR + if 'ue_ambr_ul' in subscriber_details: + ue_ambr_ul = int(subscriber_details['ue_ambr_ul']) + else: + #use default AMBR of unlimited if no value in subscriber_details + ue_ambr_ul = 1048576000 + + if 'ue_ambr_dl' in subscriber_details: + ue_ambr_dl = int(subscriber_details['ue_ambr_dl']) + else: + #use default AMBR of unlimited if no value in subscriber_details + ue_ambr_dl = 1048576000 + + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + subscription_data += self.generate_vendor_avp(1435, "c0", 10415, AMBR) #Add AMBR AVP in two sub-AVPs + + #APN Configuration Profile is a sub AVP of Subscription Data + APN_Configuration_Profile = '' + APN_Configuration_Profile += self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(1, 4)) #Context Identifier + APN_Configuration_Profile += self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) #All-APN-Configurations-Included-Indicator + + + + apn_list = subscriber_details['pdn'] + self.logTool.log(service='HSS', level='debug', message="APN list: " + str(apn_list), redisClient=self.redisMessaging) + APN_context_identifer_count = 1 + for apn_profile in apn_list: + self.logTool.log(service='HSS', level='debug', message="Processing APN profile " + str(apn_profile), redisClient=self.redisMessaging) + APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_profile['apn']))) + + self.logTool.log(service='HSS', level='debug', message="Setting APN Configuration Profile", redisClient=self.redisMessaging) + #Sub AVPs of APN Configuration Profile + APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) + APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(0, 4)) + + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) + #AMBR + AMBR = '' #Initiate empty var AVP for AMBR + if 'AMBR' in apn_profile: + ue_ambr_ul = int(apn_profile['AMBR']['apn_ambr_ul']) + ue_ambr_dl = int(apn_profile['AMBR']['apn_ambr_dl']) + else: + #use default AMBR of unlimited if no value in subscriber_details + ue_ambr_ul = 50000000 + ue_ambr_dl = 100000000 + + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(ue_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) + #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['priority_level']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_profile['qos']['qci']), 4)) + APN_EPS_Subscribed_QoS_Profile = self.generate_vendor_avp(1431, "c0", 10415, AVP_QoS + AVP_ARP) + + + #If static UE IP is specified + try: + apn_ip = apn_profile['ue']['addr'] + self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(apn_ip), redisClient=self.redisMessaging) + Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(apn_ip)) + except: + Served_Party_Address = "" + + if 'MIP6-Agent-Info' in apn_profile: + self.logTool.log(service='HSS', level='debug', message="MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info']), redisClient=self.redisMessaging) + MIP6_Destination_Host = self.generate_avp(293, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_HOST']))) + MIP6_Destination_Realm = self.generate_avp(283, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_REALM']))) + MIP6_Home_Agent_Host = self.generate_avp(348, '40', MIP6_Destination_Host + MIP6_Destination_Realm) + MIP6_Agent_Info = self.generate_avp(486, '40', MIP6_Home_Agent_Host) + self.logTool.log(service='HSS', level='debug', message="MIP6 value is " + str(MIP6_Agent_Info), redisClient=self.redisMessaging) + else: + MIP6_Agent_Info = '' + + if 'PDN_GW_Allocation_Type' in apn_profile: + self.logTool.log(service='HSS', level='debug', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) + PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) + self.logTool.log(service='HSS', level='debug', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) + else: + PDN_GW_Allocation_Type = '' + + if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: + self.logTool.log(service='HSS', level='debug', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) + VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) + self.logTool.log(service='HSS', level='debug', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) + else: + VPLMN_Dynamic_Address_Allowed = '' + + APN_Configuration_AVPS = APN_context_identifer + APN_PDN_type + APN_AMBR + APN_Service_Selection \ + + APN_EPS_Subscribed_QoS_Profile + Served_Party_Address + MIP6_Agent_Info + PDN_GW_Allocation_Type + VPLMN_Dynamic_Address_Allowed + + APN_Configuration += self.generate_vendor_avp(1430, "c0", 10415, APN_Configuration_AVPS) + + #Incriment Context Identifier Count to keep track of how many APN Profiles returned + APN_context_identifer_count = APN_context_identifer_count + 1 + self.logTool.log(service='HSS', level='debug', message="Processed APN profile " + str(apn_profile['apn']), redisClient=self.redisMessaging) + + subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(720, 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) + subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_context_identifer + \ + self.generate_vendor_avp(1428, "c0", 10415, self.int_to_hex(0, 4)) + APN_Configuration) + + #If MSISDN is present include it in Subscription Data + if 'msisdn' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA", redisClient=self.redisMessaging) + msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, str(subscriber_details['msisdn'])) #MSISDN + self.logTool.log(service='HSS', level='debug', message=msisdn_avp, redisClient=self.redisMessaging) + subscription_data += msisdn_avp + + if 'RAT_freq_priorityID' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA", redisClient=self.redisMessaging) + rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID + self.logTool.log(service='HSS', level='debug', message=rat_freq_priorityID, redisClient=self.redisMessaging) + subscription_data += rat_freq_priorityID + + if '3gpp-charging-characteristics' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA", redisClient=self.redisMessaging) + _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, self.string_to_hex(str(subscriber_details['3gpp-charging-characteristics']))) + subscription_data += _3gpp_charging_characteristics + self.logTool.log(service='HSS', level='debug', message=_3gpp_charging_characteristics, redisClient=self.redisMessaging) + + + if 'APN_OI_replacement' in subscriber_details: + self.logTool.log(service='HSS', level='debug', message="APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA", redisClient=self.redisMessaging) + subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) + + + if 'GetLocation' in kwargs: + avp += self.generate_vendor_avp(1400, "c0", 10415, "") #Subscription-Data + else: + avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data + + response = self.generate_diameter_packet("01", "C0", 319, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Cx Location Information Request (LIR) + #ToDo - Check the command code here... + def Request_16777216_302(self, sipaor): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + #Auth Session state + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex(sipaor)) #Public-Identity / SIP-AOR + avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host + + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID + + + response = self.generate_diameter_packet("01", "c0", 302, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Cx User Authorization Request (UAR) + def Request_16777216_300(self, imsi, domain): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(1, 40, self.string_to_hex(imsi + "@" + domain)) #User-Name + avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + imsi + "@" + domain)) #Public-Identity + avp += self.generate_vendor_avp(600, "c0", 10415, self.string_to_hex(domain)) #Visited Network Identifier + response = self.generate_diameter_packet("01", "c0", 300, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Cx Server Assignment Request (SAR) + def Request_16777216_301(self, imsi, domain, server_assignment_type): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session Session ID + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + imsi + "@" + domain)) #Public-Identity + avp += self.generate_vendor_avp(602, "c0", 10415, self.string_to_hex('sip:scscf.ims.mnc' + self.MNC + '.mcc' + self.MCC + '.3gppnetwork.org:5060')) #Public-Identity + avp += self.generate_avp(1, 40, self.string_to_hex(imsi + "@" + domain)) #User-Name + avp += self.generate_vendor_avp(614, "c0", 10415, format(int(server_assignment_type),"x").zfill(8)) #Server Assignment Type + avp += self.generate_vendor_avp(624, "c0", 10415, "00000000") #User Data Already Available (Not Available) + response = self.generate_diameter_packet("01", "c0", 301, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Cx Multimedia Authentication Request (MAR) + def Request_16777216_303(self, imsi, domain): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(1, 40, self.string_to_hex(str(imsi) + "@" + domain)) #User-Name + avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + str(imsi) + "@" + domain)) #Public-Identity + avp += self.generate_vendor_avp(607, "c0", 10415, "00000001") #3GPP-SIP-Number-Auth-Items + #3GPP-SIP-Number-Auth-Data-Item + + avp += self.generate_vendor_avp(612, "c0", 10415, "00000260c0000013000028af756e6b6e6f776e0000000262c000002a000028af02e3fe1064bea4dd52602bef1c80a34ededbeb4ccabfa0430f4ffd5f1d8c0000") + avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name + response = self.generate_diameter_packet("01", "c0", 303, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Cx Registration Termination Request (RTR) + def Request_16777216_304(self, imsi, domain, destinationHost, destinationRealm): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #SIP-Deregistration-Reason + reason_code_avp = self.generate_vendor_avp(616, "c0", 10415, "00000000") + reason_info_avp = self.generate_vendor_avp(617, "c0", 10415, self.string_to_hex("Administrative Deregistration")) + avp += self.generate_vendor_avp(615, "c0", 10415, reason_code_avp + reason_info_avp) + + avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) #Destination Realm + avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination Host + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(1, 40, self.string_to_hex(str(imsi) + "@" + domain)) #User-Name + avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + str(imsi) + "@" + domain)) #Public-Identity + avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name + + #* [ Route-Record ] + avp += self.generate_avp(282, "40", self.OriginHost) + + response = self.generate_diameter_packet("01", "c0", 304, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + + return response + + #3GPP Sh User-Data Request (UDR) + def Request_16777217_306(self, **kwargs): + avp = '' #Initiate empty var AVP #Session-ID + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_sh' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + + avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name + + #* [ Route-Record ] + avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) + + if "msisdn" in kwargs: + msisdn = kwargs['msisdn'] + msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(msisdn))) #MSISDN + avp += self.generate_vendor_avp(700, "c0", 10415, msisdn_avp) #User-Identity + avp += self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(msisdn))) + elif "imsi" in kwargs: + imsi = kwargs['imsi'] + public_identity_avp = self.generate_vendor_avp(601, 'c0', 10415, self.string_to_hex(imsi)) #MSISDN + avp += self.generate_vendor_avp(700, "c0", 10415, public_identity_avp) #Username (IMSI) + + response = self.generate_diameter_packet("01", "c0", 306, 16777217, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + + return response + + #3GPP S13 - ME-Identity-Check Request + def Request_16777252_324(self, imsi, imei, software_version): + avp = '' + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(293, 40, str(binascii.hexlify(b'eir.localdomain'),'ascii')) #Destination Host + imei = self.generate_vendor_avp(1402, "c0", 10415, str(binascii.hexlify(str.encode(imei)),'ascii')) + software_version = self.generate_vendor_avp(1403, "c0", 10415, self.string_to_hex(software_version)) + avp += self.generate_vendor_avp(1401, "c0", 10415, imei + software_version) #Terminal Information + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + response = self.generate_diameter_packet("01", "c0", 324, 16777252, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP SLg - Provide Subscriber Location Request + def Request_16777255_8388620(self, imsi): + avp = '' + #ToDo - Update the Vendor Specific Application ID + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm + avp += self.generate_avp(293, 40, str(binascii.hexlify(b'mme-slg.localdomain'),'ascii')) #Destination Host + #SLg Location Type AVP + avp += self.generate_vendor_avp(2500, "c0", 10415, "00000000") + #Username (IMSI) + avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) + #LCS-EPS-Client-Name + LCS_EPS_Client_Name = self.generate_vendor_avp(1238, "c0", 10415, str(binascii.hexlify(b'PyHSS GMLC'),'ascii')) #LCS Name String + LCS_EPS_Client_Name += self.generate_vendor_avp(1237, "c0", 10415, "00000002") #LCS Format Indicator + avp += self.generate_vendor_avp(2501, "c0", 10415, LCS_EPS_Client_Name) + #LCS-Client-Type (Emergency Services) + avp += self.generate_vendor_avp(1241, "c0", 10415, "00000000") + response = self.generate_diameter_packet("01", "c0", 8388620, 16777255, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP SLh - Provide Subscriber Location Request + def Request_16777291_8388622(self, **kwargs): + avp = '' + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777252),"x").zfill(8)) #Auth-Application-ID S13 + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + self.generate_id(5) + ';1;app_slh' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + + #Username (IMSI) + if 'imsi' in kwargs: + avp += self.generate_avp(1, 40, self.string_to_hex(str(kwargs.get('imsi')))) #Username (IMSI) + + #MSISDN (Optional) + if 'msisdn' in kwargs: + avp += self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(kwargs.get('msisdn')))) #Username (IMSI) + + #GMLC Address + avp += self.generate_vendor_avp(2405, 'c0', 10415, self.ip_to_hex('127.0.0.1')) #GMLC-Address + + response = self.generate_diameter_packet("01", "c0", 8388622, 16777291, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Gx - Credit Control Request + def Request_16777238_272(self, imsi, apn, ccr_type, destinationHost, destinationRealm, sessionId=None): + avp = '' + if sessionId == None: + sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_gx' #Session state generate + else: + sessionid = sessionId + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx + + #CCR Type + avp += self.generate_avp(416, 40, format(int(ccr_type),"x").zfill(8)) + avp += self.generate_avp(415, 40, format(int(0),"x").zfill(8)) + + #Subscription ID + Subscription_ID_Data = self.generate_avp(444, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) + Subscription_ID_Type = self.generate_avp(450, 40, format(int(1),"x").zfill(8)) + avp += self.generate_avp(443, 40, Subscription_ID_Type + Subscription_ID_Data) + + + #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP + SupportedFeatures = '' + SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID + SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "0000000b") #Feature-List Flags + avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP + + avp += self.generate_vendor_avp(1024, 80, 10415, self.int_to_hex(1, 4)) #Network Requests Supported + + avp += self.generate_avp(8, 40, binascii.b2a_hex(os.urandom(4)).decode('utf-8')) #Framed IP Address Randomly Generated + + avp += self.generate_vendor_avp(1027, 'c0', 10415, self.int_to_hex(5, 4)) #IP CAN Type (EPS) + avp += self.generate_vendor_avp(1032, 'c0', 10415, self.int_to_hex(1004, 4)) #RAT-Type (EUTRAN) + #Default EPS Bearer QoS + avp += self.generate_vendor_avp(1049, 80, 10415, + '0000041980000058000028af00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000080000041780000010000028af000000010000041880000010000028af00000001') + #3GPP-User-Location-Information + avp += self.generate_vendor_avp(22, 80, 10415, + '8205f539007b05f53900000001') + avp += self.generate_vendor_avp(23, 80, 10415, '00000000') #MS Timezone + + #Called Station ID (APN) + avp += self.generate_avp(30, 40, str(binascii.hexlify(str.encode(apn)),'ascii')) + + response = self.generate_diameter_packet("01", "c0", 272, 16777238, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Gx - Re Auth Request + def Request_16777238_258(self, sessionId, servingPgw, servingRealm, chargingRules=None, ueIp=None, chargingRuleAction='install', chargingRuleName=None): + avp = '' + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Creating Re Auth Request", redisClient=self.redisMessaging) + + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionId)),'ascii')) #Session-Id set AVP + + #Setup Charging Rule + self.logTool.log(service='HSS', level='debug', message=chargingRules, redisClient=self.redisMessaging) + if chargingRules is not None and ueIp is not None: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Charging Rules: {chargingRules}", redisClient=self.redisMessaging) + avp += self.Charging_Rule_Generator(ChargingRules=chargingRules, ue_ip=ueIp) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Generated Charging Rules", redisClient=self.redisMessaging) + elif chargingRuleName is not None and chargingRuleAction == 'remove': + avp += self.Charging_Rule_Generator(action=chargingRuleAction, chargingRuleName=chargingRuleName) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Removing Charging Rule: {chargingRuleName}", redisClient=self.redisMessaging) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(293, 40, self.string_to_hex(servingPgw)) #Destination Host + avp += self.generate_avp(283, 40, self.string_to_hex(servingRealm)) #Destination Realm + + avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx + + avp += self.generate_avp(285, 40, format(int(0),"x").zfill(8)) #Re-Auth Request TYpe + + response = self.generate_diameter_packet("01", "c0", 258, 16777238, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP Gy - Credit Control Request + def Request_4_272(self, sessionid, imsi, CC_Request_Type, input_octets, output_octets): + avp = '' + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm + + avp += self.generate_avp(258, 40, format(int(4),"x").zfill(8)) #Auth-Application-ID Gx + avp += self.generate_avp(461, 40, self.string_to_hex("open5gs-smfd@open5gs.org")) #Service Context ID + avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC Request Type + avp += self.generate_avp(415, 40, format(int(0),"x").zfill(8)) #CC Request Number + avp += self.generate_avp(55, 40, '00000000') #Event Timestamp + + #Subscription ID + Subscription_ID_Data = self.generate_avp(444, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) + Subscription_ID_Type = self.generate_avp(450, 40, format(int(1),"x").zfill(8)) + avp += self.generate_avp(443, 40, Subscription_ID_Type + Subscription_ID_Data) + + avp += self.generate_avp(436, 40, format(int(0),"x").zfill(8)) #Requested Action (Direct Debiting) + + avp += self.generate_vendor_avp(2055, 'c0', 10415, "00000001") #AoC_FULL (1) + + avp += self.generate_avp(455, 40, format(int(0),"x").zfill(8)) #Multiple Services Indicator (Not Supported) + if int(CC_Request_Type) == 1: + mscc = '' #Multiple Services Credit Control + mscc += self.generate_avp(437, 40, '') #Requested Service Unit + used_service_unit = '' + used_service_unit += self.generate_avp(420, 40, format(int(0),"x").zfill(8)) #Time + used_service_unit += self.generate_avp(412, 40, format(int(0),"x").zfill(16)) #Input Octets + used_service_unit += self.generate_avp(414, 40, format(int(0),"x").zfill(16)) #Output Octets + mscc += self.generate_avp(446, 40, used_service_unit) #Used Service Unit + mscc += self.generate_vendor_avp(1016, 'c0', 10415, #QoS Information + "00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000090000041780000010000028af000000000000041880000010000028af000000000000041180000010000028af061a80000000041080000010000028af061a8000") + mscc += self.generate_vendor_avp(21, 'c0', 10415, '000028af') #3GPP RAT Type (WB-EUTRAN) + avp += self.generate_avp(456, 40, mscc) + + elif int(CC_Request_Type) == 2: + mscc = '' #Multiple Services Credit Control + mscc += self.generate_avp(437, 40, '') #Requested Service Unit + used_service_unit = '' + used_service_unit += self.generate_avp(420, 40, format(int(0),"x").zfill(8)) #Time + used_service_unit += self.generate_avp(412, 40, format(int(input_octets),"x").zfill(16)) #Input Octets + used_service_unit += self.generate_avp(414, 40, format(int(output_octets),"x").zfill(16)) #Output Octets + mscc += self.generate_avp(446, 40, used_service_unit) #Used Service Unit + mscc += self.generate_vendor_avp(872, 'c0', 10415, format(int(4),"x").zfill(8)) #3GPP Reporting Reason (Validity Time (4)) + mscc += self.generate_vendor_avp(1016, 'c0', 10415, #QoS Information + "00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000090000041780000010000028af000000000000041880000010000028af000000000000041180000010000028af061a80000000041080000010000028af061a8000") + mscc += self.generate_vendor_avp(21, 'c0', 10415, '000028af') #3GPP RAT Type (WB-EUTRAN) + avp += self.generate_avp(456, 40, mscc) + elif int(CC_Request_Type) == 3: + #Multiple Services Credit Control + avp += self.generate_avp(456, 40, + "000001be40000034000001a44000000c000000000000019c4000001000000000000000000000019e40000010000000000000000000000368c0000010000028af00000002000003f8c0000078000028af00000404c0000010000028af000000090000040a8000003c000028af0000041680000010000028af000000020000041780000010000028af000000010000041880000010000028af000000000000041180000010000028af020000000000041080000010000028af0320000000000015c000000d000028af06000000") + + #Service Information + avp += self.generate_vendor_avp(873, 'c0', 10415, + "0000036ac00000d8000028af00000002c0000010000028af0000010400000003c0000010000028af00000000000004cbc0000012000028af00010a2d01050000000004ccc0000012000028af0001ac1212ca00000000034fc0000012000028af0001ac12120400000000001e40000010696e7465726e65740000000cc000000d000028af300000000000000dc0000010000028af3030303000000012c0000011000028af30303130310000000000000ac000000d000028af0100000000000016c0000019000028af8200f110000100f11000000017000000") + response = self.generate_diameter_packet("01", "c0", 272, 4, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + + #3GPP Sh - Profile Update Request + def Request_16777217_307(self, msisdn): + avp = '' + sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_sh' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777217),"x").zfill(8)) #Auth-Application-ID Gx + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host + avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber IMS info based on MSISDN", redisClient=self.redisMessaging) + subscriber_ims_details = self.database.Get_IMS_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber IMS details: " + str(subscriber_ims_details), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber info based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) + subscriber_details = {**subscriber_details, **subscriber_ims_details} + self.logTool.log(service='HSS', level='debug', message="Merged subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) + + avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(subscriber_details['imsi'])),'ascii')) #Username AVP + + + #Sh-User-Data (XML) + #This loads a Jinja XML template containing the Sh-User-Data + templateLoader = jinja2.FileSystemLoader(searchpath="./") + templateEnv = jinja2.Environment(loader=templateLoader) + sh_userdata_template = self.config['hss']['Default_Sh_UserData'] + self.logTool.log(service='HSS', level='debug', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) + template = templateEnv.get_template(sh_userdata_template) + #These variables are passed to the template for use + subscriber_details['mnc'] = self.MNC.zfill(3) + subscriber_details['mcc'] = self.MCC.zfill(3) + + self.logTool.log(service='HSS', level='debug', message="Rendering template with values: " + str(subscriber_details), redisClient=self.redisMessaging) + xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer + avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) + + response = self.generate_diameter_packet("01", "c0", 307, 16777217, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response + + #3GPP S13 - ME-Identity-Check Request + def Request_16777252_324(self, imei, imsi): + avp = '' + sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_s13' #Session state generate + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP + #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- + VendorSpecificApplicationId = '' + VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx + avp += self.generate_avp(260, 40, VendorSpecificApplicationId) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) + avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host + avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi)),'ascii')) #Username AVP + TerminalInformation = '' + TerminalInformation += self.generate_vendor_avp(1402, 'c0', 10415, str(binascii.hexlify(str.encode(imei)),'ascii')) + TerminalInformation += self.generate_vendor_avp(1403, 'c0', 10415, str(binascii.hexlify(str.encode('00')),'ascii')) + avp += self.generate_vendor_avp(1401, 'c0', 10415, TerminalInformation) + + + response = self.generate_diameter_packet("01", "c0", 324, 16777252, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet + return response \ No newline at end of file diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py new file mode 100644 index 0000000..f8c5be9 --- /dev/null +++ b/lib/diameterAsync.py @@ -0,0 +1,360 @@ +#Diameter Packet Decoder / Encoder & Tools +import math +import asyncio +import yaml +from messagingAsync import RedisMessagingAsync + + +class DiameterAsync: + + def __init__(self, logTool): + self.diameterCommandList = [ + {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, + {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, + {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, + {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, + {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, + {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, + {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, + {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, + {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, + {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, + {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 258, "applicationId": 16777238, "responseMethod": self.Answer_16777238_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, + {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, + {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + ] + + with open("../config.yaml", 'r') as stream: + self.config = (yaml.safe_load(stream)) + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + + self.logTool = logTool + + + #Generates rounding for calculating padding + async def myRound(self, n, base=4): + if(n > 0): + return math.ceil(n/4.0) * 4 + elif( n < 0): + return math.floor(n/4.0) * 4 + else: + return 4 + + async def roundUpToMultiple(self, n, multiple): + return ((n + multiple - 1) // multiple) * multiple + + async def getAvpData(self, avps, avp_code): + #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) + misc_data = [] + for keys in avps: + if keys['avp_code'] == avp_code: + misc_data.append(keys['misc_data']) + return misc_data + + async def validateSingleAvp(self, data) -> bool: + """ + Attempts to validate a single hex string diameter AVP as being an AVP. + """ + try: + avpCode = int(data[0:8], 16) + # The next byte contains the AVP Flags + avpFlags = data[8:10] + # The next 3 bytes contain the AVP Length + avpLength = int(data[10:16], 16) + if avpFlags not in ['80', '40', '20', '00', 'c0']: + #print(f"[AVP VALIDATION] Failed to validate due to invalid Flag: {data}") + return False + if int(len(data[16:]) / 2) < ((avpLength - 8)): + #print(f"[AVP VALIDATION] Failed to validate due to invalid length: {data}") + return False + return True + except Exception as e: + return False + + + async def decodeDiameterPacket(self, data): + """ + Handles decoding of a full diameter packet. + """ + packet_vars = {} + avps = [] + + if type(data) is bytes: + data = data.hex() + # One byte is 2 hex characters + # First Byte is the Diameter Packet Version + packet_vars['packet_version'] = data[0:2] + # Next 3 Bytes are the length of the entire Diameter packet + packet_vars['length'] = int(data[2:8], 16) + # Next Byte is the Diameter Flags + packet_vars['flags'] = data[8:10] + packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + # Next 3 Bytes are the Diameter Command Code + packet_vars['command_code'] = int(data[10:16], 16) + # Next 4 Bytes are the Application Id + packet_vars['ApplicationId'] = int(data[16:24], 16) + # Next 4 Bytes are the Hop By Hop Identifier + packet_vars['hop-by-hop-identifier'] = data[24:32] + # Next 4 Bytes are the End to End Identifier + packet_vars['end-to-end-identifier'] = data[32:40] + + #We're enforcing correct length, and calculate the end byte based on the length of the remaining AVPs and the known 'length' packet var. + + lengthOfDiameterVars = int(len(data[:40]) / 2) + #print(f"Length of Diameter Vars (Bytes): {lengthOfDiameterVars}") + + #Length of all AVPs, in bytes + avpLength = int(packet_vars['length'] - lengthOfDiameterVars) + #print(f"avpLength (bytes): {avpLength}") + avpCharLength = int((avpLength * 2)) + #print(f"avpCharLength (chars): {avpCharLength}") + #print(f"Total Data Length (bytes) {len(data) / 2}") + remaining_avps = data[40:] + + #print(remaining_avps) + + avps = await self.decodeAvpPacket(remaining_avps) + #print(f"Got Back: {avps}") + + return packet_vars, avps + + async def decodeAvpPacket(self, data): + """ + Returns a list of decoded AVP Packet dictionaries. + """ + processed_avps = [] + # Initialize a failsafe counter, to prevent packets that pass validation but aren't AVPs from causing an infinite loop + failsafeCounter = 0 + + # If the avp data is 8 bytes (16 chars) or less, it's invalid. + if len(data) < 16: + return [] + + # Keep processing AVPs until they're all dealt with + while len(data) > 16: + try: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + avp_vars = {} + #print(f"AVP Data: {data}") + # The first 4 bytes contains the AVP code + avp_vars['avp_code'] = int(data[0:8], 16) + # The next byte contains the AVP Flags + avp_vars['avp_flags'] = data[8:10] + # The next 3 bytes contains the AVP Length + avp_vars['avp_length'] = int(data[10:16], 16) + #print(f"Individual AVP Length: {avp_vars['avp_length']}") + # The remaining bytes (until the end, defined by avp_length) is the AVP payload. + # Padding is excluded from avp_length. It's calculated separately, and unknown by the AVP itself. + # We calculate the avp payload length (in bytes) by subtracting 8, because the avp headers are always 8 bytes long. + # The result is then multiplied by 2 to give us chars. + avpPayloadLength = int((avp_vars['avp_length'])*2) + #print(f"AVP Payload Length (Chars): {avpPayloadLength}") + + # Work out our vendor id and add the payload itself (misc_data) + if avp_vars['avp_code'] == 266: + avp_vars['vendor_id'] = int(data[16:24], 16) + avp_vars['misc_data'] = data[16:avpPayloadLength] + else: + avp_vars['vendor_id'] = '' + avp_vars['misc_data'] = data[16:avpPayloadLength] + + # Rounds up the length to the nearest multiple of 4, which we can differential against the avp length to give us the padding length (if required) + avp_padded_length = int((await(self.roundUpToMultiple(avp_vars['avp_length'], 4)))) + # avp_padded_length = (avp_vars['avp_length'] + 3) // 4 * 4 + avpPaddingLength = ((avp_padded_length - avp_vars['avp_length']) * 2) + #print(f"AVP Padding length (Chars): {avpPaddingLength}") + + avp_vars['sub_avps'] = [] + + # Check if the payload data contains sub or grouped AVPs inside + payloadContainsSubAvps = await(self.validateSingleAvp(avp_vars['misc_data'])) + + if payloadContainsSubAvps: + # If the payload contains sub or grouped AVPs, assign misc_data to sub_avps to start working through them + sub_avp_data = avp_vars['misc_data'] + + while payloadContainsSubAvps: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + sub_avp = {} + sub_avp['avp_code'] = int(sub_avp_data[0:8], 16) + sub_avp['avp_flags'] = sub_avp_data[8:10] + sub_avp['avp_length'] = int(sub_avp_data[10:16], 16) + sub_avpPayloadLength = int((sub_avp['avp_length'])*2) + + if sub_avp['avp_code'] == 266: + sub_avp['vendor_id'] = int(sub_avp_data[16:24], 16) + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + else: + sub_avp['vendor_id'] = '' + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + + avp_vars['sub_avps'].append(sub_avp) + + #print(f"Sub Avp Data before trimming: {sub_avp_data}") + #print(f"Sub Avp payload length: {sub_avpPayloadLength}") + sub_avp_data = sub_avp_data[sub_avpPayloadLength:] + avp_vars['misc_data'] = avp_vars['misc_data'][sub_avpPayloadLength:] + #print(f"Sub Avp Data after trimming: {sub_avp_data}") + payloadContainsSubAvps = await(self.validateSingleAvp(sub_avp_data)) + + if avpPaddingLength > 0: + processed_avps.append(avp_vars) + data = data[avpPayloadLength+avpPaddingLength:] + else: + processed_avps.append(avp_vars) + data = data[avpPayloadLength:] + except Exception as e: + #print(f"EXCEPTION: {e}") + continue + + return processed_avps + + async def getPeerType(self, originHost: str) -> str: + try: + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + + for peer in peerTypes: + if peer in originHost.lower(): + return peer + + except Exception as e: + return '' + + async def getConnectedPeersByType(self, peerType: str) -> list: + try: + peerType = peerType.lower() + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + + if peerType not in peerTypes: + return [] + filteredConnectedPeers = [] + activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers")) + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + filteredConnectedPeers.append(activePeers.get(key, {})) + + return filteredConnectedPeers + + except Exception as e: + return [] + + async def getDiameterMessageType(self, binaryData: str) -> dict: + """ + Determines whether a message is a request or a response, and the appropriate acronyms for each type. + """ + packet_vars, avps = await(self.decodeDiameterPacket(binaryData)) + response = {} + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if packet_vars["flags_bin"][0:1] == "1": + response['inbound'] = diameterApplication["requestAcronym"] + response['outbound'] = diameterApplication["responseAcronym"] + else: + response['inbound'] = diameterApplication["responseAcronym"] + response['outbound'] = diameterApplication["requestAcronym"] + except Exception as e: + continue + + return response + + async def generateDiameterResponse(self, binaryData: str) -> str: + packet_vars, avps = await(self.decodeDiameterPacket(binaryData)) + response = '' + + # Drop packet if it's a response packet: + if packet_vars["flags_bin"][0:1] == "0": + return + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if 'flags' in diameterApplication: + assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + response = diameterApplication["responseMethod"](packet_vars, avps) + except Exception as e: + continue + + return response + + async def Answer_257(self): + pass + + async def Answer_16777238_272(self): + pass + + async def Answer_280(self): + pass + + async def Answer_282(self): + pass + + async def Answer_16777251_318(self): + pass + + async def Answer_16777251_316(self): + pass + + async def Answer_16777251_321(self): + pass + + async def Answer_16777251_323(self): + pass + + async def Answer_16777216_300(self): + pass + + async def Answer_16777216_301(self): + pass + + async def Answer_16777216_302(self): + pass + + async def Answer_16777216_303(self): + pass + + async def Answer_16777217_306(self): + pass + + async def Answer_16777217_307(self): + pass + + async def Answer_16777252_324(self): + pass + + async def Answer_16777291_8388622(self): + pass + + async def Answer_16777236_265(self): + pass + + async def Answer_16777236_275(self): + pass + + async def Answer_16777236_274(self): + pass + + async def Answer_16777238_258(self): + pass \ No newline at end of file diff --git a/lib/logtool.py b/lib/logtool.py index 7fff24b..8506113 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -1,244 +1,98 @@ import logging import logging.handlers as handlers -import os -import sys -import inspect +import os, sys, time +from datetime import datetime sys.path.append(os.path.realpath('../')) -import yaml -from datetime import datetime as log_dt -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -import json -import pickle - -from prometheus_client import Counter, Gauge, Histogram, Summary - -from prometheus_client import start_http_server - -if yaml_config['prometheus']['enabled'] == True: - #Check if this is the HSS service, and if it's not increment the port before starting - print(sys.argv[0]) - if 'hss.py' in str(sys.argv[0]): - print("Starting Prometheus on port from config " + str(yaml_config['prometheus']['port'])) - else: - print("This is not the HSS stack so offsetting Prometheus port") - yaml_config['prometheus']['port'] += 1 - try: - start_http_server(yaml_config['prometheus']['port']) - print("Started Prometheus on port " + str(yaml_config['prometheus']['port'])) - except Exception as E: - print("Error loading Prometheus") - print(E) - - -tags = ['diameter_application_id', 'diameter_cmd_code', 'endpoint', 'type'] -prom_diam_request_count = Counter('prom_diam_request_count', 'Number of Diameter Requests', tags) -prom_diam_response_count_successful = Counter('prom_diam_response_count_successful', 'Number of Successful Diameter Responses', tags) -prom_diam_response_count_fail = Counter('prom_diam_response_count_fail', 'Number of Failed Diameter Responses', tags) -prom_diam_connected_peers = Gauge('prom_diam_connected_peers', 'Connected Diameter Peer Count', ['endpoint']) -prom_diam_connected_peers._metrics.clear() -prom_diam_response_time_diam = Histogram('prom_diam_response_time_diam', 'Diameter Response Times') -prom_diam_response_time_method = Histogram('prom_diam_response_time_method', 'Diameter Response Times', tags) -prom_diam_response_time_db = Summary('prom_diam_response_time_db', 'Diameter Response Times from Database') -prom_diam_response_time_h = Histogram('request_latency_seconds', 'Diameter Response Time Histogram') -prom_diam_auth_event_count = Counter('prom_diam_auth_event_count', 'Diameter Authentication related Counters', ['diameter_application_id', 'diameter_cmd_code', 'event', 'imsi_prefix']) -prom_diam_eir_event_count = Counter('prom_diam_eir_event_count', 'Diameter EIR event related Counters', ['response']) - -prom_eir_devices = Counter('prom_eir_devices', 'Profile of attached devices', ['imei_prefix', 'device_type', 'device_name']) - -prom_http_geored = Counter('prom_http_geored', 'Number of Geored Pushes', ['geored_host', 'endpoint', 'http_response_code', 'error']) -prom_flask_http_geored_endpoints = Counter('prom_flask_http_geored_endpoints', 'Number of Geored Pushes Received', ['geored_host', 'endpoint']) - -prom_diam_result_code = Counter('prom_diam_result_code', 'Prometheus Result Codes', ['result_code', 'diameter_application_id', 'diameter_cmd_code', 'endpoint', 'imsi']) - -prom_pcrf_subs = Gauge('prom_pcrf_subs', 'Number of attached PCRF Subscribers') -prom_mme_subs = Gauge('prom_mme_subs', 'Number of attached MME Subscribers') -prom_ims_subs = Gauge('prom_ims_subs', 'Number of attached IMS Subscribers') +import asyncio +from messagingAsync import RedisMessagingAsync +from messaging import RedisMessaging + +class TimestampFilter (logging.Filter): + """ + Logging filter which checks for a `timestamp` attribute on a + given LogRecord, and if present it will override the LogRecord creation time. + Expects time.time() or equivalent integer. + """ + + def filter(self, record): + if hasattr(record, 'timestamp'): + record.created = record.timestamp + return True class LogTool: - def __init__(self, **kwargs): - print("Instantiating LogTool with Kwargs " + str(kwargs.items())) - if yaml_config['redis']['enabled'] == True: - print("Redis support enabled") - import redis - redis_store = redis.Redis(host=str(yaml_config['redis']['host']), port=str(yaml_config['redis']['port']), db=0) - self.redis_store = redis_store - try: - if "HSS_Init" in kwargs: - print("Called Init for HSS_Init") - redis_store.incr('restart_count') - if yaml_config['redis']['clear_stats_on_boot'] == True: - logging.debug("Clearing ActivePeerDict") - redis_store.delete('ActivePeerDict') - else: - logging.debug("Leaving prexisting Redis keys") - #Clear ActivePeerDict - redis_store.delete('ActivePeerDict') - - #Clear Async Keys - for key in redis_store.scan_iter("*_request_queue"): - print("Deleting Key: " + str(key)) - redis_store.delete(key) - logging.info("Connected to Redis server") - else: - logging.info("Init of Logtool but not from HSS_Init") - except: - logging.error("Failed to connect to Redis server - Disabling") - yaml_config['redis']['enabled'] == False - - #function for handling incrimenting Redis counters with error handling - def RedisIncrimenter(self, name): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.incr(name) - except: - logging.error("failed to incriment " + str(name)) - - def RedisStore(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(key, value) - except: - logging.error("failed to set Redis key " + str(key) + " to value " + str(value)) - - def RedisGet(self, key): - if yaml_config['redis']['enabled'] == True: - try: - return self.redis_store.get(key) - except: - logging.error("failed to set Redis key " + str(key)) - - def RedisHMSET(self, key, value_dict): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.hmset(key, value_dict) - except: - logging.error("failed to set hm Redis key " + str(key) + " to value " + str(value_dict)) - - def Async_SendRequest(self, request, DiameterHostname): - if yaml_config['redis']['enabled'] == True: - try: - import time - print("Writing request to Queue '" + str(DiameterHostname) + "_request_queue'") - self.redis_store.hset(str(DiameterHostname) + "_request_queue", "hss_Async_client_" + str(int(time.time())), request) - print("Written to Queue to send.") - except Exception as E: - logging.error("failed to run Async_SendRequest to " + str(DiameterHostname)) + """ + Reusable logging class, providing both asynchronous and synchronous logging functions. + """ + def __init__(self, config: dict): + self.logLevels = { + 'CRITICAL': {'verbosity': 1, 'logging': logging.CRITICAL}, + 'ERROR': {'verbosity': 2, 'logging': logging.ERROR}, + 'WARNING': {'verbosity': 3, 'logging': logging.WARNING}, + 'INFO': {'verbosity': 4, 'logging': logging.INFO}, + 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, + 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, + } + self.logLevel = config.get('logging', {}).get('level', 'INFO') + + self.redisUseUnixSocket = config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = config.get('redis', {}).get('host', 'localhost') + self.redisPort = config.get('redis', {}).get('port', 6379) + + self.redisMessagingAsync = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) - def RedisHMGET(self, key): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Getting HM Get from " + str(key)) - data = self.redis_store.hgetall(key) - logging.debug("Result: " + str(data)) - return data - except: - logging.error("failed to get hm Redis key " + str(key)) - - def RedisHDEL(self, key, item): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Removing item " + str(item) + " from key " + str(key)) - self.redis_store.hdel(key, item) - except: - logging.error("failed to hdel Redis key " + str(key) + " item " + str(item)) - - def RedisStoreDict(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(str(key), pickle.dumps(value)) - except: - logging.error("failed to set Redis dict " + str(key) + " to value " + str(value)) - - def RedisGetDict(self, key): - if yaml_config['redis']['enabled'] == True: - try: - read_dict = self.redis_store.get(key) - return pickle.loads(read_dict) - except: - logging.error("failed to hmget Redis key " + str(key)) - - def GetDiameterPeers(self): - if yaml_config['redis']['enabled'] == True: - try: - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - return ActivePeerDict - except: - logging.error("Failed to get ActivePeerDict") - + async def logAsync(self, service: str, level: str, message: str, redisClient=None) -> bool: + """ + Tests loglevel, prints to console and queues a log message to an asynchronous redis messaging client. + """ + if redisClient == None: + redisClient = self.redisMessagingAsync + configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) + messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) + if not messageLogLevelVerbosity <= configLogLevelVerbosity: + return False + timestamp = time.time() + dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() + print(f"[{dateTimeString}] [{level.upper()}] {message}") + await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) + return True - def Manage_Diameter_Peer(self, peername, ip, action): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Managing Diameter peer to Redis with hostname" + str(peername) + " and IP " + str(ip)) - now = log_dt.now() - timestamp = str(now.strftime("%Y-%m-%d %H:%M:%S")) - - #Try and get IP and Port seperately - try: - ip = ip[0] - port = ip[1] - except: - pass - - if self.redis_store.exists('ActivePeerDict') == False: - #Initialise empty active peer dict in Redis - logging.debug("Populated new empty ActivePeerDict Redis key") - ActivePeerDict = {} - ActivePeerDict['internal_connection'] = {"connect_timestamp" : timestamp} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "add": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict) + " to add peer " + str(peername) + " with ip " + str(ip)) - - - #If key has already existed in dict due to disconnect / reconnect, get reconnection count - try: - reconnection_count = ActivePeerDict[str(ip)]['reconnection_count'] + 1 - except: - reconnection_count = 0 - - ActivePeerDict[str(ip)] = {"connect_timestamp" : timestamp, \ - "recv_ip_address" : str(ip), "DiameterHostname" : "Unknown - Socket connection only", \ - "reconnection_count" : reconnection_count, - "connection_status" : "Pending"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "remove": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict)) - ActivePeerDict[str(ip)] = {"disconnect_timestamp" : str(timestamp), \ - "DiameterHostname" : str(ActivePeerDict[str(ip)]['DiameterHostname']), \ - "reconnection_count" : ActivePeerDict[str(ip)]['reconnection_count'], - "connection_status" : "Disconnected"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "update": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - ActivePeerDict[str(ip)]['DiameterHostname'] = str(peername) - ActivePeerDict[str(ip)]['last_dwr_timestamp'] = str(timestamp) - ActivePeerDict[str(ip)]['connection_status'] = "Connected" - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - except Exception as E: - logging.error("failed to add/update/remove Diameter peer from Redis") - logging.error(E) - - - def setup_logger(self, logger_name, log_file, level=logging.DEBUG): - l = logging.getLogger(logger_name) - formatter = logging.Formatter('%(asctime)s \t %(levelname)s \t {%(pathname)s:%(lineno)d} \t %(message)s') - fileHandler = logging.FileHandler(log_file, mode='a+') - fileHandler.setFormatter(formatter) - streamHandler = logging.StreamHandler() - streamHandler.setFormatter(formatter) - rolloverHandler = handlers.RotatingFileHandler(log_file, maxBytes=50000000, backupCount=5) - l.setLevel(level) - l.addHandler(fileHandler) - l.addHandler(streamHandler) - l.addHandler(rolloverHandler) + def log(self, service: str, level: str, message: str, redisClient=None) -> bool: + """ + Tests loglevel, prints to console and queues a log message to a synchronous redis messaging client. + """ + if redisClient == None: + redisClient = self.redisMessaging + configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) + messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) + if not messageLogLevelVerbosity <= configLogLevelVerbosity: + return False + timestamp = time.time() + dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() + print(f"[{dateTimeString}] [{level.upper()}] {message}") + redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60) + return True + + def setupFileLogger(self, loggerName: str, logFilePath: str): + """ + Sets up and returns a file logger, given a loggerName and logFilePath. + Defaults to {pyhssRootDir}/log/{logFileName} if the configured file location is not writable. + """ + try: + rolloverHandler = handlers.RotatingFileHandler(logFilePath, maxBytes=50000000, backupCount=5) + except PermissionError: + logFileName = logFilePath.split('/')[-1] + pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + print(f"[LogTool] Warning - Unable to write to {logFilePath}, using {pyhssRootDir}/log/{logFileName} instead.") + logFilePath = f"{pyhssRootDir}/log/{logFileName}" + rolloverHandler = handlers.RotatingFileHandler(logFilePath, maxBytes=50000000, backupCount=5) + fileLogger = logging.getLogger(loggerName) + print(logFilePath) + formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s {%(pathname)s:%(lineno)d} %(message)s", datefmt="%m/%d/%Y %H:%M:%S %Z") + filter = TimestampFilter() + fileLogger.addFilter(filter) + rolloverHandler.setFormatter(formatter) + fileLogger.addHandler(rolloverHandler) + fileLogger.setLevel(logging.DEBUG) + return fileLogger \ No newline at end of file diff --git a/lib/messaging.py b/lib/messaging.py new file mode 100644 index 0000000..7b376a3 --- /dev/null +++ b/lib/messaging.py @@ -0,0 +1,190 @@ +from redis import Redis +import time, json, uuid, traceback + +class RedisMessaging: + """ + PyHSS Redis Message Service + A class for sending and receiving redis messages. + """ + + def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): + if useUnixSocket: + self.redisClient = Redis(unix_socket_path=unixSocketPath) + else: + self.redisClient = Redis(host=host, port=port) + + def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key). + """ + try: + self.redisClient.rpush(queue, message) + if queueExpiry is not None: + self.redisClient.expire(queue, queueExpiry) + return f'{message} stored in {queue} successfully.' + except Exception as e: + return '' + + def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + """ + Stores a prometheus metric in a format readable by the metric service. + """ + if not isinstance(metricValue, (int, float)): + return 'Invalid Argument: metricValue must be a digit' + metricValue = float(metricValue) + prometheusMetricBody = json.dumps([{ + 'serviceName': serviceName, + 'timestamp': metricTimestamp, + 'NAME': metricName, + 'TYPE': metricType, + 'HELP': metricHelp, + 'LABELS': metricLabels, + 'ACTION': metricAction, + 'VALUE': metricValue, + } + ]) + + metricQueueName = f"metric" + + try: + self.redisClient.rpush(metricQueueName, prometheusMetricBody) + if metricExpiry is not None: + self.redisClient.expire(metricQueueName, metricExpiry) + return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' + except Exception as e: + return '' + + def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key). + """ + try: + logQueueName = f"log" + logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) + self.redisClient.rpush(logQueueName, logMessage) + if logExpiry is not None: + self.redisClient.expire(logQueueName, logExpiry) + return f'{message} stored in {logQueueName} successfully.' + except Exception as e: + return '' + + def getMessage(self, queue: str) -> str: + """ + Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. + """ + try: + message = self.redisClient.lpop(queue) + if message is None: + message = '' + else: + try: + message = message.decode() + except (UnicodeDecodeError, AttributeError): + pass + return message + except Exception as e: + return '' + + def getQueues(self, pattern: str='*') -> list: + """ + Returns all Queues (Keys) in the database. + """ + try: + allQueues = self.redisClient.scan_iter(match=pattern) + return [x.decode() for x in allQueues] + except Exception as e: + return f"{traceback.format_exc()}" + + def getNextQueue(self, pattern: str='*') -> dict: + """ + Returns the next Queue (Key) in the list. + """ + try: + for nextQueue in self.redisClient.scan_iter(match=pattern): + return nextQueue.decode() + except Exception as e: + return {} + + def awaitMessage(self, key: str): + """ + Blocks until a message is received at the given key, then returns the message. + """ + try: + message = self.redisClient.blpop(key) + return tuple(data.decode() for data in message) + except Exception as e: + return '' + + def awaitBulkMessage(self, key: str, count: int=100): + """ + Blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. + """ + try: + message = self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count) + return message + except Exception as e: + print(traceback.format_exc()) + return '' + + def deleteQueue(self, queue: str) -> bool: + """ + Deletes the given Queue (Key) + """ + try: + self.redisClient.delete(queue) + return True + except Exception as e: + return False + + def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + """ + Stores a value under a given key and sets an expiry (in seconds) if provided. + """ + try: + self.redisClient.set(key, value) + if keyExpiry is not None: + self.redisClient.expire(key, keyExpiry) + return f'{value} stored in {key} successfully.' + except Exception as e: + return '' + + def getValue(self, key: str) -> str: + """ + Gets the value stored under a given key. + """ + try: + message = self.redisClient.get(key) + if message is None: + message = '' + else: + return message + except Exception as e: + return '' + + def getList(self, key: str) -> list: + """ + Gets the list stored under a given key. + """ + try: + allResults = self.redisClient.lrange(key, 0, -1) + if allResults is None: + result = [] + else: + return [result.decode() for result in allResults] + except Exception as e: + return [] + + def RedisHGetAll(self, key: str): + """ + Wrapper for Redis HGETALL + *Deprecated: will be removed upon completed database cleanup. + """ + try: + data = self.redisClient.hgetall(key) + return data + except Exception as e: + return '' + +if __name__ == '__main__': + redisMessaging = RedisMessaging() + print(redisMessaging.getNextQueue()) \ No newline at end of file diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py new file mode 100644 index 0000000..6c33e0a --- /dev/null +++ b/lib/messagingAsync.py @@ -0,0 +1,193 @@ +import asyncio +import redis.asyncio as redis +import time, json, uuid + +class RedisMessagingAsync: + """ + PyHSS Redis Asynchronous Message Service + A class for sending and receiving redis messages asynchronously. + """ + + def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): + if useUnixSocket: + self.redisClient = redis.Redis(unix_socket_path=unixSocketPath) + else: + self.redisClient = redis.Redis(host=host, port=port) + pass + + async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. + """ + try: + await(self.redisClient.rpush(queue, message)) + if queueExpiry is not None: + await(self.redisClient.expire(queue, queueExpiry)) + return f'{message} stored in {queue} successfully.' + except Exception as e: + return '' + + async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None) -> str: + """ + Empties a given asyncio queue into a redis pipeline, then sends to redis. + """ + try: + redisPipe = self.redisClient.pipeline() + + for message in messageList: + redisPipe.rpush(queue, message) + if queueExpiry is not None: + redisPipe.expire(queue, queueExpiry) + + await(redisPipe.execute()) + + return f'Messages stored in {queue} successfully.' + + except Exception as e: + return '' + + async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + """ + Stores a prometheus metric in a format readable by the metric service, asynchronously. + """ + if not isinstance(metricValue, (int, float)): + return 'Invalid Argument: metricValue must be a digit' + metricValue = float(metricValue) + prometheusMetricBody = json.dumps([{ + 'serviceName': serviceName, + 'timestamp': metricTimestamp, + 'NAME': metricName, + 'TYPE': metricType, + 'HELP': metricHelp, + 'LABELS': metricLabels, + 'ACTION': metricAction, + 'VALUE': metricValue, + } + ]) + + metricQueueName = f"metric" + + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await(redisPipe.rpush(metricQueueName, prometheusMetricBody).execute()) + if metricExpiry is not None: + await(redisPipe.expire(metricQueueName, metricExpiry).execute()) + sendMetricResult, expireKeyResult = await redisPipe.execute() + return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' + except Exception as e: + return '' + + async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + """ + Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. + """ + try: + logQueueName = f"log" + logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await redisPipe.rpush(logQueueName, logMessage) + if logExpiry is not None: + await redisPipe.expire(logQueueName, logExpiry) + sendMessageResult, expireKeyResult = await redisPipe.execute() + return f'{message} stored in {logQueueName} successfully.' + except Exception as e: + return '' + + async def getMessage(self, queue: str) -> str: + """ + Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. + """ + try: + message = await(self.redisClient.lpop(queue)) + if message is None: + message = '' + else: + try: + if message[0] is None: + return '' + else: + message = message[0].decode() + except (UnicodeDecodeError, AttributeError): + pass + return message + except Exception as e: + return '' + + async def getQueues(self, pattern: str='*') -> list: + """ + Returns all Queues (Keys) in the database, asynchronously. + """ + try: + allQueuesBinary = await(self.redisClient.scan_iter(match=pattern)) + allQueues = [x.decode() for x in allQueuesBinary] + return allQueues + except Exception as e: + return [] + + async def getNextQueue(self, pattern: str='*') -> str: + """ + Returns the next Queue (Key) in the list, asynchronously. + """ + try: + async for nextQueue in self.redisClient.scan_iter(match=pattern): + if nextQueue is not None: + return nextQueue.decode('utf-8') + except Exception as e: + print(e) + return '' + + async def awaitMessage(self, key: str): + """ + Asynchronously blocks until a message is received at the given key, then returns the message. + """ + try: + message = (await(self.redisClient.blpop(key))) + return tuple(data.decode() for data in message) + except Exception as e: + return '' + + async def deleteQueue(self, queue: str) -> bool: + """ + Deletes the given Queue (Key) asynchronously. + """ + try: + deleteQueueResult = await(self.redisClient.delete(queue)) + return True + except Exception as e: + return False + + async def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + """ + Stores a value under a given key asynchronously and sets an expiry (in seconds) if provided. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await redisPipe.set(key, value) + if keyExpiry is not None: + await redisPipe.expire(key, value) + setValueResult, expireValueResult = await redisPipe.execute() + return f'{value} stored in {key} successfully.' + except Exception as e: + return '' + + async def getValue(self, key: str) -> str: + """ + Gets the value stored under a given key asynchronously. + """ + try: + message = await(self.redisClient.get(key)) + if message is None: + message = '' + else: + return message + except Exception as e: + return '' + + async def closeConnection(self) -> bool: + await self.redisClient.close() + return True + + +if __name__ == '__main__': + redisMessaging = RedisMessagingAsync() + print(redisMessaging.getNextQueue()) \ No newline at end of file diff --git a/lib/metrics.py b/lib/metrics.py new file mode 100644 index 0000000..3db7918 --- /dev/null +++ b/lib/metrics.py @@ -0,0 +1,41 @@ +class Metrics: + + def __init__(self, redisMessaging): + self.redisMessaging = redisMessaging + + def initializeMetrics(self) -> bool: + """ + Preloads all metrics, and sets their initial value to 0. + """ + + print("Initializing Metrics") + + metricList = [ + {'serviceName':'api', 'metricName':'prom_flask_http_geored_endpoints', 'metricType':'counter', 'metricHelp':'Number of Geored Pushes Received'}, + {'serviceName':'diameter', 'metricName':'prom_diam_inbound_count', 'metricType':'counter', 'metricHelp':'Number of Diameter Inbounds'}, + {'serviceName':'geored', 'metricName':'prom_http_geored', 'metricType':'counter', 'metricHelp':'Number of Geored Pushes'}, + {'serviceName':'webhook', 'metricName':'prom_http_webhook', 'metricType':'counter', 'metricHelp':'Number of Webhook Pushes'}, + {'serviceName':'database', 'metricName':'prom_eir_devices', 'metricType':'counter', 'metricHelp':'Profile of attached devices'}, + {'serviceName':'diameter', 'metricName':'prom_ims_subs', 'metricType':'gauge', 'metricHelp':'Number of attached IMS Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_mme_subs', 'metricType':'gauge', 'metricHelp':'Number of attached MME Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_pcrf_subs', 'metricType':'gauge', 'metricHelp':'Number of attached PCRF Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_diam_auth_event_count', 'metricType':'counter', 'metricHelp':'Diameter Authentication related Counters'}, + {'serviceName':'diameter', 'metricName':'prom_diam_response_count_successful', 'metricType':'counter', 'metricHelp':'Number of Successful Diameter Responses'}, + {'serviceName':'diameter', 'metricName':'prom_diam_response_count_fail', 'metricType':'counter', 'metricHelp':'Number of Failed Diameter Responses'} + ] + + for metric in metricList: + try: + self.redisMessaging.sendMetric(serviceName=metric['serviceName'], + metricName=metric['metricName'], + metricType=metric['metricType'], + metricAction='inc', + metricValue=0.0, + metricHelp=metric['metricHelp'], + metricLabels=metric['metricLabels'], + metricExpiry=60) + except Exception as e: + print(e) + pass + + return True \ No newline at end of file diff --git a/lib/milenage.py b/lib/milenage.py index 6175920..ef31daa 100644 --- a/lib/milenage.py +++ b/lib/milenage.py @@ -14,8 +14,6 @@ from lte import BaseLTEAuthAlgo import logging -import logtool -logtool = logtool.LogTool() import os import sys sys.path.append(os.path.realpath('../')) diff --git a/log/.gitkeep b/log/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 548af43..3fecc47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -construct==2.10.68 +aiohttp==3.8.5 Flask==2.2.3 flask_restx==1.1.0 Jinja2==3.1.2 @@ -9,11 +9,10 @@ pymongo==4.3.3 pysctp==0.7.2 pysnmp==4.4.12 PyYAML==6.0 -redis==4.5.4 -Requests==2.28.2 +redis==5.0.0 +Requests==2.31.0 SQLAlchemy==2.0.9 -sqlalchemy_utils -systemd-python==234 +SQLAlchemy_Utils==0.41.1 Werkzeug==2.2.3 mysqlclient prometheus_flask_exporter \ No newline at end of file diff --git a/PyHSS_API.py b/services/apiService.py similarity index 61% rename from PyHSS_API.py rename to services/apiService.py index 0e74941..96f0710 100644 --- a/PyHSS_API.py +++ b/services/apiService.py @@ -4,33 +4,54 @@ from flask_restx import Api, Resource, fields, reqparse, abort from werkzeug.middleware.proxy_fix import ProxyFix from functools import wraps -import datetime +import os +sys.path.append(os.path.realpath('../lib')) +import time +import requests import traceback import sqlalchemy import socket +from logtool import LogTool +from diameter import Diameter +from messaging import RedisMessaging +import database +import yaml +with open("../config.yaml", 'r') as stream: + config = (yaml.safe_load(stream)) -import logging -import yaml +siteName = config.get("hss", {}).get("site_name", "") +originHostname = socket.gethostname() +lockProvisioning = config.get('hss', {}).get('lock_provisioning', False) +provisioningKey = config.get('hss', {}).get('provisioning_key', '') +mnc = config.get('hss', {}).get('MNC', '999') +mcc = config.get('hss', {}).get('MCC', '999') +originRealm = config.get('hss', {}).get('OriginRealm', f'mnc{mnc}.mcc{mcc}.3gppnetwork.org') +originHost = config.get('hss', {}).get('OriginHost', f'hss01') +productName = config.get('hss', {}).get('ProductName', f'PyHSS') -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) +redisHost = config.get("redis", {}).get("host", "127.0.0.1") +redisPort = int(config.get("redis", {}).get("port", 6379)) +redisUseUnixSocket = config.get('redis', {}).get('useUnixSocket', False) +redisUnixSocketPath = config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') -import os -import sys -sys.path.append(os.path.realpath('lib')) +redisMessaging = RedisMessaging(host=redisHost, port=redisPort, useUnixSocket=redisUseUnixSocket, unixSocketPath=redisUnixSocketPath) -#Setup Logging -import logtool +logTool = LogTool(config) +diameterClient = Diameter( + redisMessaging=redisMessaging, + logTool=logTool, + originHost=originHost, + originRealm=originRealm, + mnc=mnc, + mcc=mcc, + productName='PyHSS-client-API' + ) -import database +databaseClient = database.Database(logTool=logTool, redisMessaging=redisMessaging) -from prometheus_flask_exporter import PrometheusMetrics -app = Flask(__name__) -metrics = PrometheusMetrics.for_app_factory() -metrics.init_app(app) -from logtool import prom_flask_http_geored_endpoints +apiService = Flask(__name__) APN = database.APN Serving_APN = database.SERVING_APN @@ -45,12 +66,8 @@ OPERATION_LOG = database.OPERATION_LOG_BASE SUBSCRIBER_ROUTING = database.SUBSCRIBER_ROUTING - -site_name = yaml_config.get("hss", {}).get("site_name", "") -origin_host_name = socket.gethostname() - -app.wsgi_app = ProxyFix(app.wsgi_app) -api = Api(app, version='1.0', title=f'{site_name + " - " if site_name else ""}{origin_host_name} - PyHSS OAM API', +apiService.wsgi_app = ProxyFix(apiService.wsgi_app) +api = Api(apiService, version='1.0', title=f'{siteName + " - " if siteName else ""}{originHostname} - PyHSS OAM API', description='Restful API for working with PyHSS', doc='/docs/' ) @@ -76,41 +93,40 @@ paginatorParser = reqparse.RequestParser() paginatorParser.add_argument('page', type=int, required=False, default=0, help='Page number for pagination') -paginatorParser.add_argument('page_size', type=int, required=False, default=yaml_config['api'].get('page_size', 100), help='Number of items per page for pagination') - +paginatorParser.add_argument('page_size', type=int, required=False, default=config['api'].get('page_size', 100), help='Number of items per page for pagination') APN_model = api.schema_model('APN JSON', - database.Generate_JSON_Model_for_Flask(APN) + databaseClient.Generate_JSON_Model_for_Flask(APN) ) Serving_APN_model = api.schema_model('Serving APN JSON', - database.Generate_JSON_Model_for_Flask(Serving_APN) + databaseClient.Generate_JSON_Model_for_Flask(Serving_APN) ) AUC_model = api.schema_model('AUC JSON', - database.Generate_JSON_Model_for_Flask(AUC) + databaseClient.Generate_JSON_Model_for_Flask(AUC) ) SUBSCRIBER_model = api.schema_model('SUBSCRIBER JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER) ) SUBSCRIBER_ROUTING_model = api.schema_model('SUBSCRIBER_ROUTING JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER_ROUTING) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER_ROUTING) ) IMS_SUBSCRIBER_model = api.schema_model('IMS_SUBSCRIBER JSON', - database.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) + databaseClient.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) ) TFT_model = api.schema_model('TFT JSON', - database.Generate_JSON_Model_for_Flask(TFT) + databaseClient.Generate_JSON_Model_for_Flask(TFT) ) CHARGING_RULE_model = api.schema_model('CHARGING_RULE JSON', - database.Generate_JSON_Model_for_Flask(CHARGING_RULE) + databaseClient.Generate_JSON_Model_for_Flask(CHARGING_RULE) ) EIR_model = api.schema_model('EIR JSON', - database.Generate_JSON_Model_for_Flask(EIR) + databaseClient.Generate_JSON_Model_for_Flask(EIR) ) IMSI_IMEI_HISTORY_model = api.schema_model('IMSI_IMEI_HISTORY JSON', - database.Generate_JSON_Model_for_Flask(IMSI_IMEI_HISTORY) + databaseClient.Generate_JSON_Model_for_Flask(IMSI_IMEI_HISTORY) ) SUBSCRIBER_ATTRIBUTES_model = api.schema_model('SUBSCRIBER_ATTRIBUTES JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER_ATTRIBUTES) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER_ATTRIBUTES) ) PCRF_Push_model = api.model('PCRF_Rule', { @@ -134,6 +150,11 @@ 'serving_mme_timestamp' : fields.String(description=SUBSCRIBER.serving_mme_timestamp.doc), 'serving_apn' : fields.String(description='Access Point Name of APN'), 'pcrf_session_id' : fields.String(description=Serving_APN.pcrf_session_id.doc), + 'pcscf' : fields.String(description=IMS_SUBSCRIBER.pcscf.doc), + 'pcscf_realm' : fields.String(description=IMS_SUBSCRIBER.pcscf_realm.doc), + 'pcscf_peer' : fields.String(description=IMS_SUBSCRIBER.pcscf_peer.doc), + 'pcscf_timestamp' : fields.String(description=IMS_SUBSCRIBER.pcscf_timestamp.doc), + 'pcscf_active_session' : fields.String(description=IMS_SUBSCRIBER.pcscf_active_session.doc), 'subscriber_routing' : fields.String(description=Serving_APN.subscriber_routing.doc), 'serving_pgw' : fields.String(description=Serving_APN.serving_pgw.doc), 'serving_pgw_realm' : fields.String(description=Serving_APN.serving_pgw_realm.doc), @@ -163,9 +184,6 @@ } -lock_provisioning = yaml_config.get('hss', {}).get('lock_provisioning', False) -provisioning_key = yaml_config.get('hss', {}).get('provisioning_key', '') - def no_auth_required(f): f.no_auth_required = True return f @@ -173,9 +191,9 @@ def no_auth_required(f): def auth_required(f): @wraps(f) def decorated_function(*args, **kwargs): - if getattr(f, 'no_auth_required', False) or (lock_provisioning == False): + if getattr(f, 'no_auth_required', False) or (lockProvisioning == False): return f(*args, **kwargs) - if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != yaml_config['hss']['provisioning_key']: + if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != config['hss']['provisioning_key']: return {'Result': 'Unauthorized - Provisioning-Key Invalid'}, 401 return f(*args, **kwargs) return decorated_function @@ -184,12 +202,12 @@ def auth_before_request(): if request.path.startswith('/docs') or request.path.startswith('/swagger') or request.path.startswith('/metrics'): return None if request.endpoint and 'static' not in request.endpoint: - view_function = app.view_functions[request.endpoint] + view_function = apiService.view_functions[request.endpoint] if hasattr(view_function, 'view_class'): view_class = view_function.view_class view_method = getattr(view_class, request.method.lower(), None) if view_method: - if(lock_provisioning == False): + if(lockProvisioning == False): return None if request.method == 'GET' and not getattr(view_method, 'auth_required', False): return None @@ -198,12 +216,13 @@ def auth_before_request(): else: return None - if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != yaml_config['hss']['provisioning_key']: + if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != config['hss']['provisioning_key']: return {'Result': 'Unauthorized - Provisioning-Key Invalid'}, 401 return None def handle_exception(e): - logging.error(f"An error occurred: {e}") + + logTool.log(service='API', level='error', message=f"[API] An error occurred: {e}", redisClient=redisMessaging) response_json = {'result': 'Failed'} if isinstance(e, sqlalchemy.exc.SQLAlchemyError): @@ -222,19 +241,18 @@ def handle_exception(e): return response_json, 410 else: response_json['reason'] = f'An internal server error occurred: {e}' - logging.error(f'{traceback.format_exc()}') - logging.error(f'{sys.exc_info()[2]}') + logTool.log(service='API', level='error', message=f"[API] Additional Error Information: {traceback.format_exc()}\n{sys.exc_info()[2]}", redisClient=redisMessaging) return response_json, 500 -app.before_request(auth_before_request) +apiService.before_request(auth_before_request) -@app.errorhandler(404) +@apiService.errorhandler(404) def page_not_found(e): return {"Result": "Not Found"}, 404 -@app.after_request +@apiService.after_request def apply_caching(response): - response.headers["HSS"] = str(yaml_config['hss']['OriginHost']) + response.headers["HSS"] = str(config['hss']['OriginHost']) return response @ns_apn.route('/') @@ -242,7 +260,7 @@ class PyHSS_APN_Get(Resource): def get(self, apn_id): '''Get all APN data for specified APN ID''' try: - apn_data = database.GetObj(APN, apn_id) + apn_data = databaseClient.GetObj(APN, apn_id) return apn_data, 200 except Exception as E: print(E) @@ -253,7 +271,7 @@ def delete(self, apn_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(APN, apn_id, False, operation_id) + data = databaseClient.DeleteObj(APN, apn_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -268,7 +286,7 @@ def patch(self, apn_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_data = database.UpdateObj(APN, json_data, apn_id, False, operation_id) + apn_data = databaseClient.UpdateObj(APN, json_data, apn_id, False, operation_id) print("Updated object") print(apn_data) @@ -288,7 +306,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_id = database.CreateObj(APN, json_data, False, operation_id) + apn_id = databaseClient.CreateObj(APN, json_data, False, operation_id) return apn_id, 200 except Exception as E: @@ -302,7 +320,7 @@ def get(self): '''Get all APNs''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(APN, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(APN, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -313,8 +331,8 @@ class PyHSS_AUC_Get(Resource): def get(self, auc_id): '''Get all AuC data for specified AuC ID''' try: - auc_data = database.GetObj(AUC, auc_id) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.GetObj(AUC, auc_id) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -325,7 +343,7 @@ def delete(self, auc_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(AUC, auc_id, False, operation_id) + data = databaseClient.DeleteObj(AUC, auc_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -340,8 +358,8 @@ def patch(self, auc_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - auc_data = database.UpdateObj(AUC, json_data, auc_id, False, operation_id) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.UpdateObj(AUC, json_data, auc_id, False, operation_id) + auc_data = databaseClient.Sanitize_Keys(auc_data) print("Updated object") print(auc_data) @@ -355,8 +373,8 @@ class PyHSS_AUC_Get_ICCID(Resource): def get(self, iccid): '''Get all AuC data for specified ICCID''' try: - auc_data = database.Get_AuC(iccid=iccid) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.Get_AuC(iccid=iccid) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -367,8 +385,8 @@ class PyHSS_AUC_Get_IMSI(Resource): def get(self, imsi): '''Get all AuC data for specified IMSI''' try: - auc_data = database.Get_AuC(imsi=imsi) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.Get_AuC(imsi=imsi) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -385,7 +403,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(AUC, json_data, False, operation_id) + data = databaseClient.CreateObj(AUC, json_data, False, operation_id) return data, 200 except Exception as E: @@ -399,7 +417,7 @@ def get(self): '''Get all AuC Data (except keys)''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(AUC, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(AUC, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -410,7 +428,7 @@ class PyHSS_SUBSCRIBER_Get(Resource): def get(self, subscriber_id): '''Get all SUBSCRIBER data for specified subscriber_id''' try: - apn_data = database.GetObj(SUBSCRIBER, subscriber_id) + apn_data = databaseClient.GetObj(SUBSCRIBER, subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -421,7 +439,7 @@ def delete(self, subscriber_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(SUBSCRIBER, subscriber_id, False, operation_id) + data = databaseClient.DeleteObj(SUBSCRIBER, subscriber_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -436,7 +454,27 @@ def patch(self, subscriber_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER, json_data, subscriber_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER, json_data, subscriber_id, False, operation_id) + + #If the "enabled" flag on the subscriber is now disabled, trigger a CLR + if 'enabled' in json_data and json_data['enabled'] == False: + print("Subscriber is now disabled, checking to see if we need to trigger a CLR") + #See if we have a serving MME set + try: + assert(json_data['serving_mme']) + print("Serving MME set - Sending CLR") + + diameterClient.sendDiameterRequest( + requestType='CLR', + hostname=json_data['serving_mme'], + imsi=json_data['imsi'], + DestinationHost=json_data['serving_mme'], + DestinationRealm=json_data['serving_mme_realm'], + CancellationType=1 + ) + print("Sent CLR via Peer " + str(json_data['serving_mme'])) + except: + print("No serving MME set - Not sending CLR") #If the "enabled" flag on the subscriber is now disabled, trigger a CLR @@ -484,7 +522,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER, json_data, False, operation_id) return data, 200 except Exception as E: @@ -496,7 +534,7 @@ class PyHSS_SUBSCRIBER_IMSI(Resource): def get(self, imsi): '''Get data for IMSI''' try: - data = database.Get_Subscriber(imsi=imsi, get_attributes=True) + data = databaseClient.Get_Subscriber(imsi=imsi, get_attributes=True) return data, 200 except Exception as E: print(E) @@ -507,7 +545,7 @@ class PyHSS_SUBSCRIBER_MSISDN(Resource): def get(self, msisdn): '''Get data for MSISDN''' try: - data = database.Get_Subscriber(msisdn=msisdn, get_attributes=True) + data = databaseClient.Get_Subscriber(msisdn=msisdn, get_attributes=True) return data, 200 except Exception as E: print(E) @@ -520,7 +558,7 @@ def get(self): '''Get all Subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(SUBSCRIBER, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(SUBSCRIBER, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -537,7 +575,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER_ROUTING, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER_ROUTING, json_data, False, operation_id) return data, 200 except Exception as E: @@ -549,7 +587,7 @@ class PyHSS_SUBSCRIBER_SUBSCRIBER_ROUTING(Resource): def get(self, subscriber_id, apn_id): '''Get Subscriber Routing for specified subscriber_id & apn_id''' try: - apn_data = database.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) + apn_data = databaseClient.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) return apn_data, 200 except Exception as E: print(E) @@ -560,8 +598,8 @@ def delete(self, subscriber_id, apn_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_data = database.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) - data = database.DeleteObj(SUBSCRIBER_ROUTING, apn_data['subscriber_routing_id'], False, operation_id) + apn_data = databaseClient.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) + data = databaseClient.DeleteObj(SUBSCRIBER_ROUTING, apn_data['subscriber_routing_id'], False, operation_id) return data, 200 except Exception as E: print(E) @@ -578,7 +616,7 @@ def patch(self, subscriber_routing_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER_ROUTING, json_data, subscriber_routing_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER_ROUTING, json_data, subscriber_routing_id, False, operation_id) print("Updated object") print(data) @@ -592,7 +630,7 @@ class PyHSS_IMS_SUBSCRIBER_Get(Resource): def get(self, ims_subscriber_id): '''Get all SUBSCRIBER data for specified ims_subscriber_id''' try: - apn_data = database.GetObj(IMS_SUBSCRIBER, ims_subscriber_id) + apn_data = databaseClient.GetObj(IMS_SUBSCRIBER, ims_subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -603,7 +641,7 @@ def delete(self, ims_subscriber_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id, False, operation_id) + data = databaseClient.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -618,7 +656,7 @@ def patch(self, ims_subscriber_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(IMS_SUBSCRIBER, json_data, ims_subscriber_id, False, operation_id) + data = databaseClient.UpdateObj(IMS_SUBSCRIBER, json_data, ims_subscriber_id, False, operation_id) print("Updated object") print(data) @@ -638,7 +676,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(IMS_SUBSCRIBER, json_data, False, operation_id) + data = databaseClient.CreateObj(IMS_SUBSCRIBER, json_data, False, operation_id) return data, 200 except Exception as E: @@ -650,7 +688,7 @@ class PyHSS_IMS_SUBSCRIBER_MSISDN(Resource): def get(self, msisdn): '''Get IMS data for MSISDN''' try: - data = database.Get_IMS_Subscriber(msisdn=msisdn) + data = databaseClient.Get_IMS_Subscriber(msisdn=msisdn) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -662,7 +700,7 @@ class PyHSS_IMS_SUBSCRIBER_IMSI(Resource): def get(self, imsi): '''Get IMS data for imsi''' try: - data = database.Get_IMS_Subscriber(imsi=imsi) + data = databaseClient.Get_IMS_Subscriber(imsi=imsi) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -676,7 +714,7 @@ def get(self): '''Get all IMS Subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(IMS_SUBSCRIBER, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(IMS_SUBSCRIBER, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -687,7 +725,7 @@ class PyHSS_TFT_Get(Resource): def get(self, tft_id): '''Get all TFT data for specified tft_id''' try: - apn_data = database.GetObj(TFT, tft_id) + apn_data = databaseClient.GetObj(TFT, tft_id) return apn_data, 200 except Exception as E: print(E) @@ -698,7 +736,7 @@ def delete(self, tft_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(TFT, tft_id, False, operation_id) + data = databaseClient.DeleteObj(TFT, tft_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -713,7 +751,7 @@ def patch(self, tft_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(TFT, json_data, tft_id, False, operation_id) + data = databaseClient.UpdateObj(TFT, json_data, tft_id, False, operation_id) print("Updated object") print(data) @@ -733,7 +771,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(TFT, json_data, False, operation_id) + data = databaseClient.CreateObj(TFT, json_data, False, operation_id) return data, 200 except Exception as E: @@ -747,7 +785,7 @@ def get(self): '''Get all TFTs''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(TFT, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(TFT, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -758,7 +796,7 @@ class PyHSS_Charging_Rule_Get(Resource): def get(self, charging_rule_id): '''Get all Charging Rule data for specified charging_rule_id''' try: - apn_data = database.GetObj(CHARGING_RULE, charging_rule_id) + apn_data = databaseClient.GetObj(CHARGING_RULE, charging_rule_id) return apn_data, 200 except Exception as E: print(E) @@ -769,7 +807,7 @@ def delete(self, charging_rule_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(CHARGING_RULE, charging_rule_id, False, operation_id) + data = databaseClient.DeleteObj(CHARGING_RULE, charging_rule_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -784,7 +822,7 @@ def patch(self, charging_rule_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(CHARGING_RULE, json_data, charging_rule_id, False, operation_id) + data = databaseClient.UpdateObj(CHARGING_RULE, json_data, charging_rule_id, False, operation_id) print("Updated object") print(data) @@ -804,7 +842,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(CHARGING_RULE, json_data, False, operation_id) + data = databaseClient.CreateObj(CHARGING_RULE, json_data, False, operation_id) return data, 200 except Exception as E: @@ -818,7 +856,7 @@ def get(self): '''Get all Charging Rules''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(CHARGING_RULE, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(CHARGING_RULE, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -829,7 +867,7 @@ class PyHSS_EIR_Get(Resource): def get(self, eir_id): '''Get all EIR data for specified eir_id''' try: - eir_data = database.GetObj(EIR, eir_id) + eir_data = databaseClient.GetObj(EIR, eir_id) return eir_data, 200 except Exception as E: print(E) @@ -840,7 +878,7 @@ def delete(self, eir_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(EIR, eir_id, False, operation_id) + data = databaseClient.DeleteObj(EIR, eir_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -855,7 +893,7 @@ def patch(self, eir_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(EIR, json_data, eir_id, False, operation_id) + data = databaseClient.UpdateObj(EIR, json_data, eir_id, False, operation_id) print("Updated object") print(data) @@ -875,7 +913,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(EIR, json_data, False, operation_id) + data = databaseClient.CreateObj(EIR, json_data, False, operation_id) return data, 200 except Exception as E: @@ -887,11 +925,11 @@ class PyHSS_EIR_HISTORY(Resource): def get(self, attribute): '''Get history for IMSI or IMEI''' try: - data = database.Get_IMEI_IMSI_History(attribute=attribute) + data = databaseClient.Get_IMEI_IMSI_History(attribute=attribute) #Add device info for each entry data_w_device_info = [] for record in data: - record['imei_result'] = database.get_device_info_from_TAC(imei=str(record['imei'])) + record['imei_result'] = databaseClient.get_device_info_from_TAC(imei=str(record['imei'])) data_w_device_info.append(record) return data_w_device_info, 200 except Exception as E: @@ -901,9 +939,9 @@ def get(self, attribute): def delete(self, attribute): '''Get Delete for IMSI or IMEI''' try: - data = database.Get_IMEI_IMSI_History(attribute=attribute) + data = databaseClient.Get_IMEI_IMSI_History(attribute=attribute) for record in data: - database.DeleteObj(IMSI_IMEI_HISTORY, record['imsi_imei_history_id']) + databaseClient.DeleteObj(IMSI_IMEI_HISTORY, record['imsi_imei_history_id']) return data, 200 except Exception as E: print(E) @@ -916,7 +954,7 @@ def get(self): '''Get EIR history for all subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(IMSI_IMEI_HISTORY, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(IMSI_IMEI_HISTORY, args['page'], args['page_size']) for record in data: record['imsi'] = record['imsi_imei'].split(',')[0] record['imei'] = record['imsi_imei'].split(',')[1] @@ -932,7 +970,7 @@ def get(self): '''Get all EIR Rules''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(EIR, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(EIR, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -943,13 +981,12 @@ class PyHSS_EIR_TAC(Resource): def get(self, imei): '''Get Device Info from IMEI''' try: - data = database.get_device_info_from_TAC(imei=imei) + data = databaseClient.get_device_info_from_TAC(imei=imei) return (data), 200 except Exception as E: print(E) return handle_exception(E) - @ns_subscriber_attributes.route('/list') class PyHSS_Subscriber_Attributes_All(Resource): @ns_subscriber_attributes.expect(paginatorParser) @@ -957,7 +994,7 @@ def get(self): '''Get all Subscriber Attributes''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(SUBSCRIBER_ATTRIBUTES, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(SUBSCRIBER_ATTRIBUTES, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -968,7 +1005,7 @@ class PyHSS_Attributes_Get(Resource): def get(self, subscriber_id): '''Get all attributes / values for specified Subscriber ID''' try: - apn_data = database.Get_Subscriber_Attributes(subscriber_id) + apn_data = databaseClient.Get_Subscriber_Attributes(subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -981,7 +1018,7 @@ def delete(self, subscriber_attributes_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(SUBSCRIBER_ATTRIBUTES, subscriber_attributes_id, False, operation_id) + data = databaseClient.DeleteObj(SUBSCRIBER_ATTRIBUTES, subscriber_attributes_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -996,7 +1033,7 @@ def patch(self, subscriber_attributes_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER_ATTRIBUTES, json_data, subscriber_attributes_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER_ATTRIBUTES, json_data, subscriber_attributes_id, False, operation_id) print("Updated object") print(data) @@ -1016,7 +1053,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER_ATTRIBUTES, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER_ATTRIBUTES, json_data, False, operation_id) return data, 200 except Exception as E: @@ -1030,7 +1067,7 @@ def get(self): '''Get all Operation Logs''' try: args = paginatorParser.parse_args() - OperationLogs = database.get_all_operation_logs(args['page'], args['page_size']) + OperationLogs = databaseClient.get_all_operation_logs(args['page'], args['page_size']) return OperationLogs, 200 except Exception as E: print(E) @@ -1041,7 +1078,7 @@ class PyHSS_Operation_Log_Last(Resource): def get(self): '''Get the most recent Operation Log''' try: - OperationLogs = database.get_last_operation_log() + OperationLogs = databaseClient.get_last_operation_log() return OperationLogs, 200 except Exception as E: print(E) @@ -1054,7 +1091,7 @@ def get(self, table_name): '''Get all Operation Logs for a given table''' try: args = paginatorParser.parse_args() - OperationLogs = database.get_all_operation_logs_by_table(table_name, args['page'], args['page_size']) + OperationLogs = databaseClient.get_all_operation_logs_by_table(table_name, args['page'], args['page_size']) return OperationLogs, 200 except Exception as E: print(E) @@ -1063,11 +1100,153 @@ def get(self, table_name): @ns_oam.route('/diameter_peers') class PyHSS_OAM_Peers(Resource): def get(self): - '''Get all Diameter Peers''' + '''Get active Diameter Peers''' try: - logObj = logtool.LogTool() - DiameterPeers = logObj.GetDiameterPeers() - return DiameterPeers, 200 + diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers")) + return diameterPeers, 200 + except Exception as E: + logTool.log(service='API', level='error', message=f"[API] An error occurred: {traceback.format_exc()}", redisClient=redisMessaging) + print(E) + return handle_exception(E) + +@ns_oam.route('/deregister/') +class PyHSS_OAM_Deregister(Resource): + def get(self, imsi): + '''Deregisters a given IMSI from the entire network.''' + try: + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) + imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) + subscriberId = subscriberInfo.get('subscriber_id', None) + servingMmePeer = subscriberInfo.get('serving_mme_peer', None) + servingMme = subscriberInfo.get('serving_mme', None) + servingMmeRealm = subscriberInfo.get('serving_mme_realm', None) + servingScscf = subscriberInfo.get('scscf', None) + servingScscfPeer = imsSubscriberInfo.get('scscf_peer', None) + servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) + + if servingMmePeer is not None and servingMmeRealm is not None and servingMme is not None: + if ';' in servingMmePeer: + servingMmePeer = servingMmePeer.split(';')[0] + + # Send the CLR to the serving MME + diameterClient.sendDiameterRequest( + requestType='CLR', + hostname=servingMmePeer, + imsi=imsi, + DestinationHost=servingMme, + DestinationRealm=servingMmeRealm, + CancellationType=2 + ) + + #Broadcast the CLR to all connected MME's, regardless of whether the subscriber is attached. + diameterClient.broadcastDiameterRequest( + requestType='CLR', + peerType='MME', + imsi=imsi, + DestinationHost=servingMme, + DestinationRealm=servingMmeRealm, + CancellationType=2 + ) + + databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) + + if servingScscfPeer is not None and servingScscfRealm is not None and servingScscf is not None: + if ';' in servingScscfPeer: + servingScscfPeer = servingScscfPeer.split(';')[0] + servingScscf = servingScscf.replace('sip:', '') + if ';' in servingScscf: + servingScscf = servingScscf.split(';')[0] + diameterClient.sendDiameterRequest( + requestType='RTR', + peerType=servingScscfPeer, + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + + #Broadcast the RTR to all connected SCSCF's, regardless of whether the subscriber is attached. + diameterClient.broadcastDiameterRequest( + requestType='RTR', + peerType='SCSCF', + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + + databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + + # If a subscriber has an active serving apn, grab the pcrf session id for that apn and send a CCR-T, then a Registration Termination Request to the serving pgw peer. + if subscriberId is not None: + servingApns = databaseClient.Get_Serving_APNs(subscriber_id=subscriberId) + if len(servingApns.get('apns', {})) > 0: + for apnKey, apnDict in servingApns['apns'].items(): + pcrfSessionId = None + servingPgwPeer = None + servingPgwRealm = None + servingPgw = None + for apnDataKey, apnDataValue in servingApns['apns'][apnKey].items(): + if apnDataKey == 'pcrf_session_id': + pcrfSessionId = apnDataValue + if apnDataKey == 'serving_pgw_peer': + servingPgwPeer = apnDataValue + if apnDataKey == 'serving_pgw_realm': + servingPgwRealm = apnDataValue + if apnDataKey == 'serving_pgw': + servingPgwRealm = apnDataValue + + if pcrfSessionId is not None and servingPgwPeer is not None and servingPgwRealm is not None and servingPgw is not None: + if ';' in servingPgwPeer: + servingPgwPeer = servingPgwPeer.split(';')[0] + + diameterClient.sendDiameterRequest( + requestType='CCR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId=pcrfSessionId, + domain=servingPgwRealm + ) + + diameterClient.sendDiameterRequest( + requestType='RTR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + diameterClient.broadcastDiameterRequest( + requestType='CCR', + peerType='PGW', + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId = pcrfSessionId, + domain=servingPgwRealm + ) + + diameterClient.broadcastDiameterRequest( + requestType='RTR', + peerType='PGW', + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + databaseClient.Update_Serving_APN(imsi=imsi, apn=apnKey, pcrf_session_id=None, serving_pgw=None, subscriber_routing='') + + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) + imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) + servingApns = databaseClient.Get_Serving_APNs(subscriber_id=subscriberId) + + return {'subscriber': subscriberInfo, 'ims_subscriber': imsSubscriberInfo, 'pcrf': servingApns}, 200 except Exception as E: print(E) return handle_exception(E) @@ -1089,7 +1268,7 @@ class PyHSS_OAM_Rollback_Last(Resource): def get(self): '''Undo the last Insert/Update/Delete operation''' try: - RollbackResponse = database.rollback_last_change() + RollbackResponse = databaseClient.rollback_last_change() return RollbackResponse, 200 except Exception as E: print(E) @@ -1101,7 +1280,7 @@ class PyHSS_OAM_Rollback_Last_Table(Resource): def get(self, operation_id): '''Undo the last Insert/Update/Delete operation for a given operation id''' try: - RollbackResponse = database.rollback_change_by_operation_id(operation_id) + RollbackResponse = databaseClient.rollback_change_by_operation_id(operation_id) return RollbackResponse, 200 except Exception as E: print(E) @@ -1112,7 +1291,7 @@ class PyHSS_OAM_Serving_Subs(Resource): def get(self): '''Get all Subscribers served by HSS''' try: - data = database.Get_Served_Subscribers() + data = databaseClient.Get_Served_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1124,7 +1303,7 @@ class PyHSS_OAM_Serving_Subs_PCRF(Resource): def get(self): '''Get all Subscribers served by PCRF''' try: - data = database.Get_Served_PCRF_Subscribers() + data = databaseClient.Get_Served_PCRF_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1136,7 +1315,7 @@ class PyHSS_OAM_Serving_Subs_IMS(Resource): def get(self): '''Get all Subscribers served by IMS''' try: - data = database.Get_Served_IMS_Subscribers() + data = databaseClient.Get_Served_IMS_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1148,16 +1327,15 @@ class PyHSS_OAM_Reconcile_IMS(Resource): def get(self, imsi): '''Get current location of IMS Subscriber from all linked HSS nodes''' response_dict = {} - import requests try: #Get local database result - local_result = database.Get_IMS_Subscriber(imsi=imsi) + local_result = databaseClient.Get_IMS_Subscriber(imsi=imsi) response_dict['localhost'] = {} for keys in local_result: if 'cscf' in keys: response_dict['localhost'][keys] = local_result[keys] - for remote_HSS in yaml_config['geored']['sync_endpoints']: + for remote_HSS in config['geored']['sync_endpoints']: print("Pulling data from remote HSS: " + str(remote_HSS)) try: response = requests.get(remote_HSS + '/ims_subscriber/ims_subscriber_imsi/' + str(imsi)) @@ -1200,9 +1378,9 @@ def get(self, imsi): serving_sub_final['apns'] = {} #Resolve Subscriber ID - subscriber_data = database.Get_Subscriber(imsi=str(imsi)) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(imsi)) print("subscriber_data: " + str(subscriber_data)) - serving_sub_final['subscriber_data'] = database.Sanitize_Datetime(subscriber_data) + serving_sub_final['subscriber_data'] = databaseClient.Sanitize_Datetime(subscriber_data) #Split the APN list into a list apn_list = subscriber_data['apn_list'].split(',') @@ -1220,11 +1398,11 @@ def get(self, imsi): #Get APN ID from APN for list_apn_id in apn_list: print("Getting APN ID " + str(list_apn_id)) - apn_data = database.Get_APN(list_apn_id) + apn_data = databaseClient.Get_APN(list_apn_id) print(apn_data) try: serving_sub_final['apns'][str(apn_data['apn'])] = {} - serving_sub_final['apns'][str(apn_data['apn'])] = database.Sanitize_Datetime(database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=list_apn_id)) + serving_sub_final['apns'][str(apn_data['apn'])] = databaseClient.Sanitize_Datetime(databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=list_apn_id)) except: serving_sub_final['apns'][str(apn_data['apn'])] = {} print("Failed to get Serving APN for APN ID " + str(list_apn_id)) @@ -1246,7 +1424,7 @@ def get(self, imsi, apn_id): apn_id_final = None #Resolve Subscriber ID - subscriber_data = database.Get_Subscriber(imsi=str(imsi)) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(imsi)) print("subscriber_data: " + str(subscriber_data)) #Split the APN list into a list @@ -1265,14 +1443,14 @@ def get(self, imsi, apn_id): for list_apn_id in apn_list: print("Getting APN ID " + str(list_apn_id) + " to see if it matches APN " + str(apn_id)) #Get each APN in List - apn_data = database.Get_APN(list_apn_id) + apn_data = databaseClient.Get_APN(list_apn_id) print(apn_data) if str(apn_data['apn_id']).lower() == str(apn_id).lower(): print("Matched named APN with APN ID") apn_id_final = apn_data['apn_id'] - data = database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=apn_id_final) - data = database.Sanitize_Datetime(data) + data = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=apn_id_final) + data = databaseClient.Sanitize_Datetime(data) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -1290,37 +1468,45 @@ def put(self): json_data = request.get_json(force=True) print("JSON Data sent: " + str(json_data)) #Get IMSI - subscriber_data = database.Get_Subscriber(imsi=str(json_data['imsi'])) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(json_data['imsi'])) print("subscriber_data: " + str(subscriber_data)) #Get PCRF Session - pcrf_session_data = database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) - print("pcrf_session_data: " + str(pcrf_session_data)) + servingApn = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) + print("pcrf_session_data: " + str(servingApn)) #Get Charging Rules - ChargingRule = database.Get_Charging_Rule(json_data['charging_rule_id']) - ChargingRule['apn_data'] = database.Get_APN(json_data['apn_id']) + ChargingRule = databaseClient.Get_Charging_Rule(json_data['charging_rule_id']) + ChargingRule['apn_data'] = databaseClient.Get_APN(json_data['apn_id']) print("Got ChargingRule: " + str(ChargingRule)) - diameter_host = yaml_config['hss']['OriginHost'] #Diameter Host of this Machine - OriginRealm = yaml_config['hss']['OriginRealm'] - DestinationRealm = OriginRealm - mcc = yaml_config['hss']['MCC'] #Mobile Country Code - mnc = yaml_config['hss']['MNC'] #Mobile Network Code - import diameter - diameter = diameter.Diameter(diameter_host, DestinationRealm, 'PyHSS-client-API', str(mcc), str(mnc)) - diam_hex = diameter.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') - import time - logObj = logtool.LogTool() - logObj.Async_SendRequest(diam_hex, str(pcrf_session_data['serving_pgw'])) - return diam_hex, 200 + subscriberId = subscriber_data.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + ueIp = servingApn.get('subscriber_routing', None) + + diameterResponse = diameterClient.sendDiameterRequest( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + chargingRules=ChargingRule, + ueIp=ueIp, + servingPgw=servingPgw, + servingRealm=servingPgwRealm + ) + + result = {"Result": "Successfully sent Gx RAR", "destinationClients": str(servingPgw)} + return result, 200 @ns_pcrf.route('/') class PyHSS_PCRF_Complete(Resource): def get(self, charging_rule_id): '''Get full Charging Rule + TFTs''' try: - data = database.Get_Charging_Rule(charging_rule_id) + data = databaseClient.Get_Charging_Rule(charging_rule_id) return data, 200 except Exception as E: print(E) @@ -1331,38 +1517,43 @@ class PyHSS_PCRF_SUBSCRIBER_ROUTING(Resource): def get(self, subscriber_routing): '''Get Subscriber info from Subscriber Routing''' try: - data = database.Get_UE_by_IP(subscriber_routing) + data = databaseClient.Get_UE_by_IP(subscriber_routing) return data, 200 except Exception as E: print(E) return handle_exception(E) @ns_geored.route('/') - class PyHSS_Geored(Resource): @ns_geored.doc('Create ChargingRule Object') @ns_geored.expect(GeoRed_model) - # @metrics.counter('flask_http_geored_pushes', 'Count of GeoRed Pushes to this API', - # labels={'status': lambda r: r.status_code, 'source_endpoint': lambda r: r.remote_addr}) @no_auth_required def patch(self): '''Get Geored data Pushed''' try: json_data = request.get_json(force=True) print("JSON Data sent in Geored request: " + str(json_data)) - #Determine what actions to take / update based on keys returned response_data = [] if 'serving_mme' in json_data: print("Updating serving MME") - response_data.append(database.Update_Serving_MME(imsi=str(json_data['imsi']), serving_mme=json_data['serving_mme'], serving_mme_realm=json_data['serving_mme_realm'], serving_mme_peer=json_data['serving_mme_peer'], propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='HSS', geored_host=request.remote_addr).inc() + response_data.append(databaseClient.Update_Serving_MME(imsi=str(json_data['imsi']), serving_mme=json_data['serving_mme'], serving_mme_realm=json_data['serving_mme_realm'], serving_mme_peer=json_data['serving_mme_peer'], propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "HSS", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'serving_apn' in json_data: print("Updating serving APN") if 'serving_pgw_realm' not in json_data: json_data['serving_pgw_realm'] = None if 'serving_pgw_peer' not in json_data: json_data['serving_pgw_peer'] = None - response_data.append(database.Update_Serving_APN( + if 'serving_pgw_timestamp' not in json_data: + json_data['serving_pgw_timestamp'] = None + response_data.append(databaseClient.Update_Serving_APN( imsi=str(json_data['imsi']), apn=json_data['serving_apn'], pcrf_session_id=json_data['pcrf_session_id'], @@ -1370,20 +1561,63 @@ def patch(self): subscriber_routing=json_data['subscriber_routing'], serving_pgw_realm=json_data['serving_pgw_realm'], serving_pgw_peer=json_data['serving_pgw_peer'], + serving_pgw_timestamp=json_data['serving_pgw_timestamp'], propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='PCRF', geored_host=request.remote_addr).inc() + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "PCRF", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'scscf' in json_data: - print("Updating serving SCSCF") + print("Updating Serving SCSCF") if 'scscf_realm' not in json_data: json_data['scscf_realm'] = None if 'scscf_peer' not in json_data: json_data['scscf_peer'] = None - response_data.append(database.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='IMS', geored_host=request.remote_addr).inc() + if 'scscf_timestamp' not in json_data: + json_data['scscf_timestamp'] = None + response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=json_data['scscf_realm'], scscf_peer=json_data['scscf_peer'], scscf_timestamp=json_data['scscf_timestamp'], propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMS_SCSCF", + "geored_host": request.remote_addr, + }, + metricExpiry=60) + if 'pcscf' in json_data: + print("Updating Proxy SCSCF") + if 'pcscf_realm' not in json_data: + json_data['pcscf_realm'] = None + if 'pcscf_peer' not in json_data: + json_data['pcscf_peer'] = None + if 'pcscf_timestamp' not in json_data: + json_data['pcscf_timestamp'] = None + if 'pcscf_active_session' not in json_data: + json_data['pcscf_active_session'] = None + response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=json_data['pcscf_realm'], pcscf_peer=json_data['pcscf_peer'], pcscf_timestamp=json_data['pcscf_timestamp'], pcscf_active_session=json_data['pcscf_active_session'], propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMS_PCSCF", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'imei' in json_data: print("Updating EIR") - response_data.append(database.Store_IMSI_IMEI_Binding(str(json_data['imsi']), str(json_data['imei']), str(json_data['match_response_code']), propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='EIR', geored_host=request.remote_addr).inc() + response_data.append(databaseClient.Store_IMSI_IMEI_Binding(str(json_data['imsi']), str(json_data['imei']), str(json_data['match_response_code']), propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMEI", + "geored_host": request.remote_addr, + }, + metricExpiry=60) return response_data, 200 except Exception as E: print("Exception when updating: " + str(E)) @@ -1402,34 +1636,68 @@ def get(self): response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored Schema: " + str(E)} return response_json +@ns_geored.route('/peers') +class PyHSS_Geored_Peers(Resource): + def get(self): + '''Return the configured geored peers''' + try: + georedEnabled = config.get('geored', {}).get('enabled', False) + if not georedEnabled: + return {'result': 'Failed', 'Reason' : "Geored not enabled"} + georedPeers = config.get('geored', {}).get('endpoints', []) + return {'peers': georedPeers}, 200 + except Exception as E: + print("Exception when returning geored peers: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored peers: " + str(E)} + return response_json + +@ns_geored.route('/webhooks') +class PyHSS_Geored_Webhooks(Resource): + def get(self): + '''Return the configured geored webhooks''' + try: + georedEnabled = config.get('webhooks', {}).get('enabled', False) + if not georedEnabled: + return {'result': 'Failed', 'Reason' : "Webhooks not enabled"} + georedWebhooks = config.get('webhooks', {}).get('endpoints', []) + return {'endpoints': georedWebhooks}, 200 + except Exception as E: + print("Exception when returning geored webhooks: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored webhooks: " + str(E)} + return response_json + @ns_push.route('/clr/') class PyHSS_Push_CLR(Resource): @ns_push.expect(Push_CLR_Model) @ns_push.doc('Push CLR (Cancel Location Request) to MME') def put(self, imsi): - '''Push CLR (Cancel Location Request) to MME''' - json_data = request.get_json(force=True) - print("JSON Data sent: " + str(json_data)) - if 'DestinationHost' not in json_data: - json_data['DestinationHost'] = None - import diameter - diameter = diameter.Diameter( - OriginHost=yaml_config['hss']['OriginHost'], - OriginRealm=yaml_config['hss']['OriginRealm'], - MNC=yaml_config['hss']['MNC'], - MCC=yaml_config['hss']['MCC'], - ProductName='PyHSS-client-API' - ) - diam_hex = diameter.Request_16777251_317( - imsi=imsi, - DestinationHost=json_data['DestinationHost'], - DestinationRealm=json_data['DestinationRealm'], - CancellationType=json_data['cancellationType'] - ) - logObj = logtool.LogTool() - logObj.Async_SendRequest(diam_hex, str(json_data['diameterPeer'])) - return diam_hex, 200 + try: + '''Push CLR (Cancel Location Request) to MME''' + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + if 'DestinationHost' not in json_data: + json_data['DestinationHost'] = None + diameterRequest = diameterClient.sendDiameterRequest( + requestType='CLR', + hostname=json_data['diameterPeer'], + imsi=imsi, + DestinationHost=json_data['DestinationHost'], + DestinationRealm=json_data['DestinationRealm'], + CancellationType=json_data['cancellationType'] + ) + if not len(diameterRequest) > 0: + return {'result': f'Failed queueing CLR to {json_data["diameterPeer"]}'}, 400 + + subscriber_details = databaseClient.Get_Subscriber(imsi=str(imsi)) + if subscriber_details['serving_mme'] == json_data['DestinationHost']: + databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) + + return {'result': f'Successfully queued CLR to {json_data["diameterPeer"]}'}, 200 + except Exception as E: + print("Exception when sending CLR: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to send CLR: " + str(E)} + return response_json if __name__ == '__main__': - app.run(debug=False) + apiService.run(debug=False, host='0.0.0.0', port=8080) diff --git a/services/diameterService.py b/services/diameterService.py new file mode 100644 index 0000000..958ec72 --- /dev/null +++ b/services/diameterService.py @@ -0,0 +1,303 @@ +import asyncio +import sys, os, json +import time, yaml, uuid +from datetime import datetime +sys.path.append(os.path.realpath('../lib')) +from messagingAsync import RedisMessagingAsync +from diameterAsync import DiameterAsync +from banners import Banners +from logtool import LogTool +import traceback + +class DiameterService: + """ + PyHSS Diameter Service + A class for handling diameter inbounds and replies on Port 3868, via TCP. + Functions in this class are high-performance, please edit with care. Last profiled October 6th, 2023. + """ + + def __init__(self): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Diameter] [__init__] Fatal Error - config.yaml not found, exiting.") + quit() + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisReaderMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisWriterMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisPeerMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.banners = Banners() + self.logTool = LogTool(config=self.config) + self.diameterLibrary = DiameterAsync(logTool=self.logTool) + self.activePeers = {} + self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) + self.benchmarking = self.config.get('benchmarking', {}).get('enabled', False) + self.benchmarkingInterval = self.config.get('benchmarking', {}).get('reporting_interval', 3600) + self.diameterRequests = 0 + self.diameterResponses = 0 + self.workerPoolSize = int(self.config.get('hss', {}).get('diameter_service_workers', 10)) + + async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: + """ + Asynchronously validates a given diameter inbound, and increments the 'Number of Diameter Inbounds' metric. + """ + try: + packetVars, avps = await(self.diameterLibrary.decodeDiameterPacket(inboundData)) + originHost = (await(self.diameterLibrary.getAvpData(avps, 264)))[0] + originHost = bytes.fromhex(originHost).decode("utf-8") + peerType = await(self.diameterLibrary.getPeerType(originHost)) + self.activePeers[f"{clientAddress}-{clientPort}"].update({'diameterHostname': originHost, + 'peerType': (peerType if peerType != None else 'Unknown'), + }) + return True + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}")) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] AVPs: {avps}\nPacketVars: {packetVars}")) + return False + + async def handleActiveDiameterPeers(self): + """ + Prunes stale entries from self.activePeers, and + keeps the ActiveDiameterPeers key in Redis current. + """ + while True: + try: + if not len(self.activePeers) > 0: + await(asyncio.sleep(1)) + continue + + activeDiameterPeersTimeout = self.config.get('hss', {}).get('active_diameter_peers_timeout', 3600) + + stalePeers = [] + + for key, connection in self.activePeers.items(): + if connection.get('connectionStatus', '') == 'disconnected': + if (datetime.now() - datetime.strptime(connection['disconnectTimestamp'], "%Y-%m-%d %H:%M:%S")).seconds > activeDiameterPeersTimeout: + stalePeers.append(key) + + if len(stalePeers) > 0: + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [handleActiveDiameterPeers] Pruning disconnected peers: {stalePeers}")) + for key in stalePeers: + del self.activePeers[key] + await(self.logActivePeers()) + + await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) + + await(asyncio.sleep(1)) + except Exception as e: + print(e) + await(asyncio.sleep(1)) + continue + + async def logActivePeers(self): + """ + Logs the number of active connections on a rolling basis. + """ + activePeers = self.activePeers + if not len(activePeers) > 0: + activePeers = '' + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActivePeers] {len(self.activePeers)} Active Peers {activePeers}")) + + async def logProcessedMessages(self): + """ + Logs the number of processed messages on a rolling basis. + """ + if not self.benchmarking: + return False + + benchmarkInterval = int(self.benchmarkingInterval) + + while True: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logProcessedMessages] Processed {self.diameterRequests} inbound diameter messages in the last {self.benchmarkingInterval} second(s)")) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logProcessedMessages] Processed {self.diameterResponses} outbound in the last {self.benchmarkingInterval} second(s)")) + self.diameterRequests = 0 + self.diameterResponses = 0 + await(asyncio.sleep(benchmarkInterval)) + + async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: + """ + Reads incoming data from a connected client. Data is sent to a shared memory-based queue, to be polled and processed by a worker coroutine. + Terminates the connection if the client disconnects, the queue fills or another exception occurs. + """ + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}")) + clientConnection = f"{clientAddress}-{clientPort}" + while True: + try: + + inboundData = await(asyncio.wait_for(reader.read(8192), timeout=socketTimeout)) + + if reader.at_eof(): + return False + + if len(inboundData) > 0: + self.sharedQueue.put_nowait({"diameter-inbound": inboundData, "inbound-received-timestamp": time.time(), "clientAddress": clientAddress, "clientPort": clientPort}) + + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}")) + return False + + async def inboundDataWorker(self, coroutineUuid: str) -> bool: + """ + Collects messages from the memory queue, performs peer validation and fires off to redis every 0.1 seconds. + """ + batchInterval = 0.1 + inboundQueueName = f"diameter-inbound" + while True: + try: + nextSendTime = time.time() + batchInterval + messageList = [] + while time.time() < nextSendTime: + try: + inboundData = await(asyncio.wait_for(self.sharedQueue.get(), timeout=nextSendTime - time.time())) + inboundHex = inboundData.get('diameter-inbound', '').hex() + inboundData['diameter-inbound'] = inboundHex + clientAddress = inboundData.get('clientAddress', '') + clientPort = inboundData.get('clientPort', '') + + if len(self.activePeers.get(f'{clientAddress}-{clientPort}', {}).get('peerType', '')) == 0: + if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundHex)): + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) + continue + else: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Validated peer: {clientAddress} on port {clientPort}")) + + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Queueing to redis: {inboundData}")) + messageList.append(json.dumps(inboundData)) + if self.benchmarking: + self.diameterRequests += 1 + except asyncio.TimeoutError: + break + + if messageList: + await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout) + messageList = [] + + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Exception for inboundDataWorker, continuing.\n{e}")) + pass + + async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: + """ + Waits for a message to be received from Redis, then sends to the connected client. + """ + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}")) + while not writer.transport.is_closing(): + try: + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Waiting for messages for host {clientAddress} on port {clientPort}")) + pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}")))[1]) + diameterOutboundBinary = bytes.fromhex(pendingOutboundMessage.get('diameter-outbound', '')) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) + + writer.write(diameterOutboundBinary) + await(writer.drain()) + if self.benchmarking: + self.diameterResponses += 1 + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.")) + return False + + async def handleConnection(self, reader, writer): + """ + For each new connection on port 3868, create an asynchronous reader and writer, and handle adding and updating self.activePeers. + If a reader or writer returns false, ensure that the connection is torn down entirely. + """ + try: + coroutineUuid = str(uuid.uuid4()) + (clientAddress, clientPort) = writer.get_extra_info('peername') + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}")) + if f"{clientAddress}-{clientPort}" not in self.activePeers: + self.activePeers[f"{clientAddress}-{clientPort}"] = { + "connectTimestamp": '', + "disconnectTimestamp": '', + "reconnectionCount": 0, + "ipAddress":'', + "port":'', + "connectionStatus": '', + "diameterHostname": '', + "peerType": '', + } + else: + reconnectionCount = self.activePeers.get(f"{clientAddress}-{clientPort}", {}).get('reconnectionCount', 0) + reconnectionCount += 1 + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "reconnectionCount": reconnectionCount + }) + + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "connectTimestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "ipAddress":clientAddress, + "port": clientPort, + "connectionStatus": 'connected', + }) + await(self.logActivePeers()) + + readTask = asyncio.create_task(self.readInboundData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) + writeTask = asyncio.create_task(self.writeOutboundData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) + + completeTasks, pendingTasks = await(asyncio.wait([readTask, writeTask], return_when=asyncio.FIRST_COMPLETED)) + + for pendingTask in pendingTasks: + try: + pendingTask.cancel() + await(asyncio.sleep(0.1)) + except asyncio.CancelledError: + pass + + writer.close() + await(writer.wait_closed()) + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "connectionStatus": 'disconnected', + "disconnectTimestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + }) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.")) + await(self.logActivePeers()) + return + + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}")) + return + + async def startServer(self, host: str=None, port: int=None, type: str=None): + """ + Start a server with the given parameters and handle new clients with self.handleConnection. + Also create a single instance of self.handleActiveDiameterPeers and self.logProcessedMessages. + """ + + self.sharedQueue = asyncio.Queue(maxsize=1024) + + for i in range(self.workerPoolSize): + asyncio.create_task(self.inboundDataWorker(coroutineUuid=f'inboundDataWorker-{i}')) + + if host is None: + host=str(self.config.get('hss', {}).get('bind_ip', '0.0.0.0')[0]) + + if port is None: + port=int(self.config.get('hss', {}).get('bind_port', 3868)) + + if type is None: + type=str(self.config.get('hss', {}).get('transport', 'TCP')) + + self.socketTimeout = int(self.config.get('hss', {}).get('client_socket_timeout', 300)) + + if type.upper() == 'TCP': + server = await(asyncio.start_server(self.handleConnection, host, port)) + else: + return False + servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}")) + handleActiveDiameterPeerTask = asyncio.create_task(self.handleActiveDiameterPeers()) + if self.benchmarking: + logProcessedMessagesTask = asyncio.create_task(self.logProcessedMessages()) + + async with server: + await(server.serve_forever()) + + +if __name__ == '__main__': + diameterService = DiameterService() + asyncio.run(diameterService.startServer()) \ No newline at end of file diff --git a/services/georedService.py b/services/georedService.py new file mode 100644 index 0000000..861e8b8 --- /dev/null +++ b/services/georedService.py @@ -0,0 +1,387 @@ +import os, sys, json, yaml +import uuid, time +import asyncio, aiohttp +sys.path.append(os.path.realpath('../lib')) +from messagingAsync import RedisMessagingAsync +from banners import Banners +from logtool import LogTool + +class GeoredService: + """ + PyHSS Geored Service + Handles updating and sending webhooks to remote endpoints. + """ + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Geored] Fatal Error - config.yaml not found, exiting.") + quit() + self.logTool = LogTool(self.config) + self.banners = Banners() + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisGeoredMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisWebhookMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + + self.georedPeers = self.config.get('geored', {}).get('endpoints', []) + self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) + self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + + if not self.config.get('geored', {}).get('enabled'): + self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") + quit() + if self.georedPeers is not None: + if not (len(self.georedPeers) > 0): + self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") + quit() + + async def sendGeored(self, asyncSession, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + """ + Sends a Geored HTTP request to a given endpoint. + """ + if self.benchmarking: + startTime = time.perf_counter() + operation = operation.upper() + requestOperations = {'GET': asyncSession.get, 'PUT': asyncSession.put, 'POST': asyncSession.post, 'PATCH':asyncSession.patch, 'DELETE': asyncSession.delete} + + if not url or not operation or not body: + return False + + if operation not in requestOperations: + return False + + headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} + + for attempt in range(retryCount): + try: + responseStatusCode = None + responseBody = None + + if operation in ['PUT', 'POST', 'PATCH']: + async with requestOperations[operation](url, json=body, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + else: + async with requestOperations[operation](url, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + + if 200 <= responseStatusCode <= 299: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}")) + + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(responseStatusCode), + "error": ""}, + metricExpiry=60)) + break + else: + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(responseStatusCode), + "error": str(response.reason)}, + metricExpiry=60)) + except aiohttp.ClientConnectionError as e: + error_message = str(e) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + if "Name or service not known" in error_message: + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "No matching DNS entry found"}, + metricExpiry=60)) + else: + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "Connection Refused"}, + metricExpiry=60)) + except aiohttp.ServerTimeoutError: + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "Timeout"}, + metricExpiry=60)) + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": e}, + metricExpiry=60)) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendGeored] Time taken to send individual geored request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + return True + + async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + """ + Sends a Webhook HTTP request to a given endpoint. + """ + if self.benchmarking: + startTime = time.perf_counter() + operation = operation.upper() + requestOperations = {'GET': asyncSession.get, 'PUT': asyncSession.put, 'POST': asyncSession.post, 'PATCH':asyncSession.patch, 'DELETE': asyncSession.delete} + + if not url or not operation or not body or not headers: + return False + + if operation not in requestOperations: + return False + + for attempt in range(retryCount): + try: + responseStatusCode = None + responseBody = None + + if operation in ['PUT', 'POST', 'PATCH']: + async with requestOperations[operation](url, json=body, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + else: + async with requestOperations[operation](url, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + + if 200 <= responseStatusCode <= 299: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}")) + + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(responseStatusCode), + "error": ""}, + metricExpiry=60)) + break + else: + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(responseStatusCode), + "error": str(response.reason)}, + metricExpiry=60)) + except aiohttp.ClientConnectionError as e: + error_message = str(e) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + if "Name or service not known" in error_message: + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "No matching DNS entry found"}, + metricExpiry=60)) + else: + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "Connection Refused"}, + metricExpiry=60)) + except aiohttp.ServerTimeoutError: + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "Timeout"}, + metricExpiry=60)) + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendWebhook] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": e}, + metricExpiry=60)) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendWebhook] Time taken to send individual webhook request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + return True + + async def handleAsymmetricGeoredQueue(self): + """ + Collects and processes asymmetric geored messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1]) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) + + georedOperation = georedMessage['operation'] + georedBody = georedMessage['body'] + georedUrls = georedMessage['urls'] + georedTasks = [] + + for georedEndpoint in georedUrls: + georedTasks.append(self.sendGeored(asyncSession=session, url=georedEndpoint, operation=georedOperation, body=georedBody)) + await asyncio.gather(*georedTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleAsymmetricGeoredQueue] Time taken to send asymmetric geored message to specified peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + await(asyncio.sleep(0)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Error handling asymmetric geored queue: {e}")) + await(asyncio.sleep(0)) + continue + + async def handleGeoredQueue(self): + """ + Collects and processes queued geored messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1]) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) + + georedOperation = georedMessage['operation'] + georedBody = georedMessage['body'] + georedTasks = [] + + for remotePeer in self.georedPeers: + georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody)) + await asyncio.gather(*georedTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + await(asyncio.sleep(0)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}")) + await(asyncio.sleep(0)) + continue + + async def handleWebhookQueue(self): + """ + Collects and processes queued webhook messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1]) + + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) + + webhookHeaders = webhookMessage['headers'] + webhookOperation = webhookMessage['operation'] + webhookBody = webhookMessage['body'] + webhookTasks = [] + + for remotePeer in self.webhookPeers: + webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders)) + await asyncio.gather(*webhookTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + await(asyncio.sleep(0.001)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}")) + await(asyncio.sleep(0.001)) + continue + + async def startService(self): + """ + Performs sanity checks on configuration and starts the geored and webhook tasks, when enabled. + """ + await(self.logTool.logAsync(service='Geored', level='info', message=f"{self.banners.georedService()}")) + while True: + + georedEnabled = self.config.get('geored', {}).get('enabled', False) + webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) + + if self.georedPeers is not None: + if not len(self.georedPeers) > 0: + georedEnabled = False + + if self.webhookPeers is not None: + if not len(self.webhookPeers) > 0: + webhooksEnabled = False + + if not georedEnabled and not webhooksEnabled: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [startService] Geored and Webhook services both disabled or missing peers, exiting.")) + sys.exit() + + activeTasks = [] + + if georedEnabled: + georedTask = asyncio.create_task(self.handleGeoredQueue()) + asymmetricGeoredTask = asyncio.create_task(self.handleAsymmetricGeoredQueue()) + activeTasks.append(georedTask) + activeTasks.append(asymmetricGeoredTask) + + if webhooksEnabled: + webhookTask = asyncio.create_task(self.handleWebhookQueue()) + activeTasks.append(webhookTask) + + completeTasks, pendingTasks = await(asyncio.wait(activeTasks, return_when=asyncio.FIRST_COMPLETED)) + + if len(pendingTasks) > 0: + for pendingTask in pendingTasks: + try: + pendingTask.cancel() + await(asyncio.sleep(0.001)) + except asyncio.CancelledError: + pass + + +if __name__ == '__main__': + georedService = GeoredService() + asyncio.run(georedService.startService()) diff --git a/services/hssService.py b/services/hssService.py new file mode 100644 index 0000000..46abcbc --- /dev/null +++ b/services/hssService.py @@ -0,0 +1,101 @@ +import os, sys, json, yaml, time, traceback +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from diameter import Diameter +from banners import Banners +from logtool import LogTool + +class HssService: + + def __init__(self): + + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[HSS] Fatal Error - config.yaml not found, exiting.") + quit() + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.logTool = LogTool(config=self.config) + self.banners = Banners() + self.mnc = self.config.get('hss', {}).get('MNC', '999') + self.mcc = self.config.get('hss', {}).get('MCC', '999') + self.originRealm = self.config.get('hss', {}).get('OriginRealm', f'mnc{self.mnc}.mcc{self.mcc}.3gppnetwork.org') + self.originHost = self.config.get('hss', {}).get('OriginHost', f'hss01') + self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') + self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) + self.diameterLibrary = Diameter(logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) + self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + + def handleQueue(self): + """ + Gets and parses inbound diameter requests, processes them and queues the response. + """ + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + + inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound') + + if inboundMessageList == None: + continue + for inboundMessage in inboundMessageList[1]: + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] Message: {inboundMessage}", redisClient=self.redisMessaging) + + inboundMessage = json.loads(inboundMessage.decode('ascii')) + inboundBinary = bytes.fromhex(inboundMessage.get('diameter-inbound', None)) + + if inboundBinary == None: + continue + inboundHost = inboundMessage.get('clientAddress', None) + inboundPort = inboundMessage.get('clientPort', None) + inboundTimestamp = inboundMessage.get('inbound-received-timestamp', None) + + try: + diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) + + if diameterOutbound == None: + continue + if not len(diameterOutbound) > 0: + continue + + diameterMessageTypeDict = self.diameterLibrary.getDiameterMessageType(binaryData=inboundBinary) + + if diameterMessageTypeDict == None: + continue + if not len(diameterMessageTypeDict) > 0: + continue + + diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') + diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') + except Exception as e: + self.logTool.log(service='HSS', level='warning', message=f"[HSS] [handleQueue] Failed to generate diameter outbound: {e}", redisClient=self.redisMessaging) + continue + + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) + + outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}" + outboundMessage = json.dumps({"diameter-outbound": diameterOutbound, "inbound-received-timestamp": inboundTimestamp}) + + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) + + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + if self.benchmarking: + self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) + + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[HSS] [handleQueue] Exception: {traceback.format_exc()}", redisClient=self.redisMessaging) + continue + + + +if __name__ == '__main__': + hssService = HssService() + hssService.handleQueue() \ No newline at end of file diff --git a/services/logService.py b/services/logService.py new file mode 100644 index 0000000..34e7ae0 --- /dev/null +++ b/services/logService.py @@ -0,0 +1,77 @@ +import os, sys, json, yaml +from datetime import datetime +import time +import logging +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from banners import Banners +from logtool import LogTool + +class LogService: + """ + PyHSS Log Service + A class for handling queued log entries in the Redis DB. + This class is synchronous and not high-performance. + """ + + def __init__(self): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Log] Fatal Error - config.yaml not found, exiting.") + quit() + self.logTool = LogTool(config=self.config) + self.banners = Banners() + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.logFilePaths = self.config.get('logging', {}).get('logfiles', {}) + self.logLevels = { + 'CRITICAL': {'verbosity': 1, 'logging': logging.CRITICAL}, + 'ERROR': {'verbosity': 2, 'logging': logging.ERROR}, + 'WARNING': {'verbosity': 3, 'logging': logging.WARNING}, + 'INFO': {'verbosity': 4, 'logging': logging.INFO}, + 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, + 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, + } + print(f"{self.banners.logService()}") + + def handleLogs(self): + """ + Continually polls the Redis DB for queued log files. Parses and writes log files to disk, using LogTool. + """ + activeLoggers = {} + while True: + try: + logMessage = json.loads(self.redisMessaging.awaitMessage(key='log')[1]) + + print(f"[Log] Message: {logMessage}") + + logFileMessage = logMessage['message'] + logService = logMessage.get('service').lower() + logLevel = logMessage.get('level').lower() + logTimestamp = logMessage.get('timestamp') + + if f"{logService}_logging_file" not in self.logFilePaths: + continue + + logFileName = f"{logService}_logging_file" + logFilePath = self.logFilePaths.get(logFileName, '/var/log/pyhss.log') + + if logService not in activeLoggers: + activeLoggers[logService] = self.logTool.setupFileLogger(loggerName=logService, logFilePath=logFilePath) + + fileLogger = activeLoggers[logService] + fileLogger.log(self.logLevels.get(logLevel.upper(), {}).get('logging', logging.INFO), logFileMessage, extra={'timestamp': float(logTimestamp)}) + + + except Exception as e: + self.logTool.log(service='Log', level='error', message=f"[Log] Error: {e}", redisClient=self.redisMessaging) + continue + +if __name__ == '__main__': + logService = LogService() + logService.handleLogs() \ No newline at end of file diff --git a/services/metricService.py b/services/metricService.py new file mode 100644 index 0000000..12d51c1 --- /dev/null +++ b/services/metricService.py @@ -0,0 +1,99 @@ +import asyncio +import sys, os, json +import time, json, yaml +from prometheus_client import make_wsgi_app, start_http_server, Counter, Gauge, Summary, Histogram, CollectorRegistry +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from flask import Flask +import threading +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from banners import Banners +from logtool import LogTool + +class MetricService: + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Metric] Fatal Error - config.yaml not found, exiting.") + quit() + + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.banners = Banners() + self.logTool = LogTool(config=self.config) + self.registry = CollectorRegistry(auto_describe=True) + self.logTool.log(service='Metric', level='info', message=f"{self.banners.metricService()}", redisClient=self.redisMessaging) + + def handleMetrics(self): + """ + Collects queued metrics from redis, and exposes them using prometheus_client. + """ + try: + actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} + prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} + + metric = self.redisMessaging.awaitMessage(key='metric')[1] + + self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) + prometheusJsonList = json.loads(metric) + + for prometheusJson in prometheusJsonList: + self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] {prometheusJson}", redisClient=self.redisMessaging) + if not all(key in prometheusJson for key in ('NAME', 'TYPE', 'ACTION', 'VALUE')): + raise ValueError('All fields are not available for parsing') + counterName = prometheusJson['NAME'] + counterType = prometheusTypes.get(prometheusJson['TYPE'].lower()) + counterAction = prometheusJson['ACTION'].lower() + counterValue = float(prometheusJson['VALUE']) + counterHelp = prometheusJson.get('HELP', '') + counterLabels = prometheusJson.get('LABELS', {}) + + if isinstance(counterLabels, list): + counterLabels = dict() + + if counterType is not None: + try: + counterRecord = counterType(counterName, counterHelp, labelnames=counterLabels.keys(), registry=self.registry) + if counterLabels: + counterRecord = counterRecord.labels(*counterLabels.values()) + except ValueError as e: + counterRecord = self.registry._names_to_collectors.get(counterName) + if counterLabels and counterRecord: + counterRecord = counterRecord.labels(*counterLabels.values()) + action = actions.get(counterAction) + if action is not None: + prometheusMethod = getattr(counterRecord, action) + prometheusMethod(counterValue) + else: + self.logTool.log(service='Metric', level='warn', message=f"[Metric] [handleMetrics] Invalid action '{counterAction}' in message: {metric}, skipping.", redisClient=self.redisMessaging) + continue + else: + self.logTool.log(service='Metric', level='warn', message=f"[Metric] [handleMetrics] Invalid type '{counterType}' in message: {metric}, skipping.", redisClient=self.redisMessaging) + continue + + except Exception as e: + self.logTool.log(service='Metric', level='error', message=f"[Metric] [handleMetrics] Unable to parse message: {metric}, due to {e}. Skipping.", redisClient=self.redisMessaging) + return + + + def getMetrics(self): + while True: + self.handleMetrics() + + +if __name__ == '__main__': + + metricService = MetricService() + metricServiceThread = threading.Thread(target=metricService.getMetrics) + metricServiceThread.start() + + prometheusWebClient = Flask(__name__) + prometheusWebClient.wsgi_app = DispatcherMiddleware(prometheusWebClient.wsgi_app, { + '/metrics': make_wsgi_app(registry=metricService.registry) + }) + + #Uncomment the statement below to run a local testing instance. + + prometheusWebClient.run(host='0.0.0.0', port=9191) \ No newline at end of file diff --git a/systemd/pyhss.service b/systemd/pyhss.service new file mode 100644 index 0000000..39ac30a --- /dev/null +++ b/systemd/pyhss.service @@ -0,0 +1,17 @@ +[Unit] +Description=PyHSS +After=network-online.target mysql.service +Wants=pyhss_diameter.service +Wants=pyhss_geored.service +Wants=pyhss_hss.service +Wants=pyhss_log.service +Wants=pyhss_metric.service + + +[Service] +Type=oneshot +ExecStart=/bin/true +RemainAfterExit=yes + +[Install] +WantedBy=multi-user.target \ No newline at end of file diff --git a/systemd/pyhss_diameter.service b/systemd/pyhss_diameter.service new file mode 100644 index 0000000..02ceaa4 --- /dev/null +++ b/systemd/pyhss_diameter.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Diameter Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 diameterService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_geored.service b/systemd/pyhss_geored.service new file mode 100644 index 0000000..7f2da02 --- /dev/null +++ b/systemd/pyhss_geored.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Geored Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 georedService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_hss.service b/systemd/pyhss_hss.service new file mode 100644 index 0000000..5d5994c --- /dev/null +++ b/systemd/pyhss_hss.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS HSS Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 hssService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_log.service b/systemd/pyhss_log.service new file mode 100644 index 0000000..11a7e15 --- /dev/null +++ b/systemd/pyhss_log.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Log Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 logService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_metric.service b/systemd/pyhss_metric.service new file mode 100644 index 0000000..4c3995a --- /dev/null +++ b/systemd/pyhss_metric.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Metric Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 metricService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/test_Diameter.py b/tests/test_Diameter.py similarity index 100% rename from test_Diameter.py rename to tests/test_Diameter.py diff --git a/tests_API.py b/tests/tests_API.py similarity index 100% rename from tests_API.py rename to tests/tests_API.py diff --git a/tools/databaseUpgrade/README.md b/tools/databaseUpgrade/README.md new file mode 100644 index 0000000..4376b4a --- /dev/null +++ b/tools/databaseUpgrade/README.md @@ -0,0 +1,21 @@ +# Database Upgrade + +Database upgrades are currently limited to semi-automation. + +Alembic is used to handle database schema upgades. + +This will not give a foolproof upgrade, ensure you read the generated scripts. +For best results (and in production environments), read lib/database.py and compare each base object to the table in your database. +Types for columns should also be checked. + +# Usage + +1. Ensure that `config.yaml` is populated with the correct database credentials. + +2. Navigate to `tools/databaseUpgrade` + +2. `pip3 install -r requirements.txt` + +3. `alembic revision --autogenerate -m "Name your upgrade"` + +4. `alembic upgrade head` \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic.ini b/tools/databaseUpgrade/alembic.ini new file mode 100644 index 0000000..7bb0089 --- /dev/null +++ b/tools/databaseUpgrade/alembic.ini @@ -0,0 +1,110 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +; sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/tools/databaseUpgrade/alembic/README b/tools/databaseUpgrade/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/tools/databaseUpgrade/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic/env.py b/tools/databaseUpgrade/alembic/env.py new file mode 100644 index 0000000..4cf83b4 --- /dev/null +++ b/tools/databaseUpgrade/alembic/env.py @@ -0,0 +1,89 @@ +from logging.config import fileConfig +from sqlalchemy import create_engine +from alembic import context +import yaml +import sys +import os +sys.path.append(os.path.realpath('lib')) +from database import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + +def get_url_from_config() -> str: + """ + Reads config.yaml and returns the database url. + """ + with open("../../config.yaml", 'r') as stream: + try: + config = yaml.safe_load(stream) + db_string = 'mysql://' + str(config['database']['username']) + ':' + str(config['database']['password']) + '@' + str(config['database']['server']) + '/' + str(config['database']['database']) + return db_string + except Exception as e: + print(e) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = create_engine(get_url_from_config()) + + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/tools/databaseUpgrade/alembic/lib b/tools/databaseUpgrade/alembic/lib new file mode 120000 index 0000000..a5bc743 --- /dev/null +++ b/tools/databaseUpgrade/alembic/lib @@ -0,0 +1 @@ +../../../lib \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic/script.py.mako b/tools/databaseUpgrade/alembic/script.py.mako new file mode 100644 index 0000000..55df286 --- /dev/null +++ b/tools/databaseUpgrade/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/tools/databaseUpgrade/lib b/tools/databaseUpgrade/lib new file mode 120000 index 0000000..58677dd --- /dev/null +++ b/tools/databaseUpgrade/lib @@ -0,0 +1 @@ +../../lib \ No newline at end of file diff --git a/tools/databaseUpgrade/requirements.txt b/tools/databaseUpgrade/requirements.txt new file mode 100644 index 0000000..691f7b6 --- /dev/null +++ b/tools/databaseUpgrade/requirements.txt @@ -0,0 +1,2 @@ +alembic==1.10.3 +zipp==3.17.0