suryanshs16103 commited on
Commit
6f69e3c
·
verified ·
1 Parent(s): 056749f
Files changed (2) hide show
  1. app.py +182 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import requests
3
+ import matplotlib.pyplot as plt
4
+ from mplfinance.original_flavor import candlestick_ohlc
5
+ import numpy as np
6
+ from sklearn.linear_model import LinearRegression
7
+ import os
8
+ from pathlib import Path
9
+ import streamlit as st
10
+
11
+ PLOT_DIR = Path("./Plots")
12
+
13
+ if not os.path.exists(PLOT_DIR):
14
+ os.mkdir(PLOT_DIR)
15
+
16
+ host = "https://api.gateio.ws"
17
+ prefix = "/api/v4"
18
+ headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
19
+ endpoint = '/spot/candlesticks'
20
+ url = host + prefix + endpoint
21
+ max_API_request_allowed = 900
22
+
23
+ def lin_reg(data, threshold_channel_len):
24
+ list_f = []
25
+ X = []
26
+ y = []
27
+ for i in range(0, len(data)):
28
+ X.append(data[i][0])
29
+ avg = (data[i][2] + data[i][3]) / 2
30
+ y.append(avg)
31
+ X = np.array(X).reshape(-1, 1)
32
+ y = np.array(y).reshape(-1, 1)
33
+ l = 0
34
+ j = threshold_channel_len
35
+ while l < j and j <= len(data):
36
+ score = []
37
+ list_pf = []
38
+ while j <= len(data):
39
+ reg = LinearRegression().fit(X[l:j], y[l:j])
40
+ temp_coeff = list(reg.coef_)
41
+ temp_intercept = list(reg.intercept_)
42
+ list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
43
+ score.append([reg.score(X[l:j], y[l:j]), j])
44
+ j = j + 1
45
+ req_score = float("-inf")
46
+ ind = -1
47
+ temp_ind = -1
48
+ for i in range(len(score)):
49
+ if req_score < score[i][0]:
50
+ ind = score[i][1]
51
+ req_score = score[i][0]
52
+ temp_ind = i
53
+ list_f.append(list_pf[temp_ind])
54
+ l = ind
55
+ j = ind + threshold_channel_len
56
+ return list_f
57
+
58
+ def binary_search(data, line_type, m, b, epsilon):
59
+ right = float("-inf")
60
+ left = float("inf")
61
+ get_y_intercept = lambda x, y: y - m * x
62
+ for i in range(len(data)):
63
+ d = data[i]
64
+ curr_y = d[2]
65
+ if line_type == "bottom":
66
+ curr_y = d[3]
67
+ curr = get_y_intercept(d[0], curr_y)
68
+ right = max(right, curr)
69
+ left = min(left, curr)
70
+
71
+ sign = -1
72
+ if line_type == "bottom":
73
+ left, right = right, left
74
+ sign = 1
75
+ ans = right
76
+ while left <= right:
77
+ mid = left + (right - left) // 2
78
+ intersection_count = 0
79
+ for i in range(len(data)):
80
+ d = data[i]
81
+ curr_y = m * d[0] + mid
82
+ candle_y = d[2]
83
+ if line_type == "bottom":
84
+ candle_y = d[3]
85
+ if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
86
+ intersection_count += 1
87
+ if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
88
+ intersection_count += 1
89
+ if intersection_count == 0:
90
+ right = mid + 1 * sign
91
+ ans = mid
92
+ else:
93
+ left = mid - 1 * sign
94
+ return ans
95
+
96
+ def plot_lines(lines, plt, converted_data):
97
+ for m, b, start, end in lines:
98
+ x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
99
+ y_data = [m * x + b for x in x_data]
100
+ plt.plot(x_data, y_data)
101
+
102
+ def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
103
+ curr_datetime = start_datetime
104
+ total_dates = 0
105
+ while curr_datetime <= end_datetime:
106
+ total_dates += 1
107
+ curr_datetime += interval_timedelta
108
+ data = []
109
+ for i in range(0, total_dates, max_API_request_allowed):
110
+ query_param = {
111
+ "currency_pair": "{}_USDT".format(currency),
112
+ "from": int((start_datetime + i * interval_timedelta).timestamp()),
113
+ "to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
114
+ "interval": interval,
115
+ }
116
+ r = requests.get(url=url, headers=headers, params=query_param)
117
+ if r.status_code != 200:
118
+ st.error("Invalid API Request")
119
+ return []
120
+ data += r.json()
121
+ return data
122
+
123
+ def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
124
+ start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
125
+ end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
126
+
127
+ if interval == "1h":
128
+ interval_timedelta = datetime.timedelta(hours=1)
129
+ elif interval == "4h":
130
+ interval_timedelta = datetime.timedelta(hours=4)
131
+ elif interval == "1d":
132
+ interval_timedelta = datetime.timedelta(days=1)
133
+ else:
134
+ interval_timedelta = datetime.timedelta(weeks=1)
135
+
136
+ start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
137
+ end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
138
+
139
+ data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
140
+ if len(data) == 0:
141
+ return
142
+ converted_data = []
143
+ for d in data:
144
+ converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
145
+
146
+ fig, ax = plt.subplots()
147
+ candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
148
+
149
+ fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
150
+ top_fitting_lines_data = []
151
+ bottom_fitting_lines_data = []
152
+ epsilon = 0
153
+ for i in range(len(fitting_lines_data)):
154
+ m, b, start, end = fitting_lines_data[i]
155
+ top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
156
+ bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
157
+ top_fitting_lines_data.append([m, top_b, start, end])
158
+ bottom_fitting_lines_data.append([m, bottom_b, start, end])
159
+
160
+ plot_lines(top_fitting_lines_data, plt, converted_data)
161
+ plot_lines(bottom_fitting_lines_data, plt, converted_data)
162
+ plt.title("{}_USDT".format(currency))
163
+ file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
164
+ file_location = os.path.join(PLOT_DIR, file_name)
165
+ plt.savefig(file_location)
166
+ st.pyplot(fig)
167
+
168
+ def main():
169
+ st.title("Cryptocurrency Regression Analysis")
170
+ st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
171
+
172
+ currency = st.text_input("Currency", "BTC")
173
+ interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"])
174
+ startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2022")
175
+ enddate = st.text_input("End Date (MM/DD/YYYY)", "12/31/2022")
176
+ threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
177
+
178
+ if st.button("Generate Plot"):
179
+ testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
180
+
181
+ if __name__ == "__main__":
182
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ requests
3
+ matplotlib
4
+ mplfinance
5
+ numpy
6
+ scikit-learn