Merge pull request #4818 from freqtrade/cleanup_models

Move static Trade functions to right class
This commit is contained in:
Matthias 2021-04-28 21:18:55 +02:00 committed by GitHub
commit aab020c9a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 132 additions and 99 deletions

View File

@ -567,23 +567,6 @@ class LocalTrade():
else: else:
return None return None
@staticmethod
def get_trades(trade_filter=None) -> Query:
"""
Helper function to query Trades using filters.
:param trade_filter: Optional filter to apply to trades
Can be either a Filter object, or a List of filters
e.g. `(trade_filter=[Trade.id == trade_id, Trade.is_open.is_(True),])`
e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object
"""
if trade_filter is not None:
if not isinstance(trade_filter, list):
trade_filter = [trade_filter]
return Trade.query.filter(*trade_filter)
else:
return Trade.query
@staticmethod @staticmethod
def get_trades_proxy(*, pair: str = None, is_open: bool = None, def get_trades_proxy(*, pair: str = None, is_open: bool = None,
open_date: datetime = None, close_date: datetime = None, open_date: datetime = None, close_date: datetime = None,
@ -636,83 +619,7 @@ class LocalTrade():
""" """
Query trades from persistence layer Query trades from persistence layer
""" """
return Trade.get_trades(Trade.is_open.is_(True)).all() return Trade.get_trades_proxy(is_open=True)
@staticmethod
def get_open_order_trades():
"""
Returns all open trades
"""
return Trade.get_trades(Trade.open_order_id.isnot(None)).all()
@staticmethod
def get_open_trades_without_assigned_fees():
"""
Returns all open trades which don't have open fees set correctly
"""
return Trade.get_trades([Trade.fee_open_currency.is_(None),
Trade.orders.any(),
Trade.is_open.is_(True),
]).all()
@staticmethod
def get_sold_trades_without_assigned_fees():
"""
Returns all closed trades which don't have fees set correctly
"""
return Trade.get_trades([Trade.fee_close_currency.is_(None),
Trade.orders.any(),
Trade.is_open.is_(False),
]).all()
@staticmethod
def total_open_trades_stakes() -> float:
"""
Calculates total invested amount in open trades
in stake currency
"""
if Trade.use_db:
total_open_stake_amount = Trade.query.with_entities(
func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar()
else:
total_open_stake_amount = sum(
t.stake_amount for t in Trade.get_trades_proxy(is_open=True))
return total_open_stake_amount or 0
@staticmethod
def get_overall_performance() -> List[Dict[str, Any]]:
"""
Returns List of dicts containing all Trades, including profit and trade count
"""
pair_rates = Trade.query.with_entities(
Trade.pair,
func.sum(Trade.close_profit).label('profit_sum'),
func.count(Trade.pair).label('count')
).filter(Trade.is_open.is_(False))\
.group_by(Trade.pair) \
.order_by(desc('profit_sum')) \
.all()
return [
{
'pair': pair,
'profit': rate,
'count': count
}
for pair, rate, count in pair_rates
]
@staticmethod
def get_best_pair():
"""
Get best pair with closed trade.
:returns: Tuple containing (pair, profit_sum)
"""
best_pair = Trade.query.with_entities(
Trade.pair, func.sum(Trade.close_profit).label('profit_sum')
).filter(Trade.is_open.is_(False)) \
.group_by(Trade.pair) \
.order_by(desc('profit_sum')).first()
return best_pair
@staticmethod @staticmethod
def stoploss_reinitialization(desired_stoploss): def stoploss_reinitialization(desired_stoploss):
@ -810,7 +717,7 @@ class Trade(_DECL_BASE, LocalTrade):
open_date: datetime = None, close_date: datetime = None, open_date: datetime = None, close_date: datetime = None,
) -> List['LocalTrade']: ) -> List['LocalTrade']:
""" """
Helper function to query Trades. Helper function to query Trades.j
Returns a List of trades, filtered on the parameters given. Returns a List of trades, filtered on the parameters given.
In live mode, converts the filter to a database query and returns all rows In live mode, converts the filter to a database query and returns all rows
In Backtest mode, uses filters on Trade.trades to get the result. In Backtest mode, uses filters on Trade.trades to get the result.
@ -835,6 +742,107 @@ class Trade(_DECL_BASE, LocalTrade):
close_date=close_date close_date=close_date
) )
@staticmethod
def get_trades(trade_filter=None) -> Query:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
:param trade_filter: Optional filter to apply to trades
Can be either a Filter object, or a List of filters
e.g. `(trade_filter=[Trade.id == trade_id, Trade.is_open.is_(True),])`
e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object
"""
if not Trade.use_db:
raise NotImplementedError('`Trade.get_trades()` not supported in backtesting mode.')
if trade_filter is not None:
if not isinstance(trade_filter, list):
trade_filter = [trade_filter]
return Trade.query.filter(*trade_filter)
else:
return Trade.query
@staticmethod
def get_open_order_trades():
"""
Returns all open trades
NOTE: Not supported in Backtesting.
"""
return Trade.get_trades(Trade.open_order_id.isnot(None)).all()
@staticmethod
def get_open_trades_without_assigned_fees():
"""
Returns all open trades which don't have open fees set correctly
NOTE: Not supported in Backtesting.
"""
return Trade.get_trades([Trade.fee_open_currency.is_(None),
Trade.orders.any(),
Trade.is_open.is_(True),
]).all()
@staticmethod
def get_sold_trades_without_assigned_fees():
"""
Returns all closed trades which don't have fees set correctly
NOTE: Not supported in Backtesting.
"""
return Trade.get_trades([Trade.fee_close_currency.is_(None),
Trade.orders.any(),
Trade.is_open.is_(False),
]).all()
@staticmethod
def total_open_trades_stakes() -> float:
"""
Calculates total invested amount in open trades
in stake currency
"""
if Trade.use_db:
total_open_stake_amount = Trade.query.with_entities(
func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar()
else:
total_open_stake_amount = sum(
t.stake_amount for t in LocalTrade.get_trades_proxy(is_open=True))
return total_open_stake_amount or 0
@staticmethod
def get_overall_performance() -> List[Dict[str, Any]]:
"""
Returns List of dicts containing all Trades, including profit and trade count
NOTE: Not supported in Backtesting.
"""
pair_rates = Trade.query.with_entities(
Trade.pair,
func.sum(Trade.close_profit).label('profit_sum'),
func.count(Trade.pair).label('count')
).filter(Trade.is_open.is_(False))\
.group_by(Trade.pair) \
.order_by(desc('profit_sum')) \
.all()
return [
{
'pair': pair,
'profit': rate,
'count': count
}
for pair, rate, count in pair_rates
]
@staticmethod
def get_best_pair():
"""
Get best pair with closed trade.
NOTE: Not supported in Backtesting.
:returns: Tuple containing (pair, profit_sum)
"""
best_pair = Trade.query.with_entities(
Trade.pair, func.sum(Trade.close_profit).label('profit_sum')
).filter(Trade.is_open.is_(False)) \
.group_by(Trade.pair) \
.order_by(desc('profit_sum')).first()
return best_pair
class PairLock(_DECL_BASE): class PairLock(_DECL_BASE):
""" """

View File

@ -774,11 +774,16 @@ def test_adjust_min_max_rates(fee):
@pytest.mark.usefixtures("init_persistence") @pytest.mark.usefixtures("init_persistence")
def test_get_open(fee): @pytest.mark.parametrize('use_db', [True, False])
def test_get_open(fee, use_db):
Trade.use_db = use_db
Trade.reset_trades()
create_mock_trades(fee) create_mock_trades(fee, use_db)
assert len(Trade.get_open_trades()) == 4 assert len(Trade.get_open_trades()) == 4
Trade.use_db = True
@pytest.mark.usefixtures("init_persistence") @pytest.mark.usefixtures("init_persistence")
def test_to_json(default_conf, fee): def test_to_json(default_conf, fee):
@ -1083,6 +1088,13 @@ def test_get_trades_proxy(fee, use_db):
Trade.use_db = True Trade.use_db = True
def test_get_trades_backtest():
Trade.use_db = False
with pytest.raises(NotImplementedError, match=r"`Trade.get_trades\(\)` not .*"):
Trade.get_trades([])
Trade.use_db = True
@pytest.mark.usefixtures("init_persistence") @pytest.mark.usefixtures("init_persistence")
def test_get_overall_performance(fee): def test_get_overall_performance(fee):
@ -1216,11 +1228,24 @@ def test_Trade_object_idem():
trade = vars(Trade) trade = vars(Trade)
localtrade = vars(LocalTrade) localtrade = vars(LocalTrade)
excludes = (
'delete',
'session',
'query',
'open_date',
'get_best_pair',
'get_overall_performance',
'total_open_trades_stakes',
'get_sold_trades_without_assigned_fees',
'get_open_trades_without_assigned_fees',
'get_open_order_trades',
'get_trades',
)
# Parent (LocalTrade) should have the same attributes # Parent (LocalTrade) should have the same attributes
for item in trade: for item in trade:
# Exclude private attributes and open_date (as it's not assigned a default) # Exclude private attributes and open_date (as it's not assigned a default)
if (not item.startswith('_') if (not item.startswith('_') and item not in excludes):
and item not in ('delete', 'session', 'query', 'open_date')):
assert item in localtrade assert item in localtrade
# Fails if only a column is added without corresponding parent field # Fails if only a column is added without corresponding parent field