generalize database url path for any db type

This commit is contained in:
robcaulk 2022-08-05 12:19:29 +02:00
parent 51a6b4289f
commit 05ec5c5e54
1 changed files with 10 additions and 6 deletions

View File

@ -87,9 +87,13 @@ class FreqaiDataKitchen:
)
if self.live:
db_url = self.config.get('db_url', 'sqlite://')
self.database_path = '' if db_url == 'sqlite://' else str(db_url).split('///')[1]
self.trade_database_df: DataFrame = pd.DataFrame()
db_url = self.config.get('db_url', 'None')
self.database_path = Path(db_url)
self.database_name = self.database_path.parts[-1]
else:
self.database_path = Path('None')
self.trade_database_df: DataFrame = pd.DataFrame()
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
@ -1038,11 +1042,11 @@ class FreqaiDataKitchen:
def get_current_trade_database(self) -> None:
if self.database_path == '':
logger.warning('No trade databse found. Skipping analysis.')
if str(self.database_path) == 'None':
logger.warning('No trade database found. Skipping analysis.')
return
data = sqlite3.connect(self.database_path)
data = sqlite3.connect(self.database_name)
query = data.execute("SELECT * From trades")
cols = [column[0] for column in query.description]
df = pd.DataFrame.from_records(data=query.fetchall(), columns=cols)