File size: 9,130 Bytes
2493442
 
a4924d0
2493442
 
a4924d0
 
33c1ed4
a4924d0
2493442
 
 
 
 
a4924d0
2493442
a4924d0
 
 
 
 
 
2493442
 
 
a4924d0
68991f5
5395785
 
a14f9a2
a4924d0
2493442
 
 
 
 
 
a4924d0
 
 
 
2493442
 
a4924d0
 
 
 
 
67f29a0
 
 
 
2493442
67f29a0
 
 
a4924d0
2493442
a4924d0
2493442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33c1ed4
a4924d0
 
67f29a0
 
 
 
2493442
 
 
 
 
a4924d0
2493442
a4924d0
2493442
33c1ed4
a4924d0
3bdc2b8
a4924d0
 
 
2493442
 
a4924d0
33c1ed4
 
67f29a0
 
 
 
 
 
 
33c1ed4
 
a4924d0
2493442
 
a4924d0
 
2493442
 
 
 
 
 
 
 
 
67f29a0
 
 
 
 
2493442
 
a4924d0
 
 
2493442
a4924d0
 
 
 
 
2493442
a4924d0
 
33c1ed4
 
a4924d0
2493442
 
 
 
 
 
 
a4924d0
 
 
 
 
 
 
 
 
2493442
 
 
67f29a0
 
 
2493442
 
 
a4924d0
 
 
 
 
 
 
 
2493442
a4924d0
 
 
 
 
 
2493442
a4924d0
 
2493442
91c55e5
a4924d0
 
 
2493442
a4924d0
 
 
 
 
 
 
 
 
 
 
 
 
 
2493442
a4924d0
 
 
2493442
a4924d0
 
 
 
 
 
 
 
 
 
 
 
 
67f29a0
 
 
 
 
a4924d0
 
 
2493442
a4924d0
 
 
67f29a0
 
a4924d0
 
 
2493442
 
 
 
 
 
 
a4924d0
 
 
 
 
 
 
 
2493442
a4924d0
 
 
 
2493442
 
 
 
 
a4924d0
2493442
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# pylint: disable=line-too-long, missing-module-docstring, missing-function-docstring,broad-exception-caught, too-many-statements

import json
import os
import re
import time
from datetime import datetime, timedelta
from textwrap import dedent

import requests

# from bs4 import BeautifulSoup
from flask import Flask, Response, request, stream_with_context
from loguru import logger
from ycecream import y

y.configure(sln=1)

app = Flask(__name__)

APP_SESSION_VALIDITY = timedelta(days=3)
ACCESS_TOKEN_VALIDITY = timedelta(hours=1)
USERNAME = os.environ.get("USERNAME", "")
PASSWORD = os.environ.get("PASSWORD", "")
AUTHKEY = os.environ.get("AUTHKEY", "")

_ = "Set USERNAME and PASSWORD somewhere (e.g., set/export or secrets in hf space, or you wont be able to fetch sfuff from reka.ai)"
assert USERNAME and PASSWORD, _

y(USERNAME[:3], PASSWORD[:2], AUTHKEY[:2])

cache = {
    "app_session": None,
    "app_session_time": None,
    "access_token": None,
    "access_token_time": None,
}

y(cache)

# 配置日志记录
# logging.basicConfig(level=logging.DEBUG)


def fetch_tokens():
    session = requests.Session()

    # 检查并获取 appSession
    if (
        not cache["app_session_time"]
        or datetime.now() - cache["app_session_time"] >= APP_SESSION_VALIDITY
    ):
        logger.info("Fetching new appSession")
        login_page_response = session.get(
            "https://chat.reka.ai/bff/auth/login", allow_redirects=True
        )
        if login_page_response.status_code != 200:
            logger.error("Failed to load login page")
            return None

        # soup = BeautifulSoup(login_page_response.text, "html.parser")
        # state_value = soup.find("input", {"name": "state"})["value"]

        state_value = ""
        _ = re.search(r"\w{20,}", login_page_response.text)
        if _:
            state_value = _.group()

        session.post(
            "https://auth.reka.ai/u/login",
            data={
                "state": state_value,
                "username": USERNAME,
                "password": PASSWORD,
                "action": "default",
            },
        )
        cache["app_session"] = session.cookies.get("appSession")
        cache["app_session_time"] = datetime.now()  # type: ignore

    # 检查并获取 accessToken
    if (
        not cache["access_token_time"]
        or datetime.now() - cache["access_token_time"] >= ACCESS_TOKEN_VALIDITY
    ):
        logger.info("Fetching new accessToken")
        response = session.get(
            "https://chat.reka.ai/bff/auth/access_token",
            headers={"Cookie": f'appSession={cache["app_session"]}'},
        )
        if response.status_code != 200:
            logger.error("Failed to get access token")
            return None
        cache["access_token"] = response.json().get("accessToken")
        cache["access_token_time"] = datetime.now()  # type: ignore

    # y(cache)

    return cache["access_token"]


