Skip to content

Commit

Permalink
new classes, backtesting func and reports (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiyaB authored Oct 9, 2020
1 parent 08d7c95 commit eabf293
Show file tree
Hide file tree
Showing 37 changed files with 8,052 additions and 2,756 deletions.
3 changes: 2 additions & 1 deletion gs_quant/api/gs/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def delete_country(cls, country_id: str) -> dict:

@classmethod
def get_many_subdivisions(cls, limit: int = 100) -> Tuple[Subdivision, ...]:
return GsSession.current._get('/subdivisions?limit={limit}'.format(limit=limit), cls=Subdivision)['results']
return GsSession.current._get('/countries/subdivisions?limit={limit}'.format(limit=limit),
cls=Subdivision)['results']

@classmethod
def get_subdivision(cls, subdivision_id: str) -> Subdivision:
Expand Down
10 changes: 6 additions & 4 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class QueryType(Enum):
ES_DISCLOSURE_PERCENTAGE = "ES Disclosure Percentage"
RATING = "Rating"
CONVICTION_LIST = "Conviction List"
GSDEER = "Gsdeer"
GSFEER = "Gsfeer"
GIR_GSDEER_GSFEER = "Gir Gsdeer Gsfeer"
GIR_FX_FORECAST = "Gir Fx Forecast"
GROWTH_SCORE = "Growth Score"
FINANCIAL_RETURNS_SCORE = "Financial Returns Score"
MULTIPLE_SCORE = "Multiple Score"
Expand Down Expand Up @@ -652,10 +652,12 @@ def construct_dataframe_with_types(cls, dataset_id: str, data: Union[Base, List,
"""
if len(data):
dataset_types = cls.get_types(dataset_id)
df = pd.DataFrame(data)

df = pd.DataFrame(data, columns=dataset_types)

for field_name, type_name in dataset_types.items():
if df.get(field_name) is not None and type_name in ('date', 'date-time'):
if df.get(field_name) is not None and type_name in ('date', 'date-time') and \
len(df.get(field_name).value_counts()) > 0:
df = df.astype({field_name: numpy.datetime64})

field_names = dataset_types.keys()
Expand Down
4 changes: 2 additions & 2 deletions gs_quant/api/gs/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, marquee_id: str = None):

def create(
self,
inputs: IndicesCreateInputs,
inputs: Union[CustomBasketsCreateInputs, IndicesDynamicConstructInputs]
) -> CustomBasketsResponse:
"""
Create a custom basket of equity stocks or ETFs
Expand Down Expand Up @@ -68,7 +68,7 @@ def rebalance(

def cancel_rebalance(
self,
inputs: ApprovalAction,
inputs: Union[CustomBasketsRebalanceAction, ISelectActionRequest],
):
"""
Cancel current pending rebalance of an index
Expand Down
90 changes: 90 additions & 0 deletions gs_quant/api/gs/reports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright 2019 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""
import datetime as dt
import logging
from typing import Tuple

from gs_quant.session import GsSession
from gs_quant.target.reports import Report, ReportScheduleRequest, ReportJob

_logger = logging.getLogger(__name__)


class GsReportApi:
"""GS Reports API client implementation"""

@classmethod
def create_report(cls, report: Report) -> Report:
return GsSession.current._post('/reports', report, cls=Report)

@classmethod
def get_report(cls, report_id: str) -> Report:
return GsSession.current._get('/reports/{id}'.format(id=report_id), cls=Report)

@classmethod
def get_reports(cls, limit: int = 100, offset: int = None, position_source_type: str = None,
position_source_id: str = None, status: str = None, report_type: str = None) \
-> Tuple[Report, ...]:
url = '/reports?limit={limit}'.format(limit=limit)
if offset is not None:
url += '&offset={offset}'.format(offset=offset)
if position_source_type is not None:
url += '&positionSourceType={pst}'.format(pst=position_source_type)
if position_source_id is not None:
url += '&positionSourceId={psi}'.format(psi=position_source_id)
if status is not None:
url += '&status={status}'.format(status=status)
if report_type is not None:
url += '&type={report_type}'.format(report_type=report_type)
return GsSession.current._get(url, cls=Report)['results']

@classmethod
def update_report(cls, report: Report) -> dict:
return GsSession.current._put('/reports/{id}'.format(id=report.id), report, cls=Report)

@classmethod
def delete_report(cls, report_id: str) -> dict:
return GsSession.current._delete('/reports/{id}'.format(id=report_id))

@classmethod
def schedule_report(cls, report_id: str, start_date: dt.date, end_date: dt.date) -> dict:
report_schedule_request = ReportScheduleRequest(startDate=start_date, endDate=end_date)
return GsSession.current._post('/reports/{id}/schedule'.format(id=report_id), report_schedule_request,
cls=ReportScheduleRequest)

@classmethod
def get_report_status(cls, report_id: str) -> Tuple[dict, ...]:
return GsSession.current._get('/reports/{id}/status'.format(id=report_id))

@classmethod
def get_report_jobs(cls, report_id: str) -> Tuple[ReportJob, ...]:
return GsSession.current._get('/reports/{id}/jobs'.format(id=report_id))['results']

@classmethod
def get_report_job(cls, report_job_id: str) -> dict:
return GsSession.current._get('/reports/jobs/{report_job_id}'.format(report_job_id=report_job_id))

@classmethod
def cancel_report_job(cls, report_job_id: str) -> dict:
return GsSession.current._post('/reports/jobs/{report_job_id}/cancel'.format(report_job_id=report_job_id))

@classmethod
def update_report_job(cls, report_job_id: str, status: str) -> dict:
status_body = {
"status": '{status}'.format(status=status)
}
return GsSession.current._post('/reports/jobs/{report_job_id}/update'.format(report_job_id=report_job_id),
status_body)
15 changes: 9 additions & 6 deletions gs_quant/backtests/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, priceables: Union[Priceable, Iterable[Priceable]], trade_dura
# TODO: make trade_duration capable of being a tenor or a date as well as a trade attribute
super().__init__(name)
self._priceables = make_list(priceables)
self._dated_priceables = {} # a trigger may inject the portfolio at a trigger date
self._trade_duration = trade_duration
for i, p in enumerate(self._priceables):
if p.name is None:
Expand All @@ -99,15 +100,17 @@ def trade_duration(self):
return self._trade_duration

def apply_action(self, state: Union[datetime.date, Iterable[datetime.date]], backtest: BackTest):
with HistoricalPricingContext(dates=make_list(state)):
backtest.calc_calls += 1
backtest.calculations += len(make_list(state)) * len(self._priceables)
f = Portfolio(self._priceables).resolve(in_place=False)
with PricingContext(is_batch=True):
f = {}
for s in state:
active_portfolio = self._dated_priceables.get(s) or self._priceables
with PricingContext(pricing_date=s):
f[s] = Portfolio(active_portfolio).resolve(in_place=False)

for s in backtest.states:
pos = []
for create_date, portfolio in f.result().items():
pos += [inst for inst in portfolio.instruments
for create_date, portfolio in f.items():
pos += [inst for inst in portfolio.result().instruments
if get_final_date(inst, create_date, self.trade_duration) >= s >= create_date]
backtest.portfolio_dict[s].append(pos)

Expand Down
9 changes: 6 additions & 3 deletions gs_quant/backtests/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import datetime
from typing import Union

from gs_quant.data import Dataset

Expand All @@ -30,15 +31,17 @@ def get_data(self, state):


class GsDataSource(DataSource):
def __init__(self, data_set: str, asset_id: str, min_date: datetime.date = None, max_date: datetime.date = None):
def __init__(self, data_set: str, asset_id: str, min_date: datetime.date = None, max_date: datetime.date = None,
value_header: str = 'rate'):
self._data_set = data_set
self._asset_id = asset_id
self._min_date = min_date
self._max_date = max_date
self._value_header = value_header
self._loaded_data = None

def get_data(self, state):
def get_data(self, state: Union[datetime.date, datetime.datetime] = None):
if self._loaded_data is None:
ds = Dataset(self._data_set)
self._loaded_data = ds.get_data(self._min_date or state, self._max_date or state, assetId=(self._asset_id,))
return self._loaded_data[self]
return self._loaded_data[self._value_header]
2 changes: 1 addition & 1 deletion gs_quant/backtests/generic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run_backtest(cls, strategy, start=None, end=None, frequency='BM', window=Non

for trigger in strategy.triggers:
if trigger.deterministic:
triggered_dates = [date for date in dates if trigger.has_triggered(date)]
triggered_dates = [date for date in dates if trigger.has_triggered(date, backtest)]
for action in trigger.actions:
if action.deterministic:
action.apply_action(triggered_dates, backtest)
Expand Down
20 changes: 20 additions & 0 deletions gs_quant/backtests/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def __init__(self, strategy_results, risk, trigger_level, direction):
self.direction = direction


class AggregateTriggerRequirements(TriggerRequirements):
def __init__(self, triggers: Iterable[object]):
super().__init__()
self.triggers = triggers


class Trigger(object):

def __init__(self, trigger_requirements: Optional[TriggerRequirements], actions: Union[Action, Iterable[Action]]):
Expand Down Expand Up @@ -156,3 +162,17 @@ def has_triggered(self, state: datetime.date, backtest: BackTest = None) -> bool
if risk_value == self._trigger_requirements.trigger_level:
return True
return False


class AggregateTrigger(Trigger):
def __init__(self, triggers: Iterable[Trigger]):
actions = []
for t in triggers:
actions += [action for action in t.actions]
super().__init__(AggregateTriggerRequirements(triggers), actions)

def has_triggered(self, state: datetime.date, backtest: BackTest = None) -> bool:
for trigger in self._trigger_requirements.triggers:
if not trigger.has_triggered(state, backtest):
return False
return True
2 changes: 2 additions & 0 deletions gs_quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __hash__(self) -> int:
value = super().__getattribute__(prop)
if isinstance(value, dict):
value = tuple(value.items())
elif isinstance(value, list):
value = tuple(value)
calced_hash ^= hash(value)

self.__calced_hash = calced_hash
Expand Down
3 changes: 2 additions & 1 deletion gs_quant/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""

from gs_quant.target.common import BusinessDayConvention, BuySell, Currency, DayCountFraction, AssetClass, AssetType,\
OptionStyle, OptionType, PayReceive, PricingLocation, SwapClearingHouse, SwapSettlement, XRef
OptionStyle, OptionSettlementMethod, OptionType, PayReceive, PricingLocation, SwapClearingHouse, SwapSettlement, \
XRef
from gs_quant.target.risk import CountryCode
from enum import Enum

Expand Down
9 changes: 6 additions & 3 deletions gs_quant/instrument/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from gs_quant.session import GsSession

from abc import ABCMeta
import datetime as dt
import logging
from typing import Iterable, Optional, Tuple, Union
import inspect
Expand Down Expand Up @@ -106,20 +107,22 @@ def resolve(self, in_place: bool = True) -> Optional[Union[PriceableImpl, Pricin
rates is now the solved fixed rate
"""

is_historical = isinstance(PricingContext.current, HistoricalPricingContext)

def handle_result(result: Optional[Union[ErrorValue, InstrumentBase]]) -> Optional[PriceableImpl]:
ret = None if in_place else result
if isinstance(result, ErrorValue):
_logger.error('Failed to resolve instrument fields: ' + result.error)
ret = self
ret = {result.risk_key.date: self} if is_historical else self
elif result is None:
_logger.error('Unknown error resolving instrument fields')
ret = self
ret = {dt.date.today(): self} if is_historical else self
elif in_place:
self.from_instance(result)

return ret

if in_place and isinstance(PricingContext.current, HistoricalPricingContext):
if in_place and is_historical:
raise RuntimeError('Cannot resolve in place under a HistoricalPricingContext')

return self.calc(ResolvedInstrumentValues, fn=handle_result)
Expand Down
36 changes: 0 additions & 36 deletions gs_quant/instrument/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,3 @@
specific language governing permissions and limitations
under the License.
"""
from gs_quant.priceable import PriceableImpl
from gs_quant.markets import MarketDataCoordinate
from gs_quant.risk import FloatWithInfo, DataFrameWithInfo, SeriesWithInfo, RiskMeasure
from gs_quant.risk.results import ErrorValue, PricingFuture
from gs_quant.target.instrument import EqOption as __EqOption

from typing import Iterable, Tuple, Union


class EqOption(__EqOption):

def calc(self, risk_measure: Union[RiskMeasure, Iterable[RiskMeasure]], fn=None) ->\
Union[DataFrameWithInfo, ErrorValue, FloatWithInfo, PriceableImpl, PricingFuture, SeriesWithInfo,
Tuple[MarketDataCoordinate, ...]]:
# Tactical fix until the equities pricing service notion of IRDelta is changed to match FICC

from gs_quant.markets import PricingContext
from gs_quant.risk import IRDelta, IRDeltaParallel

error_result = None

if risk_measure == IRDeltaParallel:
risk_measure = IRDelta
elif risk_measure == IRDelta:
error_result = ErrorValue(None, 'IRDelta not supported for EqOption')
elif not isinstance(risk_measure, RiskMeasure):
if IRDelta in risk_measure:
risk_measure = tuple(r for r in risk_measure if r != IRDelta)
else:
risk_measure = tuple(IRDelta if r == IRDeltaParallel else r for r in risk_measure)

if error_result:
return PricingFuture(error_result) if PricingContext.current.is_entered or PricingContext.current.is_async\
else error_result
else:
return super().calc(risk_measure, fn=fn)
Loading

0 comments on commit eabf293

Please sign in to comment.