From 18b3a9339cf49a447c6f99d4e2affca8b0c4d3e5 Mon Sep 17 00:00:00 2001 From: Junchao-Mellanox <57339448+Junchao-Mellanox@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:01:11 +0800 Subject: [PATCH] [Mellanox] Fix race condition while creating SFP (#17441) - Why I did it Fix issue xcvrd crashes due to cannot import name 'initialize_sfp_thermal': Nov 27 09:47:16.388639 sonic ERR pmon#xcvrd: Exception occured at CmisManagerTask thread due to ImportError("cannot import name 'initialize_sfp_thermal' from partially initialized module 'sonic_platform.thermal' (most likely due to a circular import) (/usr/local/lib/python3.9/dist-packages/sonic_platform/thermal.py)") - How I did it Add lock for creating SFP object - How to verify it Unit test Manual Test --- .../sonic_platform/chassis.py | 70 +++++++++++-------- .../mlnx-platform-api/tests/test_chassis.py | 26 +++++++ 2 files changed, 67 insertions(+), 29 deletions(-) diff --git a/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py b/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py index f0b73de66c34..5870d7e6b602 100644 --- a/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py +++ b/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py @@ -124,6 +124,7 @@ def __init__(self): self.reboot_cause_initialized = False self.sfp_module = None + self.sfp_lock = threading.Lock() # Build the RJ45 port list from platform.json and hwsku.json self._RJ45_port_inited = False @@ -277,38 +278,49 @@ def _import_sfp_module(self): def initialize_single_sfp(self, index): sfp_count = self.get_num_sfps() + # Use double checked locking mechanism for: + # 1. protect shared resource self._sfp_list + # 2. performance (avoid locking every time) if index < sfp_count: - if not self._sfp_list: - self._sfp_list = [None] * sfp_count - - if not self._sfp_list[index]: - sfp_module = self._import_sfp_module() - if self.RJ45_port_list and index in self.RJ45_port_list: - self._sfp_list[index] = sfp_module.RJ45Port(index) - else: - self._sfp_list[index] = sfp_module.SFP(index) - self.sfp_initialized_count += 1 + if not self._sfp_list or not self._sfp_list[index]: + with self.sfp_lock: + if not self._sfp_list: + self._sfp_list = [None] * sfp_count + + if not self._sfp_list[index]: + sfp_module = self._import_sfp_module() + if self.RJ45_port_list and index in self.RJ45_port_list: + self._sfp_list[index] = sfp_module.RJ45Port(index) + else: + self._sfp_list[index] = sfp_module.SFP(index) + self.sfp_initialized_count += 1 def initialize_sfp(self): - if not self._sfp_list: - sfp_module = self._import_sfp_module() - sfp_count = self.get_num_sfps() - for index in range(sfp_count): - if self.RJ45_port_list and index in self.RJ45_port_list: - sfp_object = sfp_module.RJ45Port(index) - else: - sfp_object = sfp_module.SFP(index) - self._sfp_list.append(sfp_object) - self.sfp_initialized_count = sfp_count - elif self.sfp_initialized_count != len(self._sfp_list): - sfp_module = self._import_sfp_module() - for index in range(len(self._sfp_list)): - if self._sfp_list[index] is None: - if self.RJ45_port_list and index in self.RJ45_port_list: - self._sfp_list[index] = sfp_module.RJ45Port(index) - else: - self._sfp_list[index] = sfp_module.SFP(index) - self.sfp_initialized_count = len(self._sfp_list) + sfp_count = self.get_num_sfps() + # Use double checked locking mechanism for: + # 1. protect shared resource self._sfp_list + # 2. performance (avoid locking every time) + if sfp_count != self.sfp_initialized_count: + with self.sfp_lock: + if sfp_count != self.sfp_initialized_count: + if not self._sfp_list: + sfp_module = self._import_sfp_module() + for index in range(sfp_count): + if self.RJ45_port_list and index in self.RJ45_port_list: + sfp_object = sfp_module.RJ45Port(index) + else: + sfp_object = sfp_module.SFP(index) + self._sfp_list.append(sfp_object) + self.sfp_initialized_count = sfp_count + elif self.sfp_initialized_count != len(self._sfp_list): + sfp_module = self._import_sfp_module() + for index in range(len(self._sfp_list)): + if self._sfp_list[index] is None: + if self.RJ45_port_list and index in self.RJ45_port_list: + self._sfp_list[index] = sfp_module.RJ45Port(index) + else: + self._sfp_list[index] = sfp_module.SFP(index) + self.sfp_initialized_count = len(self._sfp_list) def get_num_sfps(self): """ diff --git a/platform/mellanox/mlnx-platform-api/tests/test_chassis.py b/platform/mellanox/mlnx-platform-api/tests/test_chassis.py index fce9bd00b0ee..ffe86aaf3d08 100644 --- a/platform/mellanox/mlnx-platform-api/tests/test_chassis.py +++ b/platform/mellanox/mlnx-platform-api/tests/test_chassis.py @@ -16,8 +16,10 @@ # import os +import random import sys import subprocess +import threading from mock import MagicMock if sys.version_info.major == 3: @@ -167,6 +169,30 @@ def test_sfp(self): assert len(sfp_list) == 3 assert chassis.sfp_initialized_count == 3 + def test_create_sfp_in_multi_thread(self): + DeviceDataManager.get_sfp_count = mock.MagicMock(return_value=3) + + iteration_num = 100 + while iteration_num > 0: + chassis = Chassis() + assert chassis.sfp_initialized_count == 0 + t1 = threading.Thread(target=lambda: chassis.get_sfp(1)) + t2 = threading.Thread(target=lambda: chassis.get_sfp(1)) + t3 = threading.Thread(target=lambda: chassis.get_all_sfps()) + t4 = threading.Thread(target=lambda: chassis.get_all_sfps()) + threads = [t1, t2, t3, t4] + random.shuffle(threads) + for t in threads: + t.start() + for t in threads: + t.join() + assert len(chassis.get_all_sfps()) == 3 + assert chassis.sfp_initialized_count == 3 + for index, s in enumerate(chassis.get_all_sfps()): + assert s.sdk_index == index + iteration_num -= 1 + + @mock.patch('sonic_platform.device_data.DeviceDataManager.get_sfp_count', MagicMock(return_value=3)) def test_change_event(self): chassis = Chassis()