Fix some type errors

This commit is contained in:
Matthias 2023-03-15 21:09:25 +01:00
parent b469addffb
commit d45599ca3b
3 changed files with 16 additions and 17 deletions

View File

@ -373,7 +373,7 @@ def load_trades_from_db(db_url: str, strategy: Optional[str] = None) -> pd.DataF
filters = []
if strategy:
filters.append(Trade.strategy == strategy)
trades = trade_list_to_dataframe(Trade.get_trades(filters).all())
trades = trade_list_to_dataframe(list(Trade.get_trades(filters).all()))
return trades

View File

@ -7,10 +7,9 @@ from datetime import datetime, timedelta, timezone
from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, cast
from sqlalchemy import (Enum, Float, ForeignKey, Integer, Result, Select, String, UniqueConstraint,
desc, func, select)
from sqlalchemy.orm import (Mapped, Query, QueryPropertyDescriptor, lazyload, mapped_column,
relationship)
from sqlalchemy import (Enum, Float, ForeignKey, Integer, ScalarResult, Select, String,
UniqueConstraint, desc, func, select)
from sqlalchemy.orm import Mapped, QueryPropertyDescriptor, lazyload, mapped_column, relationship
from freqtrade.constants import (DATETIME_PRINT_FORMAT, MATH_CLOSE_PREC, NON_OPEN_EXCHANGE_STATES,
BuySell, LongShort)
@ -1154,9 +1153,9 @@ class LocalTrade():
get open trade count
"""
if Trade.use_db:
return Trade._session.scalar(
return Trade._session.execute(
select(func.count(Trade.id)).filter(Trade.is_open.is_(True))
)
).scalar_one()
else:
return LocalTrade.bt_open_open_trade_count
@ -1335,7 +1334,7 @@ class Trade(ModelBase, LocalTrade):
)
@staticmethod
def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select['Trade']:
def get_trades_query(trade_filter=None, include_orders: bool = True) -> Select:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
@ -1360,7 +1359,7 @@ class Trade(ModelBase, LocalTrade):
return this_query
@staticmethod
def get_trades(trade_filter=None, include_orders: bool = True) -> Query['Trade']:
def get_trades(trade_filter=None, include_orders: bool = True) -> ScalarResult['Trade']:
"""
Helper function to query Trades using filters.
NOTE: Not supported in Backtesting.
@ -1370,7 +1369,7 @@ class Trade(ModelBase, LocalTrade):
e.g. `(trade_filter=Trade.id == trade_id)`
:return: unsorted query object
"""
return Trade._session.execute(Trade.get_trades_query(trade_filter, include_orders))
return Trade._session.scalars(Trade.get_trades_query(trade_filter, include_orders))
@staticmethod
def get_open_order_trades() -> List['Trade']:
@ -1378,7 +1377,7 @@ class Trade(ModelBase, LocalTrade):
Returns all open trades
NOTE: Not supported in Backtesting.
"""
return Trade.get_trades(Trade.open_order_id.isnot(None)).all()
return cast(List[Trade], Trade.get_trades(Trade.open_order_id.isnot(None)).all())
@staticmethod
def get_open_trades_without_assigned_fees():
@ -1408,12 +1407,12 @@ class Trade(ModelBase, LocalTrade):
Retrieves total realized profit
"""
if Trade.use_db:
total_profit = Trade._session.scalar(
total_profit: float = Trade._session.execute(
select(func.sum(Trade.close_profit_abs)).filter(Trade.is_open.is_(False))
)
).scalar_one()
else:
total_profit = sum(
t.close_profit_abs for t in LocalTrade.get_trades_proxy(is_open=False))
total_profit = sum(t.close_profit_abs # type: ignore
for t in LocalTrade.get_trades_proxy(is_open=False))
return total_profit or 0
@staticmethod

View File

@ -159,7 +159,7 @@ class RPC:
"""
# Fetch open trades
if trade_ids:
trades: List[Trade] = Trade.get_trades(trade_filter=Trade.id.in_(trade_ids)).all()
trades: Sequence[Trade] = Trade.get_trades(trade_filter=Trade.id.in_(trade_ids)).all()
else:
trades = Trade.get_open_trades()
@ -449,7 +449,7 @@ class RPC:
""" Returns cumulative profit statistics """
trade_filter = ((Trade.is_open.is_(False) & (Trade.close_date >= start_date)) |
Trade.is_open.is_(True))
trades: Sequence[Trade] = Trade._session.execute(Trade.get_trades_query(
trades: Sequence[Trade] = Trade._session.scalars(Trade.get_trades_query(
trade_filter, include_orders=False).order_by(Trade.id)).all()
profit_all_coin = []