init
Browse files- app.py +182 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import requests
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from mplfinance.original_flavor import candlestick_ohlc
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.linear_model import LinearRegression
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
PLOT_DIR = Path("./Plots")
|
12 |
+
|
13 |
+
if not os.path.exists(PLOT_DIR):
|
14 |
+
os.mkdir(PLOT_DIR)
|
15 |
+
|
16 |
+
host = "https://api.gateio.ws"
|
17 |
+
prefix = "/api/v4"
|
18 |
+
headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
|
19 |
+
endpoint = '/spot/candlesticks'
|
20 |
+
url = host + prefix + endpoint
|
21 |
+
max_API_request_allowed = 900
|
22 |
+
|
23 |
+
def lin_reg(data, threshold_channel_len):
|
24 |
+
list_f = []
|
25 |
+
X = []
|
26 |
+
y = []
|
27 |
+
for i in range(0, len(data)):
|
28 |
+
X.append(data[i][0])
|
29 |
+
avg = (data[i][2] + data[i][3]) / 2
|
30 |
+
y.append(avg)
|
31 |
+
X = np.array(X).reshape(-1, 1)
|
32 |
+
y = np.array(y).reshape(-1, 1)
|
33 |
+
l = 0
|
34 |
+
j = threshold_channel_len
|
35 |
+
while l < j and j <= len(data):
|
36 |
+
score = []
|
37 |
+
list_pf = []
|
38 |
+
while j <= len(data):
|
39 |
+
reg = LinearRegression().fit(X[l:j], y[l:j])
|
40 |
+
temp_coeff = list(reg.coef_)
|
41 |
+
temp_intercept = list(reg.intercept_)
|
42 |
+
list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
|
43 |
+
score.append([reg.score(X[l:j], y[l:j]), j])
|
44 |
+
j = j + 1
|
45 |
+
req_score = float("-inf")
|
46 |
+
ind = -1
|
47 |
+
temp_ind = -1
|
48 |
+
for i in range(len(score)):
|
49 |
+
if req_score < score[i][0]:
|
50 |
+
ind = score[i][1]
|
51 |
+
req_score = score[i][0]
|
52 |
+
temp_ind = i
|
53 |
+
list_f.append(list_pf[temp_ind])
|
54 |
+
l = ind
|
55 |
+
j = ind + threshold_channel_len
|
56 |
+
return list_f
|
57 |
+
|
58 |
+
def binary_search(data, line_type, m, b, epsilon):
|
59 |
+
right = float("-inf")
|
60 |
+
left = float("inf")
|
61 |
+
get_y_intercept = lambda x, y: y - m * x
|
62 |
+
for i in range(len(data)):
|
63 |
+
d = data[i]
|
64 |
+
curr_y = d[2]
|
65 |
+
if line_type == "bottom":
|
66 |
+
curr_y = d[3]
|
67 |
+
curr = get_y_intercept(d[0], curr_y)
|
68 |
+
right = max(right, curr)
|
69 |
+
left = min(left, curr)
|
70 |
+
|
71 |
+
sign = -1
|
72 |
+
if line_type == "bottom":
|
73 |
+
left, right = right, left
|
74 |
+
sign = 1
|
75 |
+
ans = right
|
76 |
+
while left <= right:
|
77 |
+
mid = left + (right - left) // 2
|
78 |
+
intersection_count = 0
|
79 |
+
for i in range(len(data)):
|
80 |
+
d = data[i]
|
81 |
+
curr_y = m * d[0] + mid
|
82 |
+
candle_y = d[2]
|
83 |
+
if line_type == "bottom":
|
84 |
+
candle_y = d[3]
|
85 |
+
if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
|
86 |
+
intersection_count += 1
|
87 |
+
if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
|
88 |
+
intersection_count += 1
|
89 |
+
if intersection_count == 0:
|
90 |
+
right = mid + 1 * sign
|
91 |
+
ans = mid
|
92 |
+
else:
|
93 |
+
left = mid - 1 * sign
|
94 |
+
return ans
|
95 |
+
|
96 |
+
def plot_lines(lines, plt, converted_data):
|
97 |
+
for m, b, start, end in lines:
|
98 |
+
x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
|
99 |
+
y_data = [m * x + b for x in x_data]
|
100 |
+
plt.plot(x_data, y_data)
|
101 |
+
|
102 |
+
def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
|
103 |
+
curr_datetime = start_datetime
|
104 |
+
total_dates = 0
|
105 |
+
while curr_datetime <= end_datetime:
|
106 |
+
total_dates += 1
|
107 |
+
curr_datetime += interval_timedelta
|
108 |
+
data = []
|
109 |
+
for i in range(0, total_dates, max_API_request_allowed):
|
110 |
+
query_param = {
|
111 |
+
"currency_pair": "{}_USDT".format(currency),
|
112 |
+
"from": int((start_datetime + i * interval_timedelta).timestamp()),
|
113 |
+
"to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
|
114 |
+
"interval": interval,
|
115 |
+
}
|
116 |
+
r = requests.get(url=url, headers=headers, params=query_param)
|
117 |
+
if r.status_code != 200:
|
118 |
+
st.error("Invalid API Request")
|
119 |
+
return []
|
120 |
+
data += r.json()
|
121 |
+
return data
|
122 |
+
|
123 |
+
def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
|
124 |
+
start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
|
125 |
+
end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
|
126 |
+
|
127 |
+
if interval == "1h":
|
128 |
+
interval_timedelta = datetime.timedelta(hours=1)
|
129 |
+
elif interval == "4h":
|
130 |
+
interval_timedelta = datetime.timedelta(hours=4)
|
131 |
+
elif interval == "1d":
|
132 |
+
interval_timedelta = datetime.timedelta(days=1)
|
133 |
+
else:
|
134 |
+
interval_timedelta = datetime.timedelta(weeks=1)
|
135 |
+
|
136 |
+
start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
|
137 |
+
end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
|
138 |
+
|
139 |
+
data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
|
140 |
+
if len(data) == 0:
|
141 |
+
return
|
142 |
+
converted_data = []
|
143 |
+
for d in data:
|
144 |
+
converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
|
145 |
+
|
146 |
+
fig, ax = plt.subplots()
|
147 |
+
candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
|
148 |
+
|
149 |
+
fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
|
150 |
+
top_fitting_lines_data = []
|
151 |
+
bottom_fitting_lines_data = []
|
152 |
+
epsilon = 0
|
153 |
+
for i in range(len(fitting_lines_data)):
|
154 |
+
m, b, start, end = fitting_lines_data[i]
|
155 |
+
top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
|
156 |
+
bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
|
157 |
+
top_fitting_lines_data.append([m, top_b, start, end])
|
158 |
+
bottom_fitting_lines_data.append([m, bottom_b, start, end])
|
159 |
+
|
160 |
+
plot_lines(top_fitting_lines_data, plt, converted_data)
|
161 |
+
plot_lines(bottom_fitting_lines_data, plt, converted_data)
|
162 |
+
plt.title("{}_USDT".format(currency))
|
163 |
+
file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
|
164 |
+
file_location = os.path.join(PLOT_DIR, file_name)
|
165 |
+
plt.savefig(file_location)
|
166 |
+
st.pyplot(fig)
|
167 |
+
|
168 |
+
def main():
|
169 |
+
st.title("Cryptocurrency Regression Analysis")
|
170 |
+
st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
|
171 |
+
|
172 |
+
currency = st.text_input("Currency", "BTC")
|
173 |
+
interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"])
|
174 |
+
startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2022")
|
175 |
+
enddate = st.text_input("End Date (MM/DD/YYYY)", "12/31/2022")
|
176 |
+
threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
|
177 |
+
|
178 |
+
if st.button("Generate Plot"):
|
179 |
+
testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
requests
|
3 |
+
matplotlib
|
4 |
+
mplfinance
|
5 |
+
numpy
|
6 |
+
scikit-learn
|