Skip to content

Commit

Permalink
Support for PCSCF Restoration via IDR
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkneipp committed Oct 31, 2023
1 parent 4e3f4f5 commit 642953c
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 204 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Reduced verbosity of failing subscriber lookups to debug
- Added CORS headers: [Zarya/171](https://github.com/nickvsnetworking/pyhss/pull/171)
- Gx RAR now dynamically creates TFT up to 512k based on UE request.

### Fixed

- Geored failing when multiple peers defined and socket closes.
- Error in Update_Serving_MME when supplied with a NoneType timestamp.

### Added

- Support for IDR-based PCSCF restoration via /pcrf/pcscf_restoration and /pcrf/pcscf_restoration_subscriber in API.

## [1.0.0] - 2023-09-27

### Added
Expand Down
3 changes: 2 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ benchmarking:
eir:
imsi_imei_logging: True #Store current IMEI / IMSI pair in backend
no_match_response: 2 #Greylist
tac_database_csv: '/etc/pyhss/tac_database_Nov2022.csv'
# Define an optional TAC csv file path
#tac_database_csv: '/etc/pyhss/tac_database.csv'

logging:
level: INFO
Expand Down
25 changes: 25 additions & 0 deletions lib/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,30 @@ def Get_Subscriber(self, **kwargs):
self.safe_close(session)
return result

def Get_Subscribers_By_Pcscf(self, pcscf: str):
Session = sessionmaker(bind = self.engine)
session = Session()
self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Subscribers_By_Pcscf] Get_Subscribers_By_Pcscf for PCSCF: {pcscf}", redisClient=self.redisMessaging)
try:
result = session.query(IMS_SUBSCRIBER).filter_by(pcscf=pcscf).all()
except Exception as E:
self.safe_close(session)
raise ValueError(E)
returnList = []
for item in result:
try:
returnList.append(item.__dict__)
except Exception as e:
self.logTool.log(service='Database', level='warning', message=f"[database.py] [Get_Subscribers_By_Pcscf] Error getting ims_subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging)
pass
for item in returnList:
try:
item.pop('_sa_instance_state')
except Exception as e:
pass
self.safe_close(session)
return returnList

def Get_SUBSCRIBER_ROUTING(self, subscriber_id, apn_id):
Session = sessionmaker(bind = self.engine)
session = Session()
Expand Down Expand Up @@ -2497,3 +2521,4 @@ def get_device_info_from_TAC(self, imei) -> dict:




268 changes: 123 additions & 145 deletions lib/diameter.py

Large diffs are not rendered by default.

118 changes: 118 additions & 0 deletions services/apiService.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@
'charging_rule_id' : fields.Integer(required=True, description='charging_rule_id to push'),
})

PCRF_PCSCF_Restoration_Subscriber_model = api.model('PCRF_PCSCF_Restoration_Subscriber', {
'imsi': fields.String(required=True, description='IMSI of IMS Subscriber'),
'msisdn': fields.String(required=True, description='MSISDN of IMS Subscriber'),
})

PCRF_PCSCF_Restoration_model = api.model('PCRF_PCSCF_Restoration', {
'pcscf': fields.String(required=True, description='Serving PCSCF to send restoration for'),
})

