remove usage of .query from regular models

This commit is contained in:
Matthias
2023-03-15 21:00:30 +01:00
parent 47ab285252
commit b469addffb
3 changed files with 114 additions and 75 deletions

View File

@@ -7,7 +7,8 @@ from datetime import datetime, timedelta, timezone
from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, cast
from sqlalchemy import Enum, Float, ForeignKey, Integer, String, UniqueConstraint, desc, func
from sqlalchemy import (Enum, Float, ForeignKey, Integer, Result, Select, String, UniqueConstraint,
desc, func, select)
from sqlalchemy.orm import (Mapped, Query, QueryPropertyDescriptor, lazyload, mapped_column,
relationship)
@@ -1153,7 +1154,9 @@ class LocalTrade():
get open trade count
"""
if Trade.use_db:
return Trade.query.filter(Trade.is_open.is_(True)).count()
return Trade._session.scalar(
select(func.count(Trade.id)).filter(Trade.is_open.is_(True))
)
else:
return LocalTrade.bt_open_open_trade_count
@@ -1287,18 +1290,18 @@ class Trade(ModelBase, LocalTrade):
def delete(self) -> None:
for order in self.orders:
Order.query.session.delete(order)
Order._session.delete(order)
Trade.query.session.delete(self)
Trade._session.delete(self)
Trade.commit()
@staticmethod
def commit():
Trade.query.session.commit()
Trade._session.commit()
@staticmethod
def rollback():
Trade.query.session.rollback()
Trade._session.rollback()
@staticmethod
def get_trades_proxy(*, pair: Optional[str] = None, is_open: Optional[bool] = None,
@@ -1332,7 +1335,7 @@ class Trade(ModelBase, LocalTrade):
)
@staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']:
def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select['Trade']:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
@@ -1347,15 +1350,28 @@ class Trade(ModelBase, LocalTrade):
if trade_filter is not None:
if not isinstance(trade_filter, list):
trade_filter = [trade_filter]
this_query = Trade.query.filter(*trade_filter)
this_query = select(Trade).filter(*trade_filter)
else:
this_query = Trade.query
this_query = select(Trade)
if not include_orders:
# Don't load order relations
# Consider using noload or raiseload instead of lazyload
this_query = this_query.options(lazyload(Trade.orders))
return this_query
@staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
:param trade_filter: Optional filter to apply to trades
Can be either a Filter object, or a List of filters
e.g. `(trade_filter=[Trade.id == trade_id, Trade.is_open.is_(True),])`
e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object
"""
return Trade._session.execute(Trade.get_trades_query(trade_filter, include_orders))
@staticmethod
def get_open_order_trades() -> List['Trade']:
"""
@@ -1392,8 +1408,9 @@ class Trade(ModelBase, LocalTrade):
Retrieves total realized profit
"""
if Trade.use_db:
total_profit = Trade.query.with_entities(
func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False)).scalar()
total_profit = Trade._session.scalar(
select(func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False))
)
else:
total_profit = sum(
t.close_profit_abs for t in LocalTrade.get_trades_proxy(is_open=False))
@@ -1406,8 +1423,9 @@ class Trade(ModelBase, LocalTrade):
in stake currency
"""
if Trade.use_db:
total_open_stake_amount = Trade.query.with_entities(
func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar()
total_open_stake_amount = Trade._session.scalar(
select(func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True))
)
else:
total_open_stake_amount = sum(
t.stake_amount for t in LocalTrade.get_trades_proxy(is_open=True))
@@ -1423,15 +1441,18 @@ class Trade(ModelBase, LocalTrade):
if minutes:
start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes)
filters.append(Trade.close_date >= start_date)
pair_rates = Trade.query.with_entities(
Trade.pair,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)\
.group_by(Trade.pair) \
.order_by(desc('profit_sum_abs')) \
.all()
pair_rates = Trade._session.execute(
select(
Trade.pair,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)
.group_by(Trade.pair)
.order_by(desc('profit_sum_abs'))
).all()
return [
{
'pair': pair,
@@ -1456,15 +1477,16 @@ class Trade(ModelBase, LocalTrade):
if (pair is not None):
filters.append(Trade.pair == pair)
enter_tag_perf = Trade.query.with_entities(
Trade.enter_tag,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)\
.group_by(Trade.enter_tag) \
.order_by(desc('profit_sum_abs')) \
.all()
enter_tag_perf = Trade._session.execute(
select(
Trade.enter_tag,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)
.group_by(Trade.enter_tag)
.order_by(desc('profit_sum_abs'))
).all()
return [
{
@@ -1488,16 +1510,16 @@ class Trade(ModelBase, LocalTrade):
filters: List = [Trade.is_open.is_(False)]
if (pair is not None):
filters.append(Trade.pair == pair)
sell_tag_perf = Trade.query.with_entities(
Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)\
.group_by(Trade.exit_reason) \
.order_by(desc('profit_sum_abs')) \
.all()
sell_tag_perf = Trade._session.execute(
select(
Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)
.group_by(Trade.exit_reason)
.order_by(desc('profit_sum_abs'))
).all()
return [
{
@@ -1521,18 +1543,18 @@ class Trade(ModelBase, LocalTrade):
filters: List = [Trade.is_open.is_(False)]
if (pair is not None):
filters.append(Trade.pair == pair)
mix_tag_perf = Trade.query.with_entities(
Trade.id,
Trade.enter_tag,
Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)\
.group_by(Trade.id) \
.order_by(desc('profit_sum_abs')) \
.all()
mix_tag_perf = Trade._session.execute(
select(
Trade.id,
Trade.enter_tag,
Trade.exit_reason,
func.sum(Trade.close_profit).label('profit_sum'),
func.sum(Trade.close_profit_abs).label('profit_sum_abs'),
func.count(Trade.pair).label('count')
).filter(*filters)
.group_by(Trade.id)
.order_by(desc('profit_sum_abs'))
).all()
return_list: List[Dict] = []
for id, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf:
@@ -1568,11 +1590,15 @@ class Trade(ModelBase, LocalTrade):
NOTE: Not supported in Backtesting.
:returns: Tuple containing (pair, profit_sum)
"""
best_pair = Trade.query.with_entities(
Trade.pair, func.sum(Trade.close_profit).label('profit_sum')
).filter(Trade.is_open.is_(False) & (Trade.close_date >= start_date)) \
.group_by(Trade.pair) \
.order_by(desc('profit_sum')).first()
best_pair = Trade._session.execute(
select(
Trade.pair,
func.sum(Trade.close_profit).label('profit_sum')
).filter(Trade.is_open.is_(False) & (Trade.close_date >= start_date))
.group_by(Trade.pair)
.order_by(desc('profit_sum'))
).first()
return best_pair
@staticmethod