diff --git a/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py b/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py index f0b73de66c34..5870d7e6b602 100644 --- a/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py +++ b/platform/mellanox/mlnx-platform-api/sonic_platform/chassis.py @@ -124,6 +124,7 @@ def __init__(self): self.reboot_cause_initialized = False self.sfp_module = None + self.sfp_lock = threading.Lock() # Build the RJ45 port list from platform.json and hwsku.json self._RJ45_port_inited = False @@ -277,38 +278,49 @@ def _import_sfp_module(self): def initialize_single_sfp(self, index): sfp_count = self.get_num_sfps() + # Use double checked locking mechanism for: + # 1. protect shared resource self._sfp_list + # 2. performance (avoid locking every time) if index < sfp_count: - if not self._sfp_list: - self._sfp_list = [None] * sfp_count - - if not self._sfp_list[index]: - sfp_module = self._import_sfp_module() - if self.RJ45_port_list and index in self.RJ45_port_list: - self._sfp_list[index] = sfp_module.RJ45Port(index) - else: - self._sfp_list[index] = sfp_module.SFP(index) - self.sfp_initialized_count += 1 + if not self._sfp_list or not self._sfp_list[index]: + with self.sfp_lock: + if not self._sfp_list: + self._sfp_list = [None] * sfp_count + + if not self._sfp_list[index]: + sfp_module = self._import_sfp_module() + if self.RJ45_port_list and index in self.RJ45_port_list: + self._sfp_list[index] = sfp_module.RJ45Port(index) + else: + self._sfp_list[index] = sfp_module.SFP(index) + self.sfp_initialized_count += 1 def initialize_sfp(self): - if not self._sfp_list: - sfp_module = self._import_sfp_module() - sfp_count = self.get_num_sfps() - for index in range(sfp_count): - if self.RJ45_port_list and index in self.RJ45_port_list: - sfp_object = sfp_module.RJ45Port(index) - else: - sfp_object = sfp_module.SFP(index) - self._sfp_list.append(sfp_object) - self.sfp_initialized_count = sfp_count - elif self.sfp_initialized_count != len(self._sfp_list): - sfp_module = self._import_sfp_module() - for index in range(len(self._sfp_list)): - if self._sfp_list[index] is None: - if self.RJ45_port_list and index in self.RJ45_port_list: - self._sfp_list[index] = sfp_module.RJ45Port(index) - else: - self._sfp_list[index] = sfp_module.SFP(index) - self.sfp_initialized_count = len(self._sfp_list) + sfp_count = self.get_num_sfps() + # Use double checked locking mechanism for: + # 1. protect shared resource self._sfp_list + # 2. performance (avoid locking every time) + if sfp_count != self.sfp_initialized_count: + with self.sfp_lock: + if sfp_count != self.sfp_initialized_count: + if not self._sfp_list: + sfp_module = self._import_sfp_module() + for index in range(sfp_count): + if self.RJ45_port_list and index in self.RJ45_port_list: + sfp_object = sfp_module.RJ45Port(index) + else: + sfp_object = sfp_module.SFP(index) + self._sfp_list.append(sfp_object) + self.sfp_initialized_count = sfp_count + elif self.sfp_initialized_count != len(self._sfp_list): + sfp_module = self._import_sfp_module() + for index in range(len(self._sfp_list)): + if self._sfp_list[index] is None: + if self.RJ45_port_list and index in self.RJ45_port_list: + self._sfp_list[index] = sfp_module.RJ45Port(index) + else: + self._sfp_list[index] = sfp_module.SFP(index) + self.sfp_initialized_count = len(self._sfp_list) def get_num_sfps(self): """ diff --git a/platform/mellanox/mlnx-platform-api/tests/test_chassis.py b/platform/mellanox/mlnx-platform-api/tests/test_chassis.py index fce9bd00b0ee..ffe86aaf3d08 100644 --- a/platform/mellanox/mlnx-platform-api/tests/test_chassis.py +++ b/platform/mellanox/mlnx-platform-api/tests/test_chassis.py @@ -16,8 +16,10 @@ # import os +import random import sys import subprocess +import threading from mock import MagicMock if sys.version_info.major == 3: @@ -167,6 +169,30 @@ def test_sfp(self): assert len(sfp_list) == 3 assert chassis.sfp_initialized_count == 3 + def test_create_sfp_in_multi_thread(self): + DeviceDataManager.get_sfp_count = mock.MagicMock(return_value=3) + + iteration_num = 100 + while iteration_num > 0: + chassis = Chassis() + assert chassis.sfp_initialized_count == 0 + t1 = threading.Thread(target=lambda: chassis.get_sfp(1)) + t2 = threading.Thread(target=lambda: chassis.get_sfp(1)) + t3 = threading.Thread(target=lambda: chassis.get_all_sfps()) + t4 = threading.Thread(target=lambda: chassis.get_all_sfps()) + threads = [t1, t2, t3, t4] + random.shuffle(threads) + for t in threads: + t.start() + for t in threads: + t.join() + assert len(chassis.get_all_sfps()) == 3 + assert chassis.sfp_initialized_count == 3 + for index, s in enumerate(chassis.get_all_sfps()): + assert s.sdk_index == index + iteration_num -= 1 + + @mock.patch('sonic_platform.device_data.DeviceDataManager.get_sfp_count', MagicMock(return_value=3)) def test_change_event(self): chassis = Chassis()