Fix some type errors

This commit is contained in:
Matthias 2023-03-15 21:09:25 +01:00
parent b469addffb
commit d45599ca3b
3 changed files with 16 additions and 17 deletions

View File

@ -373,7 +373,7 @@ def load_trades_from_db(db_url: str, strategy: Optional[str] = None) -> pd.DataF
filters = [] filters = []
if strategy: if strategy:
filters.append(Trade.strategy == strategy) filters.append(Trade.strategy == strategy)
trades = trade_list_to_dataframe(Trade.get_trades(filters).all()) trades = trade_list_to_dataframe(list(Trade.get_trades(filters).all()))
return trades return trades

View File

@ -7,10 +7,9 @@ from datetime import datetime, timedelta, timezone
from math import isclose from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, cast from typing import Any, ClassVar, Dict, List, Optional, cast
from sqlalchemy import (Enum, Float, ForeignKey, Integer, Result, Select, String, UniqueConstraint, from sqlalchemy import (Enum, Float, ForeignKey, Integer, ScalarResult, Select, String,
desc, func, select) UniqueConstraint, desc, func, select)
from sqlalchemy.orm import (Mapped, Query, QueryPropertyDescriptor, lazyload, mapped_column, from sqlalchemy.orm import Mapped, QueryPropertyDescriptor, lazyload, mapped_column, relationship
relationship)
from freqtrade.constants import (DATETIME_PRINT_FORMAT, MATH_CLOSE_PREC, NON_OPEN_EXCHANGE_STATES, from freqtrade.constants import (DATETIME_PRINT_FORMAT, MATH_CLOSE_PREC, NON_OPEN_EXCHANGE_STATES,
BuySell, LongShort) BuySell, LongShort)
@ -1154,9 +1153,9 @@ class LocalTrade():
get open trade count get open trade count
""" """
if Trade.use_db: if Trade.use_db:
return Trade._session.scalar( return Trade._session.execute(
select(func.count(Trade.id)).filter(Trade.is_open.is_(True)) select(func.count(Trade.id)).filter(Trade.is_open.is_(True))
) ).scalar_one()
else: else:
return LocalTrade.bt_open_open_trade_count return LocalTrade.bt_open_open_trade_count
@ -1335,7 +1334,7 @@ class Trade(ModelBase, LocalTrade):
) )
@staticmethod @staticmethod
def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select['Trade']: def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select:
""" """
Helper function to query Trades using filters. Helper function to query Trades using filters.
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
@ -1360,7 +1359,7 @@ class Trade(ModelBase, LocalTrade):
return this_query return this_query
@staticmethod @staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']: def get_trades(trade_filter=None, include_orders: bool = True) -> ScalarResult['Trade']:
""" """
Helper function to query Trades using filters. Helper function to query Trades using filters.
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
@ -1370,7 +1369,7 @@ class Trade(ModelBase, LocalTrade):
e.g. `(trade_filter=Trade.id == trade_id)` e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object :return: unsorted query object
""" """
return Trade._session.execute(Trade.get_trades_query(trade_filter, include_orders)) return Trade._session.scalars(Trade.get_trades_query(trade_filter, include_orders))
@staticmethod @staticmethod
def get_open_order_trades() -> List['Trade']: def get_open_order_trades() -> List['Trade']:
@ -1378,7 +1377,7 @@ class Trade(ModelBase, LocalTrade):
Returns all open trades Returns all open trades
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
""" """
return Trade.get_trades(Trade.open_order_id.isnot(None)).all() return cast(List[Trade], Trade.get_trades(Trade.open_order_id.isnot(None)).all())
@staticmethod @staticmethod
def get_open_trades_without_assigned_fees(): def get_open_trades_without_assigned_fees():
@ -1408,12 +1407,12 @@ class Trade(ModelBase, LocalTrade):
Retrieves total realized profit Retrieves total realized profit
""" """
if Trade.use_db: if Trade.use_db:
total_profit = Trade._session.scalar( total_profit: float = Trade._session.execute(
select(func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False)) select(func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False))
) ).scalar_one()
else: else:
total_profit = sum( total_profit = sum(t.close_profit_abs # type: ignore
t.close_profit_abs for t in LocalTrade.get_trades_proxy(is_open=False)) for t in LocalTrade.get_trades_proxy(is_open=False))
return total_profit or 0 return total_profit or 0
@staticmethod @staticmethod

View File

@ -159,7 +159,7 @@ class RPC:
""" """
# Fetch open trades # Fetch open trades
if trade_ids: if trade_ids:
trades: List[Trade] = Trade.get_trades(trade_filter=Trade.id.in_(trade_ids)).all() trades: Sequence[Trade] = Trade.get_trades(trade_filter=Trade.id.in_(trade_ids)).all()
else: else:
trades = Trade.get_open_trades() trades = Trade.get_open_trades()
@ -449,7 +449,7 @@ class RPC:
""" Returns cumulative profit statistics """ """ Returns cumulative profit statistics """
trade_filter = ((Trade.is_open.is_(False) & (Trade.close_date >= start_date)) | trade_filter = ((Trade.is_open.is_(False) & (Trade.close_date >= start_date)) |
Trade.is_open.is_(True)) Trade.is_open.is_(True))
trades: Sequence[Trade] = Trade._session.execute(Trade.get_trades_query( trades: Sequence[Trade] = Trade._session.scalars(Trade.get_trades_query(
trade_filter, include_orders=False).order_by(Trade.id)).all() trade_filter, include_orders=False).order_by(Trade.id)).all()
profit_all_coin = [] profit_all_coin = []