Push_CLR_Model = api.model('CLR', {
'DestinationRealm': fields.String(required=True, description='Destination Realm to set'),
'DestinationHost': fields.String(required=False, description='Destination Host (Optional)'),
Expand Down Expand Up @@ -201,6 +210,10 @@ def decorated_function(*args, **kwargs):
def auth_before_request():
if request.path.startswith('/docs') or request.path.startswith('/swagger') or request.path.startswith('/metrics'):
return None
if request.method == "OPTIONS":
res = Response()
res.headers['X-Content-Type-Options'] = '*'
return res
if request.endpoint and 'static' not in request.endpoint:
view_function = apiService.view_functions[request.endpoint]
if hasattr(view_function, 'view_class'):
Expand Down Expand Up @@ -253,6 +266,9 @@ def page_not_found(e):
@apiService.after_request
def apply_caching(response):
response.headers["HSS"] = str(config['hss']['OriginHost'])
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET,PUT,POST,DELETE,PATCH,OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, Content-Length, X-Requested-With, Provisioning-Key"
return response

@ns_apn.route('/<string:apn_id>')
Expand Down Expand Up @@ -1501,6 +1517,108 @@ def put(self):
result = {"Result": "Successfully sent Gx RAR", "destinationClients": str(servingPgw)}
return result, 200

@ns_pcrf.route('/pcscf_restoration_subscriber')
class PyHSS_PCRF_PSCSF_Restoration_Subscriber(Resource):
@ns_pcrf.doc('Trigger PCSCF Restoration for an IMS Subscriber')
@ns_pcrf.expect(PCRF_PCSCF_Restoration_Subscriber_model)
def put(self):
'''Trigger PCSCF Restoration for an IMS Subscriber'''

try:
jsonData = request.get_json(force=True)
#Get IMSI

imsi = jsonData.get('imsi', None)
msisdn = jsonData.get('msisdn', None)

if not imsi and not msisdn:
result = {"Result": "Error: IMSI or MSISDN Required"}
return result, 400

if imsi:
subscriberData = databaseClient.Get_Subscriber(imsi=imsi)
imsSubscriberData = databaseClient.Get_IMS_Subscriber(imsi=imsi)
else:
imsSubscriberData = databaseClient.Get_IMS_Subscriber(msisdn=msisdn)
subscriberData = databaseClient.Get_Subscriber(imsi=imsSubscriberData.get('imsi', None))

try:
servingMmePeer = subscriberData.get('serving_mme_peer').split(';')[0]
except Exception as e:
result = {"Result": "Error: Subscriber is not currently served by an MME"}
return result, 400

imsi = imsSubscriberData.get('imsi', None)
servingMmeRealm = subscriberData.get('serving_mme_realm', None)
servingMme = subscriberData.get('serving_mme', None)

diameterResponse = diameterClient.sendDiameterRequest(
requestType='ISD',
hostname=servingMmePeer,
imsi=imsi,
DestinationRealm=servingMmeRealm,
DestinationHost=servingMme,
PcscfRestoration=True
)

result = {"Result": f"Successfully sent PCSCF Restoration request via {servingMmePeer} for IMSI {imsi}"}
return result, 200

except Exception as E:
print("Flask Exception: " + str(E))
return handle_exception(E)

@ns_pcrf.route('/pcscf_restoration')
class PyHSS_PCRF_PSCSF_Restoration_Subscriber(Resource):
@ns_pcrf.doc('Trigger PCSCF Restoration for all IMS Subscribers attached to PCSCF')
@ns_pcrf.expect(PCRF_PCSCF_Restoration_model)
def put(self):
'''Trigger PCSCF Restoration for all IMS Subscribers attached to PCSCF'''

try:
jsonData = request.get_json(force=True)

pcscf = jsonData.get('pcscf', None)

if not pcscf:
result = {"Result": "Error: PCSCF Required"}
return result, 400

activeSubscribers = databaseClient.Get_Subscribers_By_Pcscf(pcscf=pcscf)
logTool.log(service='API', level='debug', message=f"[API] Active Subscribers for {pcscf}: {activeSubscribers}", redisClient=redisMessaging)

if len(activeSubscribers) > 0:
for imsSubscriber in activeSubscribers:
try:
imsi = imsSubscriber.get('imsi', None)
if not imsi:
continue
subscriberData = databaseClient.Get_Subscriber(imsi=imsi)
servingMmePeer = subscriberData.get('serving_mme_peer').split(';')[0]

imsi = subscriberData.get('imsi', None)
servingMmeRealm = subscriberData.get('serving_mme_realm', None)
servingMme = subscriberData.get('serving_mme', None)

diameterResponse = diameterClient.sendDiameterRequest(
requestType='ISD',
hostname=servingMmePeer,
imsi=imsi,
DestinationRealm=servingMmeRealm,
DestinationHost=servingMme,
PcscfRestoration=True
)

except Exception as e:
continue

result = {"Result": f"Successfully sent PCSCF Restoration request for PCSCF: {pcscf}"}
return result, 200

except Exception as E:
print("Flask Exception: " + str(E))
return handle_exception(E)

@ns_pcrf.route('/<string:charging_rule_id>')
class PyHSS_PCRF_Complete(Resource):
def get(self, charging_rule_id):
Expand Down
2 changes: 1 addition & 1 deletion services/diameterService.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def handleActiveDiameterPeers(self):

await(asyncio.sleep(1))
except Exception as e:
print(e)
await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [handleActiveDiameterPeers] Exception: {e}\n{traceback.format_exc()}"))
await(asyncio.sleep(1))
continue

Expand Down
114 changes: 57 additions & 57 deletions services/georedService.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,91 +251,91 @@ async def handleAsymmetricGeoredQueue(self):
"""
Collects and processes asymmetric geored messages.
"""
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1])
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}"))

