Skip to content

Commit

Permalink
simplify a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
KazuCocoa committed Nov 2, 2024
1 parent ae4b248 commit 9e95458
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
11 changes: 4 additions & 7 deletions appium/webdriver/appium_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class AppiumConnection(RemoteConnection):
"""

user_agent = f'{PREFIX_HEADER}{library_version()} ({RemoteConnection.user_agent})'
extra_headers = {}

@classmethod
def get_remote_connection_headers(cls, parsed_url: 'ParseResult', keep_alive: bool = True) -> Dict[str, Any]:
Expand All @@ -49,12 +50,8 @@ def get_remote_connection_headers(cls, parsed_url: 'ParseResult', keep_alive: bo

if parsed_url.path.endswith('/session'):
# https://github.com/appium/appium-base-driver/pull/400
if cls.extra_headers is None:
cls.extra_headers = {_HEADER_IDEMOTENCY_KEY: str(uuid.uuid4())}
else:
cls.extra_headers[_HEADER_IDEMOTENCY_KEY] = str(uuid.uuid4())
elif cls.extra_headers is not None and _HEADER_IDEMOTENCY_KEY in cls.extra_headers:
cls.extra_headers[_HEADER_IDEMOTENCY_KEY] = str(uuid.uuid4())
elif _HEADER_IDEMOTENCY_KEY in cls.extra_headers:
del cls.extra_headers[_HEADER_IDEMOTENCY_KEY]

base_headers = super().get_remote_connection_headers(parsed_url, keep_alive=keep_alive)
return base_headers if cls.extra_headers is None else {**base_headers, **cls.extra_headers}
return {**super().get_remote_connection_headers(parsed_url, keep_alive=keep_alive), **cls.extra_headers}
31 changes: 31 additions & 0 deletions test/unit/webdriver/appium_connection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from urllib import parse

from appium.webdriver import appium_connection


class AppiumConnectionTest(unittest.TestCase):
def test_get_remote_connection_headers(self):
headers = appium_connection.AppiumConnection.get_remote_connection_headers(
parse.urlparse('http://http://127.0.0.1:4723/session')
)
self.assertIsNotNone(headers.get('X-Idempotency-Key'))

headers = appium_connection.AppiumConnection.get_remote_connection_headers(
parse.urlparse('http://http://127.0.0.1:4723/session/session_id')
)
self.assertIsNone(headers.get('X-Idempotency-Key'))

appium_connection.AppiumConnection.extra_headers = {'custom': 'header'}

headers = appium_connection.AppiumConnection.get_remote_connection_headers(
parse.urlparse('http://http://127.0.0.1:4723/session')
)
self.assertIsNotNone(headers.get('X-Idempotency-Key'))
self.assertEqual(headers.get('custom'), 'header')

headers = appium_connection.AppiumConnection.get_remote_connection_headers(
parse.urlparse('http://http://127.0.0.1:4723/session/session_id')
)
self.assertIsNone(headers.get('X-Idempotency-Key'))
self.assertEqual(headers.get('custom'), 'header')

0 comments on commit 9e95458

Please sign in to comment.