Skip to content

Commit

Permalink
Formatted with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton-Shutik committed Nov 15, 2024
1 parent 733d368 commit f03335d
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 81 deletions.
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from setuptools import setup, find_packages
from setuptools import setup


def get_version():
with open("VERSION") as f:
return f.read().strip()


setup(
name="shopify-client",
version=get_version(),
Expand All @@ -25,7 +27,7 @@ def get_version():
"flake8",
"black",
"sphinx",
"pre-commit"
"pre-commit",
]
},
python_requires=">=3.9",
Expand Down
2 changes: 0 additions & 2 deletions shopify_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import time
from urllib.parse import urljoin

import requests
Expand All @@ -15,7 +14,6 @@


class ShopifyClient(requests.Session):

def __init__(
self,
api_url,
Expand Down
70 changes: 38 additions & 32 deletions shopify_client/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,37 @@

logger = logging.getLogger(__name__)


class Endpoint(object):
def __init__(self, client, endpoint, sub_endpoint=None, metafields=False):
self.client = client
self.endpoint = endpoint
self.sub_endpoint = sub_endpoint

if metafields:
self.metafields = Endpoint(client=client, endpoint=self.endpoint, sub_endpoint="metafields")
self.metafields = Endpoint(
client=client, endpoint=self.endpoint, sub_endpoint="metafields"
)

def __prepare_params(self, **params):
flatted_params = []
for key, value in params.items():
if isinstance(value, dict):
for k, v in value.items():
flatted_params.append(
(f"{key}[{k}]", v)
)
flatted_params.append((f"{key}[{k}]", v))
elif isinstance(value, list):
for v in value:
flatted_params.append(
(f"{key}[]", v)
)
flatted_params.append((f"{key}[]", v))
else:
flatted_params.append((key, value))

return flatted_params

def __build_url(self, resource_id=None, sub_resource_id=None, action=None, **params):
def __build_url(
self, resource_id=None, sub_resource_id=None, action=None, **params
):
url = self.endpoint

if resource_id:
url = f"{url}/{resource_id}"

Expand All @@ -46,11 +47,11 @@ def __build_url(self, resource_id=None, sub_resource_id=None, action=None, **par

if action:
url = f"{url}/{action}"

flatted_params = self.__prepare_params(**params)

return f"{url}.json{'?' + urlencode(flatted_params) if flatted_params else ''}"

def __paginate(self, url):
next_url = url
while next_url:
Expand All @@ -62,7 +63,7 @@ def __paginate(self, url):
def get(self, resource_id, **params):
url = self.__build_url(resource_id=resource_id, **params)
return self.client.parse_response(self.client.get(url))

def create(self, json: dict, **params):
url = self.__build_url(**params)
return self.client.parse_response(self.client.post(url, json=json))
Expand All @@ -75,56 +76,61 @@ def delete(self, resource_id, **params):
url = self.__build_url(resource_id=resource_id, **params)
resp = self.client.delete(url)
return resp.ok

def all(self, paginate=False, **params):
url = self.__build_url(**params)
if paginate:
return self.__paginate(url)
else:
return self.client.parse_response(self.client.get(url))

def action(self, action, resource_id, method="GET", **params):
url = self.__build_url(resource_id=resource_id, action=action, **params)
return self.client.parse_response(self.client.request(method, url, **params))

def count(self, resource_id=None, **params):
return self.action("count", resource_id=resource_id, **params)


class OrdersEndpoint(Endpoint):

def __init__(self, client, endpoint):
super().__init__(client, endpoint, metafields=True)

self.transactions = Endpoint(client=client, endpoint=endpoint, sub_endpoint="transactions")

self.transactions = Endpoint(
client=client, endpoint=endpoint, sub_endpoint="transactions"
)
self.risks = Endpoint(client=client, endpoint=endpoint, sub_endpoint="risks")
self.refunds = Endpoint(client=client, endpoint=endpoint, sub_endpoint="refunds")
self.refunds = Endpoint(
client=client, endpoint=endpoint, sub_endpoint="refunds"
)

def cancel(self, resource_id, **params):
return self.action("cancel", resource_id=resource_id, method="POST", **params)

def close(self, resource_id, **params):
return self.action("close", resource_id=resource_id, method="POST", **params)

def open(self, resource_id, **params):
return self.action("open", resource_id=resource_id, method="POST", **params)


class DraftOrdersEndpoint(Endpoint):

class DraftOrdersEndpoint(Endpoint):
def complete(self, resource_id, **params):
return self.action("complete", resource_id=resource_id, method="PUT", **params)

def send_invoice(self, resource_id, **params):
return self.action("send_invoice", resource_id=resource_id, method="PUT", **params)

return self.action(
"send_invoice", resource_id=resource_id, method="PUT", **params
)

class FulfillmentOrdersEndpoint(Endpoint):

class FulfillmentOrdersEndpoint(Endpoint):
def __init__(self, client, endpoint):
super().__init__(client, endpoint, metafields=True)

self.fulfillment_request = Endpoint(client=client, endpoint=endpoint, sub_endpoint="fulfillment_request")
self.fulfillment_request = Endpoint(
client=client, endpoint=endpoint, sub_endpoint="fulfillment_request"
)

def cancel(self, resource_id, **params):
return self.action("cancel", resource_id=resource_id, method="POST", **params)
return self.action("cancel", resource_id=resource_id, method="POST", **params)
42 changes: 33 additions & 9 deletions shopify_client/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class GraphQL:

def __init__(self, client, graphql_queries_dir=None):
self.client = client
self.endpoint = "graphql.json"
Expand All @@ -27,18 +26,35 @@ def query_from_name(self, name):
with open(query_path, "r") as f:
return f.read()

def __query(self, query=None, query_name=None, variables=None, operation_name=None, paginate=False, page_size=100):
def __query(
self,
query=None,
query_name=None,
variables=None,
operation_name=None,
paginate=False,
page_size=100,
):
assert query or query_name, "Either 'query' or 'query_name' must be provided"

if query is None and query_name:
query = self.query_from_name(query_name)

if paginate:
return self.__paginate(query=query, variables=variables, operation_name=operation_name, page_size=page_size)
return self.__paginate(
query=query,
variables=variables,
operation_name=operation_name,
page_size=page_size,
)
try:
response = self.client.post(
self.__build_url(),
json={"query": query, "variables": variables, "operationName": operation_name},
json={
"query": query,
"variables": variables,
"operationName": operation_name,
},
)
return self.client.parse_response(response)
except requests.exceptions.HTTPError as e:
Expand All @@ -49,9 +65,15 @@ def __query(self, query=None, query_name=None, variables=None, operation_name=No
raise e

def __paginate(self, query, variables=None, operation_name=None, page_size=100):
assert "pageInfo" in query, "Query must contain a 'pageInfo' object to be paginated"
assert "hasNextPage" in query[query.find("pageInfo"):], "Query must contain a 'hasNextPage' field in 'pageInfo' object"
assert "endCursor" in query[query.find("pageInfo"):], "Query must contain a 'endCursor' field in 'pageInfo' object"
assert (
"pageInfo" in query
), "Query must contain a 'pageInfo' object to be paginated"
assert (
"hasNextPage" in query[query.find("pageInfo") :]
), "Query must contain a 'hasNextPage' field in 'pageInfo' object"
assert (
"endCursor" in query[query.find("pageInfo") :]
), "Query must contain a 'endCursor' field in 'pageInfo' object"

variables = variables or {}
variables["page_size"] = page_size
Expand All @@ -61,7 +83,9 @@ def __paginate(self, query, variables=None, operation_name=None, page_size=100):

while has_next_page:
variables["cursor"] = cursor
response = self.__query(query=query, variables=variables, operation_name=operation_name)
response = self.__query(
query=query, variables=variables, operation_name=operation_name
)
page_info = self.__find_page_info(response)
has_next_page = page_info.get("hasNextPage", False)
cursor = page_info.get("endCursor", None)
Expand All @@ -78,4 +102,4 @@ def __find_page_info(self, response):
if isinstance(v, dict):
result = self.__find_page_info(v)
if result:
return result
return result
14 changes: 10 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@
@pytest.fixture
def mock_client(mocker):
client = mocker.Mock()
client.parse_response.side_effect = lambda x: x # Just return the response data as-is
client.parse_response.side_effect = (
lambda x: x
) # Just return the response data as-is
return client


@pytest.fixture
def endpoint(mock_client):
return Endpoint(client=mock_client, endpoint="test_endpoint")


@pytest.fixture
def shopify_client(mocker):
return ShopifyClient(api_url="https://test-shop.myshopify.com", api_token="test-token")
return ShopifyClient(
api_url="https://test-shop.myshopify.com", api_token="test-token"
)


# Create a new mock that will deepcopy the arguments passed to it
# https://docs.python.org/3.7/library/unittest.mock-examples.html#coping-with-mutable-arguments
# https://docs.python.org/3.7/library/unittest.mock-examples.html#coping-with-mutable-arguments
class CopyingMock(Mock):
def __call__(self, *args, **kwargs):
args = deepcopy(args)
kwargs = deepcopy(kwargs)
return super().__call__(*args, **kwargs)
return super().__call__(*args, **kwargs)
7 changes: 6 additions & 1 deletion tests/test_draft_orders_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import pytest
from shopify_client.endpoint import DraftOrdersEndpoint


@pytest.fixture
def draft_orders_endpoint(mock_client):
return DraftOrdersEndpoint(client=mock_client, endpoint="draft_orders")


def test_complete_draft_order(draft_orders_endpoint, mock_client):
mock_client.request.return_value = {"result": "completed"}
response = draft_orders_endpoint.complete(1)
mock_client.request.assert_called_once_with("PUT", "draft_orders/1/complete.json")
assert response == {"result": "completed"}


def test_send_invoice(draft_orders_endpoint, mock_client):
mock_client.request.return_value = {"result": "invoice_sent"}
response = draft_orders_endpoint.send_invoice(1)
mock_client.request.assert_called_once_with("PUT", "draft_orders/1/send_invoice.json")
mock_client.request.assert_called_once_with(
"PUT", "draft_orders/1/send_invoice.json"
)
assert response == {"result": "invoice_sent"}
Loading

0 comments on commit f03335d

Please sign in to comment.