Skip to content

Commit

Permalink
Signals weighting bug - using wrong key for sorting (#1143)
Browse files Browse the repository at this point in the history
- Use `signal.signal` value instead of `signal.raw_weight` when ranking alpha model results
- Work to support multiple indicator instances in the same backtest with different parameters
  • Loading branch information
miohtama authored Jan 23, 2025
1 parent 8f21fe9 commit 125a777
Show file tree
Hide file tree
Showing 9 changed files with 570 additions and 47 deletions.
194 changes: 190 additions & 4 deletions tests/backtest/test_indicator_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
import pandas as pd
import pandas_ta
import pytest
from pandas._libs.tslibs.offsets import MonthBegin

from tradeexecutor.backtest.backtest_runner import run_backtest_inline, BacktestResult
from tradeexecutor.state.identifier import AssetIdentifier, TradingPairIdentifier
from tradeexecutor.strategy.cycle import CycleDuration
from tradeexecutor.strategy.execution_context import unit_test_execution_context
from tradeexecutor.strategy.pandas_trader.indicator import (
DiskIndicatorStorage,
IndicatorSource, IndicatorDependencyResolver,
calculate_and_load_indicators_inline,
calculate_and_load_indicators_inline, prepare_indicators, IndicatorDefinition, IndicatorNotFound,
)
from tradeexecutor.strategy.pandas_trader.indicator_decorator import IndicatorRegistry
from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInputIndicators
from tradeexecutor.strategy.parameters import StrategyParameters
from tradeexecutor.strategy.pandas_trader.indicator_decorator import IndicatorRegistry, flatten_dict_permutations
from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInputIndicators, StrategyInput, IndicatorWithVariations
from tradeexecutor.strategy.parameters import StrategyParameters, RollingParameter
from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse, create_pair_universe_from_code
from tradeexecutor.testing.synthetic_ethereum_data import generate_random_ethereum_address
from tradeexecutor.testing.synthetic_exchange_data import generate_exchange
Expand Down Expand Up @@ -244,3 +247,186 @@ def rsi_derivate(rsi_length, pair, dependency_resolver):
)

assert isinstance(indicators, StrategyInputIndicators)


def test_get_indicator_rolling_parameters(strategy_universe):
"""We create multiple indicator parameter variations for rolling indicators."""
indicators = IndicatorRegistry()

rolling_data = pd.Series(
data=[21, 22, 23, 24, 25, 26],
index=pd.Index([
pd.Timestamp("2021-06-01"),
pd.Timestamp("2021-07-01"),
pd.Timestamp("2021-08-01"),
pd.Timestamp("2021-09-01"),
pd.Timestamp("2021-10-01"),
pd.Timestamp("2021-12-01"),
]),
)

other_param_data = pd.Series(
data=[1, 2, 3, 4, 5, 6],
index=pd.Index([
pd.Timestamp("2021-06-01"),
pd.Timestamp("2021-07-01"),
pd.Timestamp("2021-08-01"),
pd.Timestamp("2021-09-01"),
pd.Timestamp("2021-10-01"),
pd.Timestamp("2021-12-01"),
]),
)

class Parameters:

fixed_parameter = 10

rsi_length = RollingParameter(
name="rsi_length",
freq=MonthBegin(1),
values=rolling_data,
)

other_param = RollingParameter(
name="other_param",
freq=MonthBegin(1),
values=other_param_data,
)

backtest_start = datetime.datetime(2021, 6, 1)
backtest_end = datetime.datetime(2022, 1, 1)
initial_cash = 10_000
cycle_duration = CycleDuration.cycle_1d

@indicators.define()
def fixed_rsi(close, fixed_parameter, pair, dependency_resolver):
assert isinstance(close, pd.Series)
assert type(fixed_parameter) == int
assert isinstance(pair, TradingPairIdentifier)
assert isinstance(dependency_resolver, IndicatorDependencyResolver)
return close

