Add test to demonstrate that the dataframe is not changed
This commit is contained in:
parent
4ee0cbb575
commit
ddf37ef059
@ -213,7 +213,7 @@ def calculate_max_drawdown(trades: pd.DataFrame, *, date_col: str = 'close_time'
|
|||||||
"""
|
"""
|
||||||
if len(trades) == 0:
|
if len(trades) == 0:
|
||||||
raise ValueError("Trade dataframe empty.")
|
raise ValueError("Trade dataframe empty.")
|
||||||
profit_results = trades.sort_values(date_col).reset_index()
|
profit_results = trades.sort_values(date_col).reset_index(drop=True)
|
||||||
max_drawdown_df = pd.DataFrame()
|
max_drawdown_df = pd.DataFrame()
|
||||||
max_drawdown_df['cumulative'] = profit_results[value_col].cumsum()
|
max_drawdown_df['cumulative'] = profit_results[value_col].cumsum()
|
||||||
max_drawdown_df['high_value'] = max_drawdown_df['cumulative'].cummax()
|
max_drawdown_df['high_value'] = max_drawdown_df['cumulative'].cummax()
|
||||||
|
@ -201,7 +201,13 @@ def test_calculate_max_drawdown2():
|
|||||||
|
|
||||||
dates = [Arrow(2020, 1, 1).shift(days=i) for i in range(len(values))]
|
dates = [Arrow(2020, 1, 1).shift(days=i) for i in range(len(values))]
|
||||||
df = DataFrame(zip(values, dates), columns=['profit', 'open_time'])
|
df = DataFrame(zip(values, dates), columns=['profit', 'open_time'])
|
||||||
|
# sort by profit and reset index
|
||||||
|
df = df.sort_values('profit').reset_index(drop=True)
|
||||||
|
df1 = df.copy()
|
||||||
drawdown, h, low = calculate_max_drawdown(df, date_col='open_time', value_col='profit')
|
drawdown, h, low = calculate_max_drawdown(df, date_col='open_time', value_col='profit')
|
||||||
|
# Ensure df has not been altered.
|
||||||
|
assert df.equals(df1)
|
||||||
|
|
||||||
assert isinstance(drawdown, float)
|
assert isinstance(drawdown, float)
|
||||||
# High must be before low
|
# High must be before low
|
||||||
assert h < low
|
assert h < low
|
||||||
|
Loading…
Reference in New Issue
Block a user