Rename persistant storage infrastructure.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from freqtrade.persistence.keyvalue_middleware import KeyValues
|
||||
from freqtrade.persistence.keyvalue_middleware import CustomDataWrapper
|
||||
from freqtrade.persistence.models import cleanup_db, init_db
|
||||
from freqtrade.persistence.pairlock_middleware import PairLocks
|
||||
from freqtrade.persistence.trade_model import LocalTrade, Order, Trade
|
||||
|
||||
@@ -8,28 +8,28 @@ from freqtrade.constants import DATETIME_PRINT_FORMAT
|
||||
from freqtrade.persistence.base import _DECL_BASE
|
||||
|
||||
|
||||
class KeyValue(_DECL_BASE):
|
||||
class CustomData(_DECL_BASE):
|
||||
"""
|
||||
KeyValue database model
|
||||
CustomData database model
|
||||
Keeps records of metadata as key/value store
|
||||
for trades or global persistant values
|
||||
One to many relationship with Trades:
|
||||
- One trade can have many metadata entries
|
||||
- One metadata entry can only be associated with one Trade
|
||||
"""
|
||||
__tablename__ = 'keyvalue'
|
||||
__tablename__ = 'trade_custom_data'
|
||||
# Uniqueness should be ensured over pair, order_id
|
||||
# its likely that order_id is unique per Pair on some exchanges.
|
||||
__table_args__ = (UniqueConstraint('ft_trade_id', 'kv_key', name="_trade_id_kv_key"),)
|
||||
__table_args__ = (UniqueConstraint('ft_trade_id', 'cd_key', name="_trade_id_cd_key"),)
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
ft_trade_id = Column(Integer, ForeignKey('trades.id'), index=True, default=0)
|
||||
|
||||
trade = relationship("Trade", back_populates="keyvalues")
|
||||
trade = relationship("Trade", back_populates="custom_data")
|
||||
|
||||
kv_key = Column(String(255), nullable=False)
|
||||
kv_type = Column(String(25), nullable=False)
|
||||
kv_value = Column(Text, nullable=False)
|
||||
cd_key = Column(String(255), nullable=False)
|
||||
cd_type = Column(String(25), nullable=False)
|
||||
cd_value = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -38,20 +38,20 @@ class KeyValue(_DECL_BASE):
|
||||
if self.created_at is not None else None)
|
||||
update_time = (self.updated_at.strftime(DATETIME_PRINT_FORMAT)
|
||||
if self.updated_at is not None else None)
|
||||
return (f'KeyValue(id={self.id}, key={self.kv_key}, type={self.kv_type}, ' +
|
||||
f'value={self.kv_value}, trade_id={self.ft_trade_id}, created={create_time}, ' +
|
||||
return (f'CustomData(id={self.id}, key={self.cd_key}, type={self.cd_type}, ' +
|
||||
f'value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, ' +
|
||||
f'updated={update_time})')
|
||||
|
||||
@staticmethod
|
||||
def query_kv(key: Optional[str] = None, trade_id: Optional[int] = None) -> Query:
|
||||
def query_cd(key: Optional[str] = None, trade_id: Optional[int] = None) -> Query:
|
||||
"""
|
||||
Get all keyvalues, if trade_id is not specified
|
||||
Get all CustomData, if trade_id is not specified
|
||||
return will be for generic values not tied to a trade
|
||||
:param trade_id: id of the Trade
|
||||
"""
|
||||
filters = []
|
||||
filters.append(KeyValue.ft_trade_id == trade_id if trade_id is not None else 0)
|
||||
filters.append(CustomData.ft_trade_id == trade_id if trade_id is not None else 0)
|
||||
if key is not None:
|
||||
filters.append(KeyValue.kv_key.ilike(key))
|
||||
filters.append(CustomData.cd_key.ilike(key))
|
||||
|
||||
return KeyValue.query.filter(*filters)
|
||||
return CustomData.query.filter(*filters)
|
||||
|
||||
@@ -3,57 +3,63 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from freqtrade.persistence.keyvalue import KeyValue
|
||||
from freqtrade.persistence.keyvalue import CustomData
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyValues():
|
||||
class CustomDataWrapper():
|
||||
"""
|
||||
KeyValues middleware class
|
||||
CustomData middleware class
|
||||
Abstracts the database layer away so it becomes optional - which will be necessary to support
|
||||
backtesting and hyperopt in the future.
|
||||
"""
|
||||
|
||||
use_db = True
|
||||
kvals: List[KeyValue] = []
|
||||
custom_data: List[CustomData] = []
|
||||
unserialized_types = ['bool', 'float', 'int', 'str']
|
||||
|
||||
@staticmethod
|
||||
def reset_keyvalues() -> None:
|
||||
def reset_custom_data() -> None:
|
||||
"""
|
||||
Resets all key-value pairs. Only active for backtesting mode.
|
||||
"""
|
||||
if not KeyValues.use_db:
|
||||
KeyValues.kvals = []
|
||||
if not CustomDataWrapper.use_db:
|
||||
CustomDataWrapper.custom_data = []
|
||||
|
||||
@staticmethod
|
||||
def get_kval(key: Optional[str] = None, trade_id: Optional[int] = None) -> List[KeyValue]:
|
||||
def get_custom_data(key: Optional[str] = None,
|
||||
trade_id: Optional[int] = None) -> List[CustomData]:
|
||||
if trade_id is None:
|
||||
trade_id = 0
|
||||
|
||||
if KeyValues.use_db:
|
||||
filtered_kvals = KeyValue.query_kv(trade_id=trade_id, key=key).all()
|
||||
for index, kval in enumerate(filtered_kvals):
|
||||
if kval.kv_type not in KeyValues.unserialized_types:
|
||||
kval.kv_value = json.loads(kval.kv_value)
|
||||
filtered_kvals[index] = kval
|
||||
return filtered_kvals
|
||||
if CustomDataWrapper.use_db:
|
||||
filtered_custom_data = CustomData.query_cd(trade_id=trade_id, key=key).all()
|
||||
for index, data_entry in enumerate(filtered_custom_data):
|
||||
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
|
||||
data_entry.cd_value = json.loads(data_entry.cd_value)
|
||||
filtered_custom_data[index] = data_entry
|
||||
return filtered_custom_data
|
||||
else:
|
||||
filtered_kvals = [kval for kval in KeyValues.kvals if (kval.ft_trade_id == trade_id)]
|
||||
filtered_custom_data = [
|
||||
data_entry for data_entry in CustomDataWrapper.custom_data
|
||||
if (data_entry.ft_trade_id == trade_id)
|
||||
]
|
||||
if key is not None:
|
||||
filtered_kvals = [
|
||||
kval for kval in filtered_kvals if (kval.kv_key.casefold() == key.casefold())]
|
||||
return filtered_kvals
|
||||
filtered_custom_data = [
|
||||
data_entry for data_entry in filtered_custom_data
|
||||
if (data_entry.cd_key.casefold() == key.casefold())
|
||||
]
|
||||
return filtered_custom_data
|
||||
|
||||
@staticmethod
|
||||
def set_kval(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
||||
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
||||
|
||||
value_type = type(value).__name__
|
||||
value_db = None
|
||||
|
||||
if value_type not in KeyValues.unserialized_types:
|
||||
if value_type not in CustomDataWrapper.unserialized_types:
|
||||
try:
|
||||
value_db = json.dumps(value)
|
||||
except TypeError as e:
|
||||
@@ -64,44 +70,44 @@ class KeyValues():
|
||||
if trade_id is None:
|
||||
trade_id = 0
|
||||
|
||||
kvals = KeyValues.get_kval(key=key, trade_id=trade_id)
|
||||
if kvals:
|
||||
kv = kvals[0]
|
||||
kv.kv_value = value
|
||||
kv.updated_at = datetime.utcnow()
|
||||
custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id)
|
||||
if custom_data:
|
||||
data_entry = custom_data[0]
|
||||
data_entry.cd_value = value
|
||||
data_entry.updated_at = datetime.utcnow()
|
||||
else:
|
||||
kv = KeyValue(
|
||||
ft_trade_id=trade_id,
|
||||
kv_key=key,
|
||||
kv_type=value_type,
|
||||
kv_value=value,
|
||||
created_at=datetime.utcnow()
|
||||
data_entry = CustomData(
|
||||
ft_trade_id=trade_id,
|
||||
cd_key=key,
|
||||
cd_type=value_type,
|
||||
cd_value=value,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
if KeyValues.use_db and value_db is not None:
|
||||
kv.kv_value = value_db
|
||||
KeyValue.query.session.add(kv)
|
||||
KeyValue.query.session.commit()
|
||||
elif not KeyValues.use_db:
|
||||
kv_index = -1
|
||||
for index, kval in enumerate(KeyValues.kvals):
|
||||
if kval.ft_trade_id == trade_id and kval.kv_key == key:
|
||||
kv_index = index
|
||||
if CustomDataWrapper.use_db and value_db is not None:
|
||||
data_entry.cd_value = value_db
|
||||
CustomData.query.session.add(data_entry)
|
||||
CustomData.query.session.commit()
|
||||
elif not CustomDataWrapper.use_db:
|
||||
cd_index = -1
|
||||
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
|
||||
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
|
||||
cd_index = index
|
||||
break
|
||||
|
||||
if kv_index >= 0:
|
||||
kval.kv_type = value_type
|
||||
kval.value = value
|
||||
kval.updated_at = datetime.utcnow()
|
||||
if cd_index >= 0:
|
||||
data_entry.cd_type = value_type
|
||||
data_entry.value = value
|
||||
data_entry.updated_at = datetime.utcnow()
|
||||
|
||||
KeyValues.kvals[kv_index] = kval
|
||||
CustomDataWrapper.custom_data[cd_index] = data_entry
|
||||
else:
|
||||
KeyValues.kvals.append(kv)
|
||||
CustomDataWrapper.custom_data.append(data_entry)
|
||||
|
||||
@staticmethod
|
||||
def get_all_kvals() -> List[KeyValue]:
|
||||
def get_all_custom_data() -> List[CustomData]:
|
||||
|
||||
if KeyValues.use_db:
|
||||
return KeyValue.query.all()
|
||||
if CustomDataWrapper.use_db:
|
||||
return CustomData.query.all()
|
||||
else:
|
||||
return KeyValues.kvals
|
||||
return CustomDataWrapper.custom_data
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.pool import StaticPool
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.persistence.base import _DECL_BASE
|
||||
from freqtrade.persistence.keyvalue import KeyValue
|
||||
from freqtrade.persistence.keyvalue import CustomData
|
||||
from freqtrade.persistence.migrations import check_migrate
|
||||
from freqtrade.persistence.pairlock import PairLock
|
||||
from freqtrade.persistence.trade_model import Order, Trade
|
||||
@@ -58,8 +58,8 @@ def init_db(db_url: str) -> None:
|
||||
Trade.query = Trade._session.query_property()
|
||||
Order.query = Trade._session.query_property()
|
||||
PairLock.query = Trade._session.query_property()
|
||||
KeyValue._session = scoped_session(sessionmaker(bind=engine, autoflush=True))
|
||||
KeyValue.query = KeyValue._session.query_property()
|
||||
CustomData._session = scoped_session(sessionmaker(bind=engine, autoflush=True))
|
||||
CustomData.query = CustomData._session.query_property()
|
||||
|
||||
previous_tables = inspect(engine).get_table_names()
|
||||
_DECL_BASE.metadata.create_all(engine)
|
||||
|
||||
@@ -15,8 +15,8 @@ from freqtrade.enums import ExitType, TradingMode
|
||||
from freqtrade.exceptions import DependencyException, OperationalException
|
||||
from freqtrade.leverage import interest
|
||||
from freqtrade.persistence.base import _DECL_BASE
|
||||
from freqtrade.persistence.keyvalue import KeyValue
|
||||
from freqtrade.persistence.keyvalue_middleware import KeyValues
|
||||
from freqtrade.persistence.keyvalue import CustomData
|
||||
from freqtrade.persistence.keyvalue_middleware import CustomDataWrapper
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -240,7 +240,7 @@ class LocalTrade():
|
||||
id: int = 0
|
||||
|
||||
orders: List[Order] = []
|
||||
keyvalues: List[KeyValue] = []
|
||||
custom_data: List[CustomData] = []
|
||||
|
||||
exchange: str = ''
|
||||
pair: str = ''
|
||||
@@ -880,11 +880,11 @@ class LocalTrade():
|
||||
or (o.ft_is_open is True and o.status is not None)
|
||||
]
|
||||
|
||||
def set_kval(self, key: str, value: Any) -> None:
|
||||
KeyValues.set_kval(key=key, value=value, trade_id=self.id)
|
||||
def set_custom_data(self, key: str, value: Any) -> None:
|
||||
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
|
||||
|
||||
def get_kvals(self, key: Optional[str]) -> List[KeyValue]:
|
||||
return KeyValues.get_kval(key=key, trade_id=self.id)
|
||||
def get_custom_data(self, key: Optional[str]) -> List[CustomData]:
|
||||
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
|
||||
|
||||
@property
|
||||
def nr_of_successful_entries(self) -> int:
|
||||
@@ -1016,7 +1016,7 @@ class Trade(_DECL_BASE, LocalTrade):
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
orders = relationship("Order", order_by="Order.id", cascade="all, delete-orphan", lazy="joined")
|
||||
keyvalues = relationship("KeyValue", order_by="KeyValue.id", cascade="all, delete-orphan")
|
||||
custom_data = relationship("CustomData", order_by="CustomData.id", cascade="all, delete-orphan")
|
||||
|
||||
exchange = Column(String(25), nullable=False)
|
||||
pair = Column(String(25), nullable=False, index=True)
|
||||
@@ -1090,9 +1090,9 @@ class Trade(_DECL_BASE, LocalTrade):
|
||||
Trade.query.session.delete(self)
|
||||
Trade.commit()
|
||||
|
||||
for kval in self.keyvalues:
|
||||
KeyValue.query.session.delete(kval)
|
||||
KeyValue.query.session.commit()
|
||||
for entry in self.custom_data:
|
||||
CustomData.query.session.delete(entry)
|
||||
CustomData.query.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def commit():
|
||||
@@ -1367,11 +1367,11 @@ class Trade(_DECL_BASE, LocalTrade):
|
||||
.order_by(desc('profit_sum')).first()
|
||||
return best_pair
|
||||
|
||||
def set_kval(self, key: str, value: Any) -> None:
|
||||
super().set_kval(key=key, value=value)
|
||||
def set_custom_data(self, key: str, value: Any) -> None:
|
||||
super().set_custom_data(key=key, value=value)
|
||||
|
||||
def get_kvals(self, key: Optional[str]) -> List[KeyValue]:
|
||||
return super().get_kvals(key=key)
|
||||
def get_custom_data(self, key: Optional[str]) -> List[CustomData]:
|
||||
return super().get_custom_data(key=key)
|
||||
|
||||
@staticmethod
|
||||
def get_trading_volume(start_date: datetime = datetime.fromtimestamp(0)) -> float:
|
||||
|
||||
Reference in New Issue
Block a user