@app.route("/")
def landing():
    return dedent(
        """
        <p>
        query /hf/v1/chat/completions for a spin, e.g. <br/> curl -XPOST 127.0.0.1:7860/hf/v1/chat/completions -H "Authorization: Bearer Your_AUTHKEY"
        </p>
        <p>
        or hf-space-url e.g.,<br/>
        curl -XPOST https://mikeee-reka.hf.space/hf/v1/chat/completions -H "Authorization: Bearer Your_AUTHKEY"  -H "Content-Type: application/json" --data "{\\"model\\": \\"reka-core\\", \\"messages\\": [{\\"role\\": \\"user\\", \\"content\\": \\"Say this is a test!\\"}]}"
        </p>
        """
    )


@app.route("/hf/v1/chat/completions", methods=["POST", "OPTIONS"])
def chat_completions():
    if request.method == "OPTIONS":
        return Response(
            "",
            status=204,
            headers={
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Headers": "*",
            },
        )

    if (
        request.method != "POST"
        or request.path != "/hf/v1/chat/completions"
        or request.headers.get("Authorization") != f"Bearer {AUTHKEY}"
    ):
        logger.error("Unauthorized access attempt")
        return Response("Unauthorized", status=401)

    access_token = fetch_tokens()
    if not access_token:
        logger.error("Failed to obtain access token")
        return Response("Failed to obtain access token.", status=500)

    try:
        request_body = request.json
    except Exception as e:
        logger.error(f"Error parsing JSON body: {e}")
        return Response("Error parsing JSON body", status=400)

    messages = request_body.get("messages", [])  # type: ignore
    model = request_body.get("model", "reka-core")  # type: ignore

    conversation_history = [
        {
            "type": "human" if msg["role"] in ["user", "system"] else "model",
            "text": msg["content"],
        }
        for msg in messages
    ]

    if conversation_history and conversation_history[0]["type"] != "human":
        conversation_history.insert(0, {"type": "human", "text": ""})
    if conversation_history and conversation_history[-1]["type"] != "human":
        conversation_history.append({"type": "human", "text": ""})

    i = 0
    while i < len(conversation_history) - 1:
        if conversation_history[i]["type"] == conversation_history[i + 1]["type"]:
            conversation_history.insert(
                i + 1,
                {
                    "type": "model"
                    if conversation_history[i]["type"] == "human"
                    else "human",
                    "text": "",
                },
            )
        i += 1

    new_request_body = {
        "conversation_history": conversation_history,
        "stream": True,
        "use_search_engine": False,
        "use_code_interpreter": False,
        "model_name": "reka-core",
        "random_seed": int(time.time()),
    }

    response = requests.post(
        "https://chat.reka.ai/api/chat",
        headers={
            "authorization": f"bearer {access_token}",
            "content-type": "application/json",
        },
        data=json.dumps(new_request_body),
        stream=True,
        timeout=600,  # timeout 10 min.
    )

    if response.status_code != 200:
        logger.error(f"Error from external API: {response.status_code} {response.text}")
        return Response(response.text, status=response.status_code)

    created = int(time.time())

    def generate_stream():
        decoder = json.JSONDecoder()
        encoder = json.JSONEncoder()
        content_buffer = ""
        full_content = ""
        prev_content = ""
        last_four_texts = []

        for line in response.iter_lines():
            if line:
                content_buffer += line.decode("utf-8") + "\n"
                while "\n" in content_buffer:
                    newline_index = content_buffer.index("\n")
                    line = content_buffer[:newline_index]
                    content_buffer = content_buffer[newline_index + 1 :]

                    if not line.startswith("data:"):
                        continue

                    try:
                        data = decoder.decode(line[5:])
                    except json.JSONDecodeError:
                        continue

                    last_four_texts.append(data["text"])
                    if len(last_four_texts) > 4:
                        last_four_texts.pop(0)

                    if len(last_four_texts) == 4 and (
                        len(last_four_texts[3]) < len(last_four_texts[2])
                        or last_four_texts[3].endswith("<sep")
                        or last_four_texts[3].endswith("<")
                    ):
                        break

                    full_content = data["text"]
                    new_content = full_content[len(prev_content) :]
                    prev_content = full_content

                    formatted_data = {
                        "id": "chatcmpl-"
                        + "".join([str(time.time()), str(hash(new_content))]),
                        "object": "chat.completion.chunk",
                        "created": created,
                        "model": model,
                        "choices": [
                            {
                                "index": 0,
                                "delta": {"content": new_content},
                                "finish_reason": None,
                            }
                        ],
                    }
                    yield f"data: {encoder.encode(formatted_data)}\n\n"

        done_data = {
            "id": "chatcmpl-" + "".join([str(time.time()), str(hash("done"))]),
            "object": "chat.completion.chunk",
            "created": created,
            "model": model,
            "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
        }
        yield f"data: {json.dumps(done_data)}\n\n"
        yield "data: [DONE]\n\n"

    return Response(
        stream_with_context(generate_stream()),
        headers={"Content-Type": "text/event-stream"},
    )


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)