georedOperation = georedMessage['operation']
georedBody = georedMessage['body']
georedUrls = georedMessage['urls']
georedTasks = []

async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1])
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}"))

georedOperation = georedMessage['operation']
georedBody = georedMessage['body']
georedUrls = georedMessage['urls']
georedTasks = []

for georedEndpoint in georedUrls:
georedTasks.append(self.sendGeored(asyncSession=session, url=georedEndpoint, operation=georedOperation, body=georedBody))
await asyncio.gather(*georedTasks)
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleAsymmetricGeoredQueue] Time taken to send asymmetric geored message to specified peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleAsymmetricGeoredQueue] Time taken to send asymmetric geored message to specified peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))

await(asyncio.sleep(0))
await(asyncio.sleep(0))

except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Error handling asymmetric geored queue: {e}"))
await(asyncio.sleep(0))
continue
except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Error handling asymmetric geored queue: {e}"))
await(asyncio.sleep(0))
continue

async def handleGeoredQueue(self):
"""
Collects and processes queued geored messages.
"""
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1])
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}"))

georedOperation = georedMessage['operation']
georedBody = georedMessage['body']
georedTasks = []

async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1])
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}"))

georedOperation = georedMessage['operation']
georedBody = georedMessage['body']
georedTasks = []

for remotePeer in self.georedPeers:
georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody))
await asyncio.gather(*georedTasks)
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))

await(asyncio.sleep(0))
await(asyncio.sleep(0))

except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}"))
await(asyncio.sleep(0))
continue
except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}"))
await(asyncio.sleep(0))
continue

async def handleWebhookQueue(self):
"""
Collects and processes queued webhook messages.
"""
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1])
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
while True:
try:
if self.benchmarking:
startTime = time.perf_counter()
webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1])

await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}"))
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}"))

webhookHeaders = webhookMessage['headers']
webhookOperation = webhookMessage['operation']
webhookBody = webhookMessage['body']
webhookTasks = []
webhookHeaders = webhookMessage['headers']
webhookOperation = webhookMessage['operation']
webhookBody = webhookMessage['body']
webhookTasks = []

async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
for remotePeer in self.webhookPeers:
webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders))
await asyncio.gather(*webhookTasks)
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))
if self.benchmarking:
await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms"))

await(asyncio.sleep(0.001))
await(asyncio.sleep(0.001))

except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}"))
await(asyncio.sleep(0.001))
continue
except Exception as e:
await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}"))
await(asyncio.sleep(0.001))
continue

async def startService(self):
"""
Expand Down

0 comments on commit 642953c

Please sign in to comment.