File size: 1,544 Bytes
39b5447
28e12d0
39b5447
 
 
28e12d0
39b5447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28e12d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39b5447
030bd38
39b5447
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import { get } from "svelte/store";
import { OPENAI_API_KEY, HF_ENDPOINT } from "$lib/store";
import { Configuration, OpenAIApi } from "openai";
import { getInference } from "./getInference";
import type { LLM } from "$lib/types";
import type { TextGenerationOutput } from "@huggingface/inference";

async function OpenAILLMCall(prompt: string) {
  const openai = new OpenAIApi(
    new Configuration({ apiKey: get(OPENAI_API_KEY) })
  );
  const textAnswer =
    (
      await openai.createCompletion({
        model: "text-davinci-003",
        prompt: prompt,
        max_tokens: 1000,
      })
    ).data.choices[0].text ?? "";

  return textAnswer;
}

async function HFLLMCall(prompt: string) {
  const formattedPrompt = "<|user|>" + prompt + "<|end|><|assistant|>";

  let output: TextGenerationOutput;

  if (!!get(HF_ENDPOINT)) {
    output = await getInference()
      .endpoint(get(HF_ENDPOINT))
      .textGeneration({
        inputs: formattedPrompt,
        parameters: {
          max_new_tokens: 1400,
        },
      });
  } else {
    output = await getInference().textGeneration({
      inputs: formattedPrompt,
      model: "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
      parameters: {
        max_new_tokens: 900,
      },
    });
  }

  const text = output.generated_text.slice(formattedPrompt.length);

  return text;
}

export const OpenAILLM: LLM = {
  name: "OpenAI",
  call: OpenAILLMCall,
};

export const HFLLM: LLM = {
  name: "Hugging Face",
  call: HFLLMCall,
};

export const LLMs = [OpenAILLM, HFLLM];