Use .query.session to make sure the scoped session is used properly

This commit is contained in:
Matthias 2021-04-05 07:28:51 +02:00
parent bd5e1c5096
commit 0407bf755f
4 changed files with 19 additions and 22 deletions

View File

@ -187,7 +187,7 @@ class FreqtradeBot(LoggingMixin):
if self.get_free_open_trades(): if self.get_free_open_trades():
self.enter_positions() self.enter_positions()
Trade.session.flush() Trade.query.session.flush()
def process_stopped(self) -> None: def process_stopped(self) -> None:
""" """
@ -621,8 +621,8 @@ class FreqtradeBot(LoggingMixin):
if order_status == 'closed': if order_status == 'closed':
self.update_trade_state(trade, order_id, order) self.update_trade_state(trade, order_id, order)
Trade.session.add(trade) Trade.query.session.add(trade)
Trade.session.flush() Trade.query.session.flush()
# Updating wallets # Updating wallets
self.wallets.update() self.wallets.update()
@ -1205,7 +1205,7 @@ class FreqtradeBot(LoggingMixin):
# In case of market sell orders the order can be closed immediately # In case of market sell orders the order can be closed immediately
if order.get('status', 'unknown') == 'closed': if order.get('status', 'unknown') == 'closed':
self.update_trade_state(trade, trade.open_order_id, order) self.update_trade_state(trade, trade.open_order_id, order)
Trade.session.flush() Trade.query.session.flush()
# Lock pair for one candle to prevent immediate rebuys # Lock pair for one candle to prevent immediate rebuys
self.strategy.lock_pair(trade.pair, datetime.now(timezone.utc), self.strategy.lock_pair(trade.pair, datetime.now(timezone.utc),

View File

@ -61,11 +61,8 @@ def init_db(db_url: str, clean_open_orders: bool = False) -> None:
# We should use the scoped_session object - not a seperately initialized version # We should use the scoped_session object - not a seperately initialized version
Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True))
Trade.query = Trade.session.query_property() Trade.query = Trade.session.query_property()
# Copy session attributes to order object too Order.query = Trade.session.query_property()
Order.session = Trade.session PairLock.query = Trade.session.query_property()
Order.query = Order.session.query_property()
PairLock.session = Trade.session
PairLock.query = PairLock.session.query_property()
previous_tables = inspect(engine).get_table_names() previous_tables = inspect(engine).get_table_names()
_DECL_BASE.metadata.create_all(engine) _DECL_BASE.metadata.create_all(engine)
@ -81,7 +78,7 @@ def cleanup_db() -> None:
Flushes all pending operations to disk. Flushes all pending operations to disk.
:return: None :return: None
""" """
Trade.session.flush() Trade.query.session.flush()
def clean_dry_run_db() -> None: def clean_dry_run_db() -> None:
@ -677,7 +674,7 @@ class LocalTrade():
in stake currency in stake currency
""" """
if Trade.use_db: if Trade.use_db:
total_open_stake_amount = Trade.session.query( total_open_stake_amount = Trade.query.with_entities(
func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar() func.sum(Trade.stake_amount)).filter(Trade.is_open.is_(True)).scalar()
else: else:
total_open_stake_amount = sum( total_open_stake_amount = sum(
@ -689,7 +686,7 @@ class LocalTrade():
""" """
Returns List of dicts containing all Trades, including profit and trade count Returns List of dicts containing all Trades, including profit and trade count
""" """
pair_rates = Trade.session.query( pair_rates = Trade.query.with_entities(
Trade.pair, Trade.pair,
func.sum(Trade.close_profit).label('profit_sum'), func.sum(Trade.close_profit).label('profit_sum'),
func.count(Trade.pair).label('count') func.count(Trade.pair).label('count')
@ -712,7 +709,7 @@ class LocalTrade():
Get best pair with closed trade. Get best pair with closed trade.
:returns: Tuple containing (pair, profit_sum) :returns: Tuple containing (pair, profit_sum)
""" """
best_pair = Trade.session.query( best_pair = Trade.query.with_entities(
Trade.pair, func.sum(Trade.close_profit).label('profit_sum') Trade.pair, func.sum(Trade.close_profit).label('profit_sum')
).filter(Trade.is_open.is_(False)) \ ).filter(Trade.is_open.is_(False)) \
.group_by(Trade.pair) \ .group_by(Trade.pair) \
@ -805,10 +802,10 @@ class Trade(_DECL_BASE, LocalTrade):
def delete(self) -> None: def delete(self) -> None:
for order in self.orders: for order in self.orders:
Order.session.delete(order) Order.query.session.delete(order)
Trade.session.delete(self) Trade.query.session.delete(self)
Trade.session.flush() Trade.query.session.flush()
@staticmethod @staticmethod
def get_trades_proxy(*, pair: str = None, is_open: bool = None, def get_trades_proxy(*, pair: str = None, is_open: bool = None,

View File

@ -48,8 +48,8 @@ class PairLocks():
active=True active=True
) )
if PairLocks.use_db: if PairLocks.use_db:
PairLock.session.add(lock) PairLock.query.session.add(lock)
PairLock.session.flush() PairLock.query.session.flush()
else: else:
PairLocks.locks.append(lock) PairLocks.locks.append(lock)
@ -99,7 +99,7 @@ class PairLocks():
for lock in locks: for lock in locks:
lock.active = False lock.active = False
if PairLocks.use_db: if PairLocks.use_db:
PairLock.session.flush() PairLock.query.session.flush()
@staticmethod @staticmethod
def is_global_lock(now: Optional[datetime] = None) -> bool: def is_global_lock(now: Optional[datetime] = None) -> bool:

View File

@ -558,7 +558,7 @@ class RPC:
# Execute sell for all open orders # Execute sell for all open orders
for trade in Trade.get_open_trades(): for trade in Trade.get_open_trades():
_exec_forcesell(trade) _exec_forcesell(trade)
Trade.session.flush() Trade.query.session.flush()
self._freqtrade.wallets.update() self._freqtrade.wallets.update()
return {'result': 'Created sell orders for all open trades.'} return {'result': 'Created sell orders for all open trades.'}
@ -571,7 +571,7 @@ class RPC:
raise RPCException('invalid argument') raise RPCException('invalid argument')
_exec_forcesell(trade) _exec_forcesell(trade)
Trade.session.flush() Trade.query.session.flush()
self._freqtrade.wallets.update() self._freqtrade.wallets.update()
return {'result': f'Created sell order for trade {trade_id}.'} return {'result': f'Created sell order for trade {trade_id}.'}
@ -696,7 +696,7 @@ class RPC:
lock.lock_end_time = datetime.now(timezone.utc) lock.lock_end_time = datetime.now(timezone.utc)
# session is always the same # session is always the same
PairLock.session.flush() PairLock.query.session.flush()
return self._rpc_locks() return self._rpc_locks()