remove usage of .query from regular models

This commit is contained in:
Matthias 2023-03-15 21:00:30 +01:00
parent 47ab285252
commit b469addffb
3 changed files with 114 additions and 75 deletions

View File

@ -7,7 +7,8 @@ from datetime import datetime, timedelta, timezone
from math import isclose from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, cast from typing import Any, ClassVar, Dict, List, Optional, cast
from sqlalchemy import Enum, Float, ForeignKey, Integer, String, UniqueConstraint, desc, func from sqlalchemy import (Enum, Float, ForeignKey, Integer, Result, Select, String, UniqueConstraint,
desc, func, select)
from sqlalchemy.orm import (Mapped, Query, QueryPropertyDescriptor, lazyload, mapped_column, from sqlalchemy.orm import (Mapped, Query, QueryPropertyDescriptor, lazyload, mapped_column,
relationship) relationship)
@ -1153,7 +1154,9 @@ class LocalTrade():
get open trade count get open trade count
""" """
if Trade.use_db: if Trade.use_db:
return Trade.query.filter(Trade.is_open.is_(True)).count() return Trade._session.scalar(
select(func.count(Trade.id)).filter(Trade.is_open.is_(True))
)
else: else:
return LocalTrade.bt_open_open_trade_count return LocalTrade.bt_open_open_trade_count
@ -1287,18 +1290,18 @@ class Trade(ModelBase, LocalTrade):
def delete(self) -> None: def delete(self) -> None:
for order in self.orders: for order in self.orders:
Order.query.session.delete(order) Order._session.delete(order)
Trade.query.session.delete(self) Trade._session.delete(self)
Trade.commit() Trade.commit()
@staticmethod @staticmethod
def commit(): def commit():
Trade.query.session.commit() Trade._session.commit()
@staticmethod @staticmethod
def rollback(): def rollback():
Trade.query.session.rollback() Trade._session.rollback()
@staticmethod @staticmethod
def get_trades_proxy(*, pair: Optional[str] = None, is_open: Optional[bool] = None, def get_trades_proxy(*, pair: Optional[str] = None, is_open: Optional[bool] = None,
@ -1332,7 +1335,7 @@ class Trade(ModelBase, LocalTrade):
) )
@staticmethod @staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']: def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select['Trade']:
""" """
Helper function to query Trades using filters. Helper function to query Trades using filters.
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
@ -1347,15 +1350,28 @@ class Trade(ModelBase, LocalTrade):
if trade_filter is not None: if trade_filter is not None:
if not isinstance(trade_filter, list): if not isinstance(trade_filter, list):
trade_filter = [trade_filter] trade_filter = [trade_filter]
this_query = Trade.query.filter(*trade_filter) this_query = select(Trade).filter(*trade_filter)
else: else:
this_query = Trade.query this_query = select(Trade)
if not include_orders: if not include_orders:
# Don't load order relations # Don't load order relations
# Consider using noload or raiseload instead of lazyload # Consider using noload or raiseload instead of lazyload
this_query = this_query.options(lazyload(Trade.orders)) this_query = this_query.options(lazyload(Trade.orders))
return this_query return this_query
@staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
:param trade_filter: Optional filter to apply to trades
Can be either a Filter object, or a List of filters
e.g. `(trade_filter=[Trade.id == trade_id, Trade.is_open.is_(True),])`
e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object
"""
return Trade._session.execute(Trade.get_trades_query(trade_filter, include_orders))
@staticmethod @staticmethod
def get_open_order_trades() -> List['Trade']: def get_open_order_trades() -> List['Trade']:
""" """
@ -1392,8 +1408,9 @@ class Trade(ModelBase, LocalTrade):
Retrieves total realized profit Retrieves total realized profit
""" """
if Trade.use_db: if Trade.use_db:
total_profit = Trade.query.with_entities( total_profit = Trade._session.scalar(
func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False)).scalar() select(func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False))
)
else: else:
total_profit = sum( total_profit = sum(
t.close_profit_abs for t in LocalTrade.get_trades_proxy(is_open=False)) t.close_profit_abs for t in LocalTrade.get_trades_proxy(is_open=False))
@ -1406,8 +1423,9 @@ class Trade(ModelBase, LocalTrade):
in stake currency in stake currency
""" """
if Trade.use_db: if Trade.use_db:
total_open_stake_amount = Trade.query.with_entities( total_open_stake_amount = Trade._session.scalar(
func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar() select(func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True))
)
else: else:
total_open_stake_amount = sum( total_open_stake_amount = sum(
t.stake_amount for t in LocalTrade.get_trades_proxy(is_open=True)) t.stake_amount for t in LocalTrade.get_trades_proxy(is_open=True))
@ -1423,15 +1441,18 @@ class Trade(ModelBase, LocalTrade):
if minutes: if minutes:
start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes) start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes)
filters.append(Trade.close_date >= start_date) filters.append(Trade.close_date >= start_date)
pair_rates = Trade.query.with_entities(
Trade.pair, pair_rates = Trade._session.execute(
func.sum(Trade.close_profit).label('profit_sum'), select(
func.sum(Trade.close_profit_abs).label('profit_sum_abs'), Trade.pair,
func.count(Trade.pair).label('count') func.sum(Trade.close_profit).label('profit_sum'),
).filter(*filters)\ func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
.group_by(Trade.pair) \ func.count(Trade.pair).label('count')
.order_by(desc('profit_sum_abs')) \ ).filter(*filters)
.all() .group_by(Trade.pair)
.order_by(desc('profit_sum_abs'))
).all()
return [ return [
{ {
'pair': pair, 'pair': pair,
@ -1456,15 +1477,16 @@ class Trade(ModelBase, LocalTrade):
if (pair is not None): if (pair is not None):
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
enter_tag_perf = Trade.query.with_entities( enter_tag_perf = Trade._session.execute(
Trade.enter_tag, select(
func.sum(Trade.close_profit).label('profit_sum'), Trade.enter_tag,
func.sum(Trade.close_profit_abs).label('profit_sum_abs'), func.sum(Trade.close_profit).label('profit_sum'),
func.count(Trade.pair).label('count') func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
).filter(*filters)\ func.count(Trade.pair).label('count')
.group_by(Trade.enter_tag) \ ).filter(*filters)
.order_by(desc('profit_sum_abs')) \ .group_by(Trade.enter_tag)
.all() .order_by(desc('profit_sum_abs'))
).all()
return [ return [
{ {
@ -1488,16 +1510,16 @@ class Trade(ModelBase, LocalTrade):
filters: List = [Trade.is_open.is_(False)] filters: List = [Trade.is_open.is_(False)]
if (pair is not None): if (pair is not None):
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
sell_tag_perf = Trade._session.execute(
sell_tag_perf = Trade.query.with_entities( select(
Trade.exit_reason, Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'), func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'), func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count') func.count(Trade.pair).label('count')
).filter(*filters)\ ).filter(*filters)
.group_by(Trade.exit_reason) \ .group_by(Trade.exit_reason)
.order_by(desc('profit_sum_abs')) \ .order_by(desc('profit_sum_abs'))
.all() ).all()
return [ return [
{ {
@ -1521,18 +1543,18 @@ class Trade(ModelBase, LocalTrade):
filters: List = [Trade.is_open.is_(False)] filters: List = [Trade.is_open.is_(False)]
if (pair is not None): if (pair is not None):
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
mix_tag_perf = Trade._session.execute(
mix_tag_perf = Trade.query.with_entities( select(
Trade.id, Trade.id,
Trade.enter_tag, Trade.enter_tag,
Trade.exit_reason, Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'), func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'), func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count') func.count(Trade.pair).label('count')
).filter(*filters)\ ).filter(*filters)
.group_by(Trade.id) \ .group_by(Trade.id)
.order_by(desc('profit_sum_abs')) \ .order_by(desc('profit_sum_abs'))
.all() ).all()
return_list: List[Dict] = [] return_list: List[Dict] = []
for id, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf: for id, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf:
@ -1568,11 +1590,15 @@ class Trade(ModelBase, LocalTrade):
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
:returns: Tuple containing (pair, profit_sum) :returns: Tuple containing (pair, profit_sum)
""" """
best_pair = Trade.query.with_entities( best_pair = Trade._session.execute(
Trade.pair, func.sum(Trade.close_profit).label('profit_sum') select(
).filter(Trade.is_open.is_(False) & (Trade.close_date >= start_date)) \ Trade.pair,
.group_by(Trade.pair) \ func.sum(Trade.close_profit).label('profit_sum')
.order_by(desc('profit_sum')).first() ).filter(Trade.is_open.is_(False) & (Trade.close_date >= start_date))
.group_by(Trade.pair)
.order_by(desc('profit_sum'))
).first()
return best_pair return best_pair
@staticmethod @staticmethod

View File

@ -5,7 +5,7 @@ import logging
from abc import abstractmethod from abc import abstractmethod
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from math import isnan from math import isnan
from typing import Any, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union
import arrow import arrow
import psutil import psutil
@ -13,6 +13,7 @@ from dateutil.relativedelta import relativedelta
from dateutil.tz import tzlocal from dateutil.tz import tzlocal
from numpy import NAN, inf, int64, mean from numpy import NAN, inf, int64, mean
from pandas import DataFrame, NaT from pandas import DataFrame, NaT
from sqlalchemy import func, select
from freqtrade import __version__ from freqtrade import __version__
from freqtrade.configuration.timerange import TimeRange from freqtrade.configuration.timerange import TimeRange
@ -339,11 +340,18 @@ class RPC:
for day in range(0, timescale): for day in range(0, timescale):
profitday = start_date - time_offset(day) profitday = start_date - time_offset(day)
# Only query for necessary columns for performance reasons. # Only query for necessary columns for performance reasons.
trades = Trade.query.session.query(Trade.close_profit_abs).filter( trades = Trade._session.execute(
Trade.is_open.is_(False), select(Trade.close_profit_abs)
Trade.close_date >= profitday, .filter(Trade.is_open.is_(False),
Trade.close_date < (profitday + time_offset(1)) Trade.close_date >= profitday,
).order_by(Trade.close_date).all() Trade.close_date < (profitday + time_offset(1)))
.order_by(Trade.close_date)
).all()
# trades = Trade.query.session.query(Trade.close_profit_abs).filter(
# Trade.is_open.is_(False),
# Trade.close_date >= profitday,
# Trade.close_date < (profitday + time_offset(1))
# ).order_by(Trade.close_date).all()
curdayprofit = sum( curdayprofit = sum(
trade.close_profit_abs for trade in trades if trade.close_profit_abs is not None) trade.close_profit_abs for trade in trades if trade.close_profit_abs is not None)
@ -381,14 +389,19 @@ class RPC:
""" Returns the X last trades """ """ Returns the X last trades """
order_by: Any = Trade.id if order_by_id else Trade.close_date.desc() order_by: Any = Trade.id if order_by_id else Trade.close_date.desc()
if limit: if limit:
trades = Trade.get_trades([Trade.is_open.is_(False)]).order_by( trades = Trade._session.execute(
order_by).limit(limit).offset(offset) Trade.get_trades_query([Trade.is_open.is_(False)])
.order_by(order_by)
.limit(limit)
.offset(offset))
else: else:
trades = Trade.get_trades([Trade.is_open.is_(False)]).order_by( trades = Trade._session.execute(
Trade.close_date.desc()) Trade.get_trades_query([Trade.is_open.is_(False)])
.order_by(Trade.close_date.desc()))
output = [trade.to_json() for trade in trades] output = [trade.to_json() for trade in trades]
total_trades = Trade.get_trades([Trade.is_open.is_(False)]).count() total_trades = Trade._session.scalar(
select(func.count(Trade.id)).filter(Trade.is_open.is_(False)))
return { return {
"trades": output, "trades": output,
@ -436,8 +449,8 @@ class RPC:
""" Returns cumulative profit statistics """ """ Returns cumulative profit statistics """
trade_filter = ((Trade.is_open.is_(False) & (Trade.close_date >= start_date)) | trade_filter = ((Trade.is_open.is_(False) & (Trade.close_date >= start_date)) |
Trade.is_open.is_(True)) Trade.is_open.is_(True))
trades: List[Trade] = Trade.get_trades( trades: Sequence[Trade] = Trade._session.execute(Trade.get_trades_query(
trade_filter, include_orders=False).order_by(Trade.id).all() trade_filter, include_orders=False).order_by(Trade.id)).all()
profit_all_coin = [] profit_all_coin = []
profit_all_ratio = [] profit_all_ratio = []

View File

@ -1793,17 +1793,17 @@ def test_get_trades_proxy(fee, use_db, is_short):
@pytest.mark.usefixtures("init_persistence") @pytest.mark.usefixtures("init_persistence")
@pytest.mark.parametrize('is_short', [True, False]) @pytest.mark.parametrize('is_short', [True, False])
def test_get_trades__query(fee, is_short): def test_get_trades__query(fee, is_short):
query = Trade.get_trades([]) query = Trade.get_trades_query([])
# without orders there should be no join issued. # without orders there should be no join issued.
query1 = Trade.get_trades([], include_orders=False) query1 = Trade.get_trades_query([], include_orders=False)
# Empty "with-options -> default - selectin" # Empty "with-options -> default - selectin"
assert query._with_options == () assert query._with_options == ()
assert query1._with_options != () assert query1._with_options != ()
create_mock_trades(fee, is_short) create_mock_trades(fee, is_short)
query = Trade.get_trades([]) query = Trade.get_trades_query([])
query1 = Trade.get_trades([], include_orders=False) query1 = Trade.get_trades_query([], include_orders=False)
assert query._with_options == () assert query._with_options == ()
assert query1._with_options != () assert query1._with_options != ()