Add direction column to pairlocks

This commit is contained in:
Matthias 2022-04-23 15:26:58 +02:00
parent 25c6c5e326
commit 6ff3b178b0
3 changed files with 97 additions and 9 deletions

View File

@ -9,7 +9,7 @@ from freqtrade.exceptions import OperationalException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_table_names_for_table(inspector, tabletype): def get_table_names_for_table(inspector, tabletype) -> List[str]:
return [t for t in inspector.get_table_names() if t.startswith(tabletype)] return [t for t in inspector.get_table_names() if t.startswith(tabletype)]
@ -21,7 +21,7 @@ def get_column_def(columns: List, column: str, default: str) -> str:
return default if not has_column(columns, column) else column return default if not has_column(columns, column) else column
def get_backup_name(tabs, backup_prefix: str): def get_backup_name(tabs: List[str], backup_prefix: str):
table_back_name = backup_prefix table_back_name = backup_prefix
for i, table_back_name in enumerate(tabs): for i, table_back_name in enumerate(tabs):
table_back_name = f'{backup_prefix}{i}' table_back_name = f'{backup_prefix}{i}'
@ -56,6 +56,16 @@ def set_sequence_ids(engine, order_id, trade_id):
connection.execute(text(f"ALTER SEQUENCE trades_id_seq RESTART WITH {trade_id}")) connection.execute(text(f"ALTER SEQUENCE trades_id_seq RESTART WITH {trade_id}"))
def drop_index_on_table(engine, inspector, table_bak_name):
with engine.begin() as connection:
# drop indexes on backup table in new session
for index in inspector.get_indexes(table_bak_name):
if engine.name == 'mysql':
connection.execute(text(f"drop index {index['name']} on {table_bak_name}"))
else:
connection.execute(text(f"drop index {index['name']}"))
def migrate_trades_and_orders_table( def migrate_trades_and_orders_table(
decl_base, inspector, engine, decl_base, inspector, engine,
trade_back_name: str, cols: List, trade_back_name: str, cols: List,
@ -116,13 +126,7 @@ def migrate_trades_and_orders_table(
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f"alter table trades rename to {trade_back_name}")) connection.execute(text(f"alter table trades rename to {trade_back_name}"))
with engine.begin() as connection: drop_index_on_table(engine, inspector, trade_back_name)
# drop indexes on backup table in new session
for index in inspector.get_indexes(trade_back_name):
if engine.name == 'mysql':
connection.execute(text(f"drop index {index['name']} on {trade_back_name}"))
else:
connection.execute(text(f"drop index {index['name']}"))
order_id, trade_id = get_last_sequence_ids(engine, trade_back_name, order_back_name) order_id, trade_id = get_last_sequence_ids(engine, trade_back_name, order_back_name)
@ -205,6 +209,31 @@ def migrate_orders_table(engine, table_back_name: str, cols_order: List):
""")) """))
def migrate_pairlocks_table(
decl_base, inspector, engine,
pairlock_back_name: str, cols: List):
# Schema migration necessary
with engine.begin() as connection:
connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}"))
drop_index_on_table(engine, inspector, pairlock_back_name)
direction = get_column_def(cols, 'direction', "'*'")
# let SQLAlchemy create the schema as required
decl_base.metadata.create_all(engine)
# Copy data back - following the correct schema
with engine.begin() as connection:
connection.execute(text(f"""insert into pairlocks
(id, pair, direction, reason, lock_time,
lock_end_time, active)
select id, pair, {direction} direction, reason, lock_time,
lock_end_time, active
from {pairlock_back_name}
"""))
def set_sqlite_to_wal(engine): def set_sqlite_to_wal(engine):
if engine.name == 'sqlite' and str(engine.url) != 'sqlite://': if engine.name == 'sqlite' and str(engine.url) != 'sqlite://':
# Set Mode to # Set Mode to
@ -220,10 +249,13 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
cols_trades = inspector.get_columns('trades') cols_trades = inspector.get_columns('trades')
cols_orders = inspector.get_columns('orders') cols_orders = inspector.get_columns('orders')
cols_pairlocks = inspector.get_columns('pairlocks')
tabs = get_table_names_for_table(inspector, 'trades') tabs = get_table_names_for_table(inspector, 'trades')
table_back_name = get_backup_name(tabs, 'trades_bak') table_back_name = get_backup_name(tabs, 'trades_bak')
order_tabs = get_table_names_for_table(inspector, 'orders') order_tabs = get_table_names_for_table(inspector, 'orders')
order_table_bak_name = get_backup_name(order_tabs, 'orders_bak') order_table_bak_name = get_backup_name(order_tabs, 'orders_bak')
pairlock_tabs = get_table_names_for_table(inspector, 'pairlocks')
pairlock_table_bak_name = get_backup_name(pairlock_tabs, 'pairlocks_bak')
# Check if migration necessary # Check if migration necessary
# Migrates both trades and orders table! # Migrates both trades and orders table!
@ -236,6 +268,13 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
decl_base, inspector, engine, table_back_name, cols_trades, decl_base, inspector, engine, table_back_name, cols_trades,
order_table_bak_name, cols_orders) order_table_bak_name, cols_orders)
if not has_column(cols_pairlocks, 'direction'):
logger.info(f"Running database migration for pairlocks - "
f"backup: {pairlock_table_bak_name}")
migrate_pairlocks_table(
decl_base, inspector, engine, pairlock_table_bak_name, cols_pairlocks
)
if 'orders' not in previous_tables and 'trades' in previous_tables: if 'orders' not in previous_tables and 'trades' in previous_tables:
raise OperationalException( raise OperationalException(
"Your database seems to be very old. " "Your database seems to be very old. "

View File

@ -1428,6 +1428,8 @@ class PairLock(_DECL_BASE):
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
pair = Column(String(25), nullable=False, index=True) pair = Column(String(25), nullable=False, index=True)
# lock direction - long, short or * (for both)
direction = Column(String(25), nullable=False, default="*")
reason = Column(String(255), nullable=True) reason = Column(String(255), nullable=True)
# Time the pair was locked (start time) # Time the pair was locked (start time)
lock_time = Column(DateTime, nullable=False) lock_time = Column(DateTime, nullable=False)

View File

@ -15,6 +15,7 @@ from freqtrade.enums import TradingMode
from freqtrade.exceptions import DependencyException, OperationalException from freqtrade.exceptions import DependencyException, OperationalException
from freqtrade.persistence import LocalTrade, Order, Trade, clean_dry_run_db, init_db from freqtrade.persistence import LocalTrade, Order, Trade, clean_dry_run_db, init_db
from freqtrade.persistence.migrations import get_last_sequence_ids, set_sequence_ids from freqtrade.persistence.migrations import get_last_sequence_ids, set_sequence_ids
from freqtrade.persistence.models import PairLock
from tests.conftest import create_mock_trades, create_mock_trades_with_leverage, log_has, log_has_re from tests.conftest import create_mock_trades, create_mock_trades_with_leverage, log_has, log_has_re
@ -1427,6 +1428,52 @@ def test_migrate_set_sequence_ids():
assert engine.begin.call_count == 0 assert engine.begin.call_count == 0
def test_migrate_pairlocks(mocker, default_conf, fee, caplog):
"""
Test Database migration (starting with new pairformat)
"""
caplog.set_level(logging.DEBUG)
# Always create all columns apart from the last!
create_table_old = """CREATE TABLE pairlocks (
id INTEGER NOT NULL,
pair VARCHAR(25) NOT NULL,
reason VARCHAR(255),
lock_time DATETIME NOT NULL,
lock_end_time DATETIME NOT NULL,
active BOOLEAN NOT NULL,
PRIMARY KEY (id)
)
"""
create_index1 = "CREATE INDEX ix_pairlocks_pair ON pairlocks (pair)"
create_index2 = "CREATE INDEX ix_pairlocks_lock_end_time ON pairlocks (lock_end_time)"
create_index3 = "CREATE INDEX ix_pairlocks_active ON pairlocks (active)"
insert_table_old = """INSERT INTO pairlocks (
id, pair, reason, lock_time, lock_end_time, active)
VALUES (1, 'ETH/BTC', 'Auto lock', '2021-07-12 18:41:03', '2021-07-11 18:45:00', 1)
"""
insert_table_old2 = """INSERT INTO pairlocks (
id, pair, reason, lock_time, lock_end_time, active)
VALUES (2, '*', 'Lock all', '2021-07-12 18:41:03', '2021-07-12 19:00:00', 1)
"""
engine = create_engine('sqlite://')
mocker.patch('freqtrade.persistence.models.create_engine', lambda *args, **kwargs: engine)
# Create table using the old format
with engine.begin() as connection:
connection.execute(text(create_table_old))
connection.execute(text(insert_table_old))
connection.execute(text(insert_table_old2))
connection.execute(text(create_index1))
connection.execute(text(create_index2))
connection.execute(text(create_index3))
init_db(default_conf['db_url'], default_conf['dry_run'])
assert len(PairLock.query.all()) == 2
assert len(PairLock.query.filter(PairLock.pair == '*').all()) == 1
assert len(PairLock.query.filter(PairLock.pair == 'ETH/BTC').all()) == 1
def test_adjust_stop_loss(fee): def test_adjust_stop_loss(fee):
trade = Trade( trade = Trade(
pair='ADA/USDT', pair='ADA/USDT',