@indicators.define()
def rsi(close, rsi_length, pair, dependency_resolver):
assert isinstance(close, pd.Series)
assert type(rsi_length) == int
assert isinstance(pair, TradingPairIdentifier)
assert isinstance(dependency_resolver, IndicatorDependencyResolver)
return close * rsi_length

@indicators.define(source=IndicatorSource.dependencies_only_per_pair, dependencies=[rsi])
def rsi_derivative(rsi_length, other_param, pair, dependency_resolver):
assert type(rsi_length) == int
assert type(other_param) == int
assert isinstance(pair, TradingPairIdentifier)
assert isinstance(dependency_resolver, IndicatorDependencyResolver)
rsi = dependency_resolver.get_indicator_data(
"rsi",
pair=pair,
parameters={
"rsi_length": rsi_length,
}
)
return rsi * other_param

parameters = StrategyParameters.from_class(Parameters)

indicator_set = prepare_indicators(
indicators.create_indicators,
parameters,
strategy_universe,
unit_test_execution_context,
)

for ind in indicator_set.indicators.values():
assert isinstance(ind, IndicatorDefinition)
if not ind.name.startswith("fixed"):
assert ind.variations is True

assert len(indicator_set.indicators.values()) == 43

strategy_input_indicators = calculate_and_load_indicators_inline(
strategy_universe=strategy_universe,
parameters=parameters,
indicator_set=indicator_set,
verbose=False,
)

assert isinstance(strategy_input_indicators, StrategyInputIndicators)

# Make sure we have access to every variation of the indicator
def decide_trades(input: StrategyInput):
timestamp = input.timestamp
indicators = input.indicators

pair = input.strategy_universe.get_pair_by_id(1)

_ = indicators.get_indicator_value("fixed_rsi", pair=pair)

with pytest.raises(IndicatorWithVariations):
indicators.get_indicator_value("rsi", pair=pair)

with pytest.raises(IndicatorWithVariations):
indicators.get_indicator_value("rsi_derivative", pair=pair)

rsi_1 = indicators.get_indicator_value(
"rsi",
pair=pair,
parameters={"rsi_length": 21},
)
assert rsi_1 > 0

rsi_2 = indicators.get_indicator_value(
"rsi",
pair=pair,
parameters={"rsi_length": 22},
)
assert rsi_2 > 0

rsi_3 = indicators.get_indicator_value(
"rsi_derivative",
pair=pair,
parameters={
"rsi_length": 22,
"other_param": 2,
},
)
assert rsi_3 > 3

with pytest.raises(IndicatorNotFound):
_ = indicators.get_indicator_value(
"rsi_derivative",
pair=pair,
parameters={
"rsi_length": 0,
"other_param": 0,
},
)

return []

backtest_result = run_backtest_inline(
start_at=datetime.datetime(2021, 6, 1),
end_at=datetime.datetime(2022, 1, 1),
client=None,
cycle_duration=CycleDuration.cycle_1d,
decide_trades=decide_trades,
universe=strategy_universe,
engine_version="0.5",
create_indicators=indicators.create_indicators,
parameters=parameters,
)

assert isinstance(backtest_result, BacktestResult)


def test_param_permutations():

input = {
"rsi": [21, 22, 23],
"other_param": [1, 2],
"fixed_val": 1,
}

permutations = flatten_dict_permutations(input)
assert len(permutations) == 6
57 changes: 53 additions & 4 deletions tradeexecutor/analysis/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def visualise_weights(
template="plotly_dark",
include_reserves=True,
legend_mode: LegendMode=LegendMode.side,
columns=20,
aave_colour='#9896FF',
reserve_asset_colour='#aaa',
clean=False,
) -> Figure:
"""Draw a chart of weights.
Expand All @@ -97,6 +99,11 @@ def visualise_weights(
:param include_reserves:
Include reserve positions like USDC in the output.
:param clean:
Remove title texts.
Good for screenshots.
:return:
Plotly chart
"""
Expand Down Expand Up @@ -149,14 +156,56 @@ def sort_key_reserve_first(col_name):
for symbol in non_volatile_symbols:
# Aave colour
# https://aave.com/brand
fig.update_traces(fillcolor='#9896FF', selector=dict(name=symbol))
fig.update_traces(fillcolor='#aaa', selector=dict(name=reserve_asset_symbol))
fig.update_traces(fillcolor=aave_colour, selector=dict(name=symbol))
fig.update_traces(fillcolor=reserve_asset_colour, selector=dict(name=reserve_asset_symbol))
fig.update_traces(line_width=0)

