From dda0523635b88a66cc0f39144f972130356806ba Mon Sep 17 00:00:00 2001 From: hiohiohio Date: Fri, 26 Apr 2024 15:50:38 +0900 Subject: [PATCH] fix: Add validation to ReplaceOrderRequest (#445) * fix: add unit test for ReplaceOrderRequest * fix: add validation for ReplaceOrderRequest * fix: fix lint * fix: remove unused parameters --- alpaca/trading/requests.py | 42 ++++++++++++----- .../trading_client/test_order_routes.py | 46 +++++++++++++++---- 2 files changed, 67 insertions(+), 21 deletions(-) diff --git a/alpaca/trading/requests.py b/alpaca/trading/requests.py index 552e9c7b..124b157a 100644 --- a/alpaca/trading/requests.py +++ b/alpaca/trading/requests.py @@ -1,26 +1,26 @@ from datetime import date, datetime, timedelta -from typing import Optional, Any, List, Union +from typing import Any, List, Optional, Union import pandas as pd from pydantic import model_validator -from alpaca.common.models import ModelWithID -from alpaca.common.requests import NonEmptyRequest from alpaca.common.enums import Sort +from alpaca.common.models import ModelWithID +from alpaca.common.requests import NonEmptyRequest from alpaca.trading.enums import ( - ContractType, - ExerciseStyle, - OrderType, - AssetStatus, AssetClass, AssetExchange, - TimeInForce, - OrderSide, - OrderClass, - CorporateActionType, + AssetStatus, + ContractType, CorporateActionDateType, - QueryOrderStatus, + CorporateActionType, + ExerciseStyle, + OrderClass, + OrderSide, + OrderType, PositionIntent, + QueryOrderStatus, + TimeInForce, ) @@ -209,6 +209,24 @@ class ReplaceOrderRequest(NonEmptyRequest): trail: Optional[float] = None client_order_id: Optional[str] = None + @model_validator(mode="before") + def root_validator(cls, values: dict) -> dict: + qty = values.get("qty", None) + limit_price = values.get("limit_price", None) + stop_price = values.get("stop_price", None) + trail = values.get("trail", None) + + if (qty is not None) and (qty <= 0): + raise ValueError("qty must be greater than 0") + if (limit_price is not None) and (limit_price <= 0): + raise ValueError("limit_price must be greater than 0") + if (stop_price is not None) and (stop_price <= 0): + raise ValueError("stop_price must be greater than 0") + if (trail is not None) and (trail <= 0): + raise ValueError("trail must be greater than 0") + + return values + class CancelOrderResponse(ModelWithID): """ diff --git a/tests/trading/trading_client/test_order_routes.py b/tests/trading/trading_client/test_order_routes.py index d0ecbb51..5a4b0ce0 100644 --- a/tests/trading/trading_client/test_order_routes.py +++ b/tests/trading/trading_client/test_order_routes.py @@ -1,18 +1,18 @@ +import pytest + +from alpaca.common.enums import BaseURL from alpaca.common.exceptions import APIError +from alpaca.trading.client import TradingClient +from alpaca.trading.enums import OrderSide, OrderStatus, PositionIntent, TimeInForce +from alpaca.trading.models import Order from alpaca.trading.requests import ( + CancelOrderResponse, GetOrderByIdRequest, GetOrdersRequest, - ReplaceOrderRequest, - CancelOrderResponse, - MarketOrderRequest, LimitOrderRequest, + MarketOrderRequest, + ReplaceOrderRequest, ) -from alpaca.trading.models import Order -from alpaca.trading.client import TradingClient -from alpaca.trading.enums import OrderSide, OrderStatus, TimeInForce, PositionIntent -from alpaca.common.enums import BaseURL - -import pytest def test_market_order(reqmock, trading_client): @@ -271,6 +271,34 @@ def test_replace_order(reqmock, trading_client: TradingClient): assert type(order) is Order +def test_replace_order_validate_replace_request() -> None: + # qty + ReplaceOrderRequest(qty=1) + with pytest.raises(ValueError): + ReplaceOrderRequest(qty=0) + ReplaceOrderRequest(qty=0, limit_price=0.1) + ReplaceOrderRequest(qty=0, stop_price=0.1) + ReplaceOrderRequest(qty=0, trail=0.1) + + # limit_price + ReplaceOrderRequest(limit_price=0.1) + ReplaceOrderRequest(qty=1, limit_price=0.1) + with pytest.raises(ValueError): + ReplaceOrderRequest(limit_price=0) + + # stop_price + ReplaceOrderRequest(stop_price=0.1) + ReplaceOrderRequest(qty=1, stop_price=0.1) + with pytest.raises(ValueError): + ReplaceOrderRequest(stop_price=0) + + # trail + ReplaceOrderRequest(trail=0.1) + ReplaceOrderRequest(qty=1, trail=0.1) + with pytest.raises(ValueError): + ReplaceOrderRequest(trail=0) + + def test_cancel_order_by_id(reqmock, trading_client: TradingClient): order_id = "61e69015-8549-4bfd-b9c3-01e75843f47d" status_code = 204