Skip to content

Commit

Permalink
Clean up and write more unit tests for Data. (#57)
Browse files Browse the repository at this point in the history
* initial commit

* more cleanups

* small bugfix

* small fix

* small stuff

* minor improvements and start setup for db tests

* fix test

* minor refactors and add tests

* minor refactors and add more tests

* lints

* lints

* add more tests and cleanup

* add more tests and cleanup

* lints

* add more tests

* lints

* lints

* minor improvements

* resolve comments

* add final test for this PR

* lints

* simplify
  • Loading branch information
ZENALC authored Jul 24, 2021
1 parent 1487855 commit d21d4df
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 92 deletions.
160 changes: 79 additions & 81 deletions algobot/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import binance

from algobot.helpers import ROOT_DIR, get_logging_object, get_normalized_data, get_ups_and_downs
from algobot.helpers import ROOT_DIR, SHORT_INTERVAL_MAP, get_logging_object, get_normalized_data, get_ups_and_downs
from algobot.typing_hints import DATA_TYPE


Expand Down Expand Up @@ -87,7 +87,7 @@ def validate_interval(interval: str):
Validates interval. If incorrect interval, raises ValueError.
:param interval: Interval to be checked in short form -> e.g. 12h for 12 hours
"""
available_intervals = ('12h', '15m', '1d', '1h', '1m', '2h', '30m', '3d', '3m', '4h', '5m', '6h', '8h')
available_intervals = SHORT_INTERVAL_MAP.keys()
if interval not in available_intervals:
raise ValueError(f'Invalid interval {interval} given. Available intervals are: \n{available_intervals}')

Expand All @@ -101,19 +101,6 @@ def validate_symbol(self, symbol: str):
if not self.is_valid_symbol(symbol):
raise ValueError(f'Invalid symbol/ticker {symbol} provided.')

def load_data(self, update: bool = True):
"""
Loads data to Data object.
:param update: Boolean that determines whether data is updated or not.
"""
self.get_data_from_database()
if update:
if not self.database_is_updated():
self.output_message("Updating data...")
self.update_database_and_data()
else:
self.output_message("Database is up-to-date.")

def output_message(self, message: str, level: int = 2, printMessage: bool = False):
"""
This function will log and optionally print the message provided.
Expand Down Expand Up @@ -166,20 +153,30 @@ def create_table(self):
);''')
connection.commit()

def dump_to_table(self, totalData: List[dict] = None) -> bool:
def dump_to_table(self, total_data: List[dict] = None) -> bool:
"""
Dumps date and price information to database.
:return: A boolean whether data entry was successful or not.
"""
if totalData is None:
totalData = self.data
if total_data is None:
total_data = self.data

query = f'''INSERT INTO {self.databaseTable} (
date_utc,
open_price,
high_price,
low_price,
close_price,
volume,
quote_asset_volume,
number_of_trades,
taker_buy_base_asset,
taker_buy_quote_asset
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);'''

query = f'''INSERT INTO {self.databaseTable} (date_utc, open_price, high_price, low_price, close_price,
volume, quote_asset_volume, number_of_trades, taker_buy_base_asset, taker_buy_quote_asset)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);'''
with closing(sqlite3.connect(self.databaseFile)) as connection:
with closing(connection.cursor()) as cursor:
for data in totalData:
for data in total_data:
try:
cursor.execute(query,
(data['date_utc'].strftime('%Y-%m-%d %H:%M:%S'),
Expand Down Expand Up @@ -220,16 +217,24 @@ def get_data_from_database(self):
with closing(sqlite3.connect(self.databaseFile)) as connection:
with closing(connection.cursor()) as cursor:
rows = cursor.execute(f'''
SELECT "date_utc", "open_price", "high_price", "low_price", "close_price", "volume",
"quote_asset_volume", "number_of_trades", "taker_buy_base_asset", "taker_buy_quote_asset"
SELECT
"date_utc",
"open_price",
"high_price",
"low_price",
"close_price",
"volume",
"quote_asset_volume",
"number_of_trades",
"taker_buy_base_asset",
"taker_buy_quote_asset"
FROM {self.databaseTable} ORDER BY date_utc
''').fetchall()

if len(rows) > 0:
self.output_message("Retrieving data from database...")
else:
self.output_message("No data found in database.")
return

for row in rows:
date_utc = datetime.strptime(row[0], '%Y-%m-%d %H:%M:%S').replace(tzinfo=timezone.utc)
Expand Down Expand Up @@ -258,8 +263,21 @@ def get_latest_timestamp(self) -> int:
# pylint: disable=protected-access
return self.binanceClient._get_earliest_valid_timestamp(self.symbol, self.interval)
else:
latestDate = datetime.strptime(result[0], '%Y-%m-%d %H:%M:%S').replace(tzinfo=timezone.utc)
return int(latestDate.timestamp()) * 1000 + 1 # Converting timestamp to milliseconds
latest_date = datetime.strptime(result[0], '%Y-%m-%d %H:%M:%S').replace(tzinfo=timezone.utc)
return int(latest_date.timestamp()) * 1000 + 1 # Converting timestamp to milliseconds

def load_data(self, update: bool = True):
"""
Loads data to Data object.
:param update: Boolean that determines whether data is updated or not.
"""
self.get_data_from_database()
if update:
if not self.database_is_updated():
self.output_message("Updating data...")
self.update_database_and_data()
else:
self.output_message("Database is up-to-date.")

# noinspection PyProtectedMember
def update_database_and_data(self):
Expand Down Expand Up @@ -382,14 +400,14 @@ def get_new_data(self, timestamp: int, limit: int = 1000, get_current: bool = Fa
else:
return new_data[:-1] # Up to -1st index, because we don't want current period data.

def is_latest_date(self, latestDate: datetime) -> bool:
def is_latest_date(self, latest_date: datetime) -> bool:
"""
Checks whether the latest date available is the latest period available.
:param latestDate: Datetime object.
:param latest_date: Datetime object.
:return: True or false whether date is latest period or not.
"""
minutes = self.get_interval_minutes()
current_date = latestDate + timedelta(minutes=minutes) + timedelta(seconds=5) # 5s leeway for server update
current_date = latest_date + timedelta(minutes=minutes) + timedelta(seconds=5) # 5s leeway for server update
return current_date >= datetime.now(timezone.utc) - timedelta(minutes=minutes)

def data_is_updated(self) -> bool:
Expand All @@ -400,12 +418,12 @@ def data_is_updated(self) -> bool:
latest_date = self.data[-1]['date_utc']
return self.is_latest_date(latest_date)

def insert_data(self, newData: List[List[str]]):
def insert_data(self, new_data: List[List[str]]):
"""
Inserts data from newData to run-time data.
:param newData: List with new data values.
:param new_data: List with new data values.
"""
for data in newData:
for data in new_data:
parsed_date = datetime.fromtimestamp(int(data[0]) / 1000, tz=timezone.utc)
current_dict = get_normalized_data(data=data, date_in_utc=parsed_date)
self.data.append(current_dict)
Expand Down Expand Up @@ -456,14 +474,16 @@ def get_current_data(self, counter: int = 0) -> Dict[str, Union[str, float]]:

next_interval = current_interval + timedelta(minutes=self.get_interval_minutes())
next_timestamp = int(next_interval.timestamp() * 1000) - 1
currentData = self.binanceClient.get_klines(symbol=self.symbol,
interval=self.interval,
startTime=current_timestamp,
endTime=next_timestamp,
)[0]
self.current_values = get_normalized_data(data=currentData, date_in_utc=current_interval)
current_data = self.binanceClient.get_klines(symbol=self.symbol,
interval=self.interval,
startTime=current_timestamp,
endTime=next_timestamp,
)[0]
self.current_values = get_normalized_data(data=current_data, date_in_utc=current_interval)

if counter > 0:
self.try_callback("Successfully reconnected.")

return self.current_values
except Exception as e:
sleep_time = 5 + counter * 2
Expand Down Expand Up @@ -519,20 +539,6 @@ def get_interval_minutes(self) -> int:
else:
raise ValueError("Invalid interval.", 4)

def create_folders_and_change_path(self, folder_name: str):
"""
Creates appropriate folders for data storage then changes current working directory to it.
:param folder_name: Folder to create.
"""
os.chdir(ROOT_DIR)
if not os.path.exists(folder_name): # Create CSV folder if it doesn't exist
os.mkdir(folder_name)
os.chdir(folder_name) # Go inside the folder.

if not os.path.exists(self.symbol): # Create symbol folder inside CSV folder if it doesn't exist.
os.mkdir(self.symbol)
os.chdir(self.symbol) # Go inside the folder.

def write_csv_data(self, total_data: list, file_name: str, army_time: bool = True) -> str:
"""
Writes CSV data to CSV folder in root directory of application.
Expand All @@ -541,10 +547,11 @@ def write_csv_data(self, total_data: list, file_name: str, army_time: bool = Tru
:param file_name: Filename to name CSV in.
:return: Absolute path to CSV file.
"""
current_path = os.getcwd()
self.create_folders_and_change_path(folder_name="CSV")
dir_path = os.path.join(ROOT_DIR, "CSV", self.symbol)
os.makedirs(dir_path, exist_ok=True)

with open(file_name, 'w') as f:
file_path = os.path.join(dir_path, file_name)
with open(file_path, 'w') as f:
f.write("Date_UTC, Open, High, Low, Close, Volume, Quote_Asset_Volume, Number_of_Trades, "
"Taker_Buy_Base_Asset, Taker_Buy_Quote_Asset\n")
for data in total_data:
Expand All @@ -556,10 +563,7 @@ def write_csv_data(self, total_data: list, file_name: str, army_time: bool = Tru
f'{data["volume"]}, {data["quote_asset_volume"]}, {data["number_of_trades"]}, '
f'{data["taker_buy_base_asset"]}, {data["taker_buy_quote_asset"]}\n')

path = os.path.join(os.getcwd(), file_name)
os.chdir(current_path)

return path
return file_path

def create_csv_file(self, descending: bool = True, army_time: bool = True, start_date: datetime = None) -> str:
"""
Expand All @@ -568,11 +572,10 @@ def create_csv_file(self, descending: bool = True, army_time: bool = True, start
:param descending: Boolean that decides whether values in CSV are in descending format or not.
:param army_time: Boolean that dictates whether dates will be written in army-time format or not.
"""
self.update_database_and_data() # Update data if updates exist.
file_name = f'{self.symbol}_data_{self.interval}.csv'

data = self.data
if start_date is not None:
if start_date is not None: # Getting data to start from.
for index, period in enumerate(data):
if period['date_utc'].date() <= start_date:
data = self.data[index:]
Expand All @@ -597,45 +600,40 @@ def is_valid_symbol(self, symbol: str) -> bool:
return True
return False

def is_valid_average_input(self, shift: int, prices: int, extraShift: int = 0) -> bool:
def is_valid_average_input(self, shift: int, prices: int, extra_shift: int = 0) -> bool:
"""
Checks whether shift, prices, and (optional) extraShift are valid.
:param shift: Periods from current period.
:param prices: Amount of prices to iterate over.
:param extraShift: Extra shift for EMA.
:param extra_shift: Extra shift for EMA.
:return: A boolean whether shift, prices, and extraShift are logical or not.
TODO: Deprecate along with helper get EMA and RSI.
"""
if shift < 0:
self.output_message("Shift cannot be less than 0.")
return False
elif prices <= 0:
self.output_message("Prices cannot be 0 or less than 0.")
return False
elif shift + extraShift + prices > len(self.data) + 1:
elif shift + extra_shift + prices > len(self.data) + 1:
self.output_message("Shift + prices period cannot be more than data available.")
return False
return True

def verify_integrity(self) -> bool:
@staticmethod
def verify_integrity(total_data: List[Dict[str, Union[float, datetime]]]) -> DATA_TYPE:
"""
Verifies integrity of data by checking if there's any repeated data.
:return: A boolean whether the data contains no repeated data or not.
:param total_data: Total data to verify integrity of.
:return: List of duplicate data found.
"""
if len(self.data) < 1:
self.output_message("No data found.", 4)
return False
errored_data = []
for index, data in enumerate(total_data[:-1]):
next_data = total_data[index + 1]
if next_data['date_utc'] == data['date_utc']:
errored_data.append(data)

previous_data = self.data[0]
for data in self.data[1:]:
if data['date_utc'] == previous_data['date_utc']:
self.output_message("Repeated data detected.", 4)
self.output_message(f'Previous data: {previous_data}', 4)
self.output_message(f'Next data: {data}', 4)
return False
previous_data = data

self.output_message("Data has been verified to be correct.")
return True
return errored_data

def get_total_non_updated_data(self) -> DATA_TYPE:
"""
Expand Down
11 changes: 8 additions & 3 deletions algobot/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re
import subprocess
import time
from datetime import datetime
from datetime import datetime, timezone
from typing import Dict, List, Optional, Tuple, Union

import requests
Expand Down Expand Up @@ -358,14 +358,19 @@ def parse_strategy_name(name: str) -> str:
return parsed_name


def get_normalized_data(data: List[str], date_in_utc: Union[str, datetime] = None) -> Dict[str, Union[str, float]]:
def get_normalized_data(data: List[str], date_in_utc: Union[str, datetime] = None, parse_date: bool = False) \
-> Dict[str, Union[str, float]]:
"""
Normalize data provided and return as an appropriate dictionary.
:param data: Data to normalize into a dictionary.
:param date_in_utc: Optional date to use (if provided). If not provided, it'll use the first element from data.
:param parse_date: Boolean whether to parse date or not if date in UTC is not provided.
"""
if date_in_utc is None:
date_in_utc = parser.parse(data[0]).replace(tzinfo=timezone.utc) if parse_date else data[0]

return {
'date_utc': date_in_utc if date_in_utc is not None else data[0],
'date_utc': date_in_utc,
'open': float(data[1]),
'high': float(data[2]),
'low': float(data[3]),
Expand Down
1 change: 1 addition & 0 deletions tests/binance_client_mocker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_all_tickers() -> List[Dict[str, str]]:
{"symbol": "LUNAUSDT", "price": '6.940'},
{"symbol": "XRPUSDT", "price": '0.5710'},
{"symbol": "DOGEUSDT", "price": '0.19326'},
{"symbol": "ALGOBOTUSDT", "price": "1209.54"}
]

def get_symbol_ticker(self, symbol: str = None) -> Union[Dict[str, str], List[Dict[str, str]]]:
Expand Down
Loading

0 comments on commit d21d4df

Please sign in to comment.