match legend_mode:
case LegendMode.bottom:
# Adjust legend properties
pass
fig.update_layout(
# Move legend to bottom
legend=dict(
yanchor="top",
y=-0.1, # Adjust this value to move legend up/down
xanchor="center",
x=0.5,
# Arrange items in 4 rows
orientation="h",
traceorder="normal",
# nrows=4
itemwidth=40, # Adjust the multiplier as needed
title_text="",
font=dict(
size=20 # Adjust this value to make legend text bigger/smaller
),
)
)

if clean:
fig.update_layout(
title=None,
xaxis=dict(
title=None,
# other x-axis properties...
nticks=4,
# Increase font size (default is usually 12)
tickfont=dict(
size=22 # Adjust this value to make font bigger/smaller
)
),
yaxis=dict(
title=None,
# other y-axis properties...
nticks=5,
# Optionally specify tick labels
# ticktext=['0%', '50%', '100%'],
tickfont=dict(
size=22,
),

)
)

return fig

Expand Down
32 changes: 25 additions & 7 deletions tradeexecutor/backtest/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,12 @@ class GridSearchResult:
#:
delivered_to_main_thread_at: datetime.datetime | None = None

#: Include first trade timestamp
#:
#: Useful for quick debugging.
#:
first_trade_at: datetime.datetime | None = None

def __hash__(self):
return self.combination.__hash__()

Expand Down Expand Up @@ -1320,14 +1326,19 @@ def run_grid_search_backtest(
if cycle_debug_data is None:
cycle_debug_data = {}

backtest_start = datetime.datetime.utcnow()
duration_start = datetime.datetime.utcnow()

universe_range = universe.data_universe.candles.get_timestamp_range()
if not start_at:
start_at = universe_range[0]
if parameters and parameters.get("backtest_start"):
start_at = parameters["backtest_start"]
end_at = parameters["backtest_end"]
else:

if not end_at:
end_at = universe_range[1]
universe_range = universe.data_universe.candles.get_timestamp_range()
if not start_at:
start_at = universe_range[0]

if not end_at:
end_at = universe_range[1]

if isinstance(start_at, datetime.datetime):
start_at = pd.Timestamp(start_at)
Expand Down Expand Up @@ -1426,6 +1437,12 @@ def run_grid_search_backtest(

period = state.get_trading_time_range()

try:
first_trade = next(iter(state.portfolio.get_all_trades()))
first_trade_at = first_trade.executed_at
except StopIteration:
first_trade_at = None

res = GridSearchResult(
combination=combination,
state=state,
Expand All @@ -1435,11 +1452,12 @@ def run_grid_search_backtest(
equity_curve=equity,
returns=returns,
initial_cash=state.portfolio.get_initial_cash(),
run_start_at=backtest_start,
run_start_at=duration_start,
run_end_at=backtest_end,
analysis_end_at=analysis_end,
backtest_start=period[0],
backtest_end=period[1],
first_trade_at=first_trade_at,
)

# Double check we have not broken QuantStats again
Expand Down
2 changes: 1 addition & 1 deletion tradeexecutor/strategy/alpha_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def select_top_signals(
`0.01 = 1%` signal strenght.
"""
filtered_signals = [s for s in self.raw_signals.values() if abs(s.signal) >= threshold]
top_signals = heapq.nlargest(count, filtered_signals, key=lambda s: s.raw_weight)
top_signals = heapq.nlargest(count, filtered_signals, key=lambda s: s.signal)
self.signals = {s.pair.internal_id: s for s in top_signals}

def _normalise_weights_simple(
Expand Down
Loading

0 comments on commit 125a777

Please sign in to comment.