Skip to content

Commit

Permalink
fix: Add validation to ReplaceOrderRequest (#445)
Browse files Browse the repository at this point in the history
* fix: add unit test for ReplaceOrderRequest

* fix: add validation for ReplaceOrderRequest

* fix: fix lint

* fix: remove unused parameters
  • Loading branch information
hiohiohio authored Apr 26, 2024
1 parent ce4a515 commit dda0523
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 21 deletions.
42 changes: 30 additions & 12 deletions alpaca/trading/requests.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from datetime import date, datetime, timedelta
from typing import Optional, Any, List, Union
from typing import Any, List, Optional, Union

import pandas as pd
from pydantic import model_validator
from alpaca.common.models import ModelWithID

from alpaca.common.requests import NonEmptyRequest
from alpaca.common.enums import Sort
from alpaca.common.models import ModelWithID
from alpaca.common.requests import NonEmptyRequest
from alpaca.trading.enums import (
ContractType,
ExerciseStyle,
OrderType,
AssetStatus,
AssetClass,
AssetExchange,
TimeInForce,
OrderSide,
OrderClass,
CorporateActionType,
AssetStatus,
ContractType,
CorporateActionDateType,
QueryOrderStatus,
CorporateActionType,
ExerciseStyle,
OrderClass,
OrderSide,
OrderType,
PositionIntent,
QueryOrderStatus,
TimeInForce,
)


Expand Down Expand Up @@ -209,6 +209,24 @@ class ReplaceOrderRequest(NonEmptyRequest):
trail: Optional[float] = None
client_order_id: Optional[str] = None

@model_validator(mode="before")
def root_validator(cls, values: dict) -> dict:
qty = values.get("qty", None)
limit_price = values.get("limit_price", None)
stop_price = values.get("stop_price", None)
trail = values.get("trail", None)

if (qty is not None) and (qty <= 0):
raise ValueError("qty must be greater than 0")
if (limit_price is not None) and (limit_price <= 0):
raise ValueError("limit_price must be greater than 0")
if (stop_price is not None) and (stop_price <= 0):
raise ValueError("stop_price must be greater than 0")
if (trail is not None) and (trail <= 0):
raise ValueError("trail must be greater than 0")

return values


class CancelOrderResponse(ModelWithID):
"""
Expand Down
46 changes: 37 additions & 9 deletions tests/trading/trading_client/test_order_routes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pytest

from alpaca.common.enums import BaseURL
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, OrderStatus, PositionIntent, TimeInForce
from alpaca.trading.models import Order
from alpaca.trading.requests import (
CancelOrderResponse,
GetOrderByIdRequest,
GetOrdersRequest,
ReplaceOrderRequest,
CancelOrderResponse,
MarketOrderRequest,
LimitOrderRequest,
MarketOrderRequest,
ReplaceOrderRequest,
)
from alpaca.trading.models import Order
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, OrderStatus, TimeInForce, PositionIntent
from alpaca.common.enums import BaseURL

import pytest


def test_market_order(reqmock, trading_client):
Expand Down Expand Up @@ -271,6 +271,34 @@ def test_replace_order(reqmock, trading_client: TradingClient):
assert type(order) is Order


def test_replace_order_validate_replace_request() -> None:
# qty
ReplaceOrderRequest(qty=1)
with pytest.raises(ValueError):
ReplaceOrderRequest(qty=0)
ReplaceOrderRequest(qty=0, limit_price=0.1)
ReplaceOrderRequest(qty=0, stop_price=0.1)
ReplaceOrderRequest(qty=0, trail=0.1)

# limit_price
ReplaceOrderRequest(limit_price=0.1)
ReplaceOrderRequest(qty=1, limit_price=0.1)
with pytest.raises(ValueError):
ReplaceOrderRequest(limit_price=0)

# stop_price
ReplaceOrderRequest(stop_price=0.1)
ReplaceOrderRequest(qty=1, stop_price=0.1)
with pytest.raises(ValueError):
ReplaceOrderRequest(stop_price=0)

# trail
ReplaceOrderRequest(trail=0.1)
ReplaceOrderRequest(qty=1, trail=0.1)
with pytest.raises(ValueError):
ReplaceOrderRequest(trail=0)


def test_cancel_order_by_id(reqmock, trading_client: TradingClient):
order_id = "61e69015-8549-4bfd-b9c3-01e75843f47d"
status_code = 204
Expand Down

0 comments on commit dda0523

Please sign in to comment.