ddh0 commited on
Commit
57b6bc3
·
verified ·
1 Parent(s): 0a930e1

Upload tensor_type_testing.py

Browse files
Files changed (1) hide show
  1. tensor_type_testing.py +128 -0
tensor_type_testing.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tensor_type_testing.py
2
+ # Python 3.11.2
3
+
4
+ import os
5
+ import sys
6
+
7
+ import numpy as np
8
+ import easy_llama as ez
9
+
10
+ from typing import Union
11
+
12
+ INPUT_TEXTS_AS_TEXT: list[str] = []
13
+
14
+ for i in range(10):
15
+ with open(f'./inputs/{i}.txt', 'r') as file:
16
+ INPUT_TEXTS_AS_TEXT.append(file.read())
17
+
18
+ BASELINE_MODEL_PATH = '/opt/workspace/gguf/Qwen2.5-14B-BF16.gguf'
19
+ BASELINE_MODEL_FILENAME = os.path.basename(BASELINE_MODEL_PATH)
20
+ QUANT_MODEL_DIR = '/opt/workspace/gguf/'
21
+ QUANT_MODEL_FILES = [
22
+ 'Qwen2.5-14B-Q2_K.gguf',
23
+ 'Qwen2.5-14B-EQ2_K-FQ8_0-AQ8_0-OQ8_0.gguf',
24
+ 'Qwen2.5-14B-EQ8_0-FQ2_K-AQ8_0-OQ8_0.gguf',
25
+ 'Qwen2.5-14B-EQ8_0-FQ8_0-AQ2_K-OQ8_0.gguf',
26
+ 'Qwen2.5-14B-EQ8_0-FQ8_0-AQ8_0-OQ2_K.gguf',
27
+ 'Qwen2.5-14B-Q8_0.gguf'
28
+ ]
29
+
30
+ def msd(a: np.ndarray, b: np.ndarray) -> np.floating:
31
+ return np.mean((a - b) ** 2)
32
+
33
+ def tokenize_prompt(llama: ez.Llama, prompt: str) -> list[int]:
34
+ return llama.tokenize(
35
+ text_bytes=prompt.encode('utf-8', 'strict'),
36
+ add_special=True,
37
+ parse_special=False
38
+ )
39
+
40
+ def eval_text(llama: ez.Llama, text_toks: list[int]) -> np.ndarray:
41
+ llama.reset()
42
+ logits = llama.eval(input_tokens=text_toks, logits_all=True)
43
+ return logits
44
+
45
+ def load_llama(model_file: str) -> ez.Llama:
46
+ return ez.Llama(
47
+ path_model=model_file,
48
+ n_gpu_layers=10,
49
+ use_mmap=False,
50
+ use_mlock=False,
51
+ n_ctx=5120,
52
+ offload_kqv=True,
53
+ warmup=False,
54
+ verbose=False
55
+ )
56
+
57
+ def get_model_results(model_path: str) -> list[np.ndarray]:
58
+ print('Load model...')
59
+ Llama = load_llama(model_path)
60
+ print('Evaluate prompts...')
61
+ results = [eval_text(Llama, prompt) for prompt in input_texts_as_tokens]
62
+ print('Unload model...')
63
+ Llama.free()
64
+ return results
65
+
66
+ def main() -> int:
67
+
68
+ global input_texts_as_tokens
69
+
70
+ results: dict[str, list[Union[list[np.floating], np.floating]]] = {}
71
+
72
+ baseline_llama = load_llama(BASELINE_MODEL_PATH)
73
+ input_texts_as_tokens = [
74
+ tokenize_prompt(baseline_llama, text) for text in INPUT_TEXTS_AS_TEXT
75
+ ]
76
+ n_inputs = len(input_texts_as_tokens)
77
+ max_len_input = max(len(toks) for toks in input_texts_as_tokens)
78
+ min_len_input = min(len(toks) for toks in input_texts_as_tokens)
79
+ avg_len_input = sum(len(toks) for toks in input_texts_as_tokens) / n_inputs
80
+ n_input_tokens = sum(len(toks) for toks in input_texts_as_tokens)
81
+ print(f' Number of input texts: {len(input_texts_as_tokens)}')
82
+ print(f'Shortest input length in tokens: {min_len_input}')
83
+ print(f' Longest input length in tokens: {max_len_input}')
84
+ print(f' Average input length in tokens: {avg_len_input}')
85
+ print(f' Total number of input tokens: {n_input_tokens}')
86
+ print('-' * 80)
87
+ baseline_llama.free()
88
+
89
+ print(f'Evaluating baseline model {BASELINE_MODEL_FILENAME}...')
90
+ baseline_results = get_model_results(BASELINE_MODEL_PATH)
91
+
92
+ for quant_file in QUANT_MODEL_FILES:
93
+ quant_path = os.path.join(QUANT_MODEL_DIR, quant_file)
94
+ if not os.path.exists(quant_path):
95
+ print(f"Error: {quant_path} not found. Skipping.")
96
+ continue
97
+
98
+ print('-' * 80)
99
+ print(f'Now processing: {quant_file}')
100
+ quant_results = get_model_results(quant_path)
101
+
102
+ print(f'Compute MSD...')
103
+ deviations = [
104
+ msd(baseline_results[i], quant_results[i]) for i in range(len(quant_results))
105
+ ]
106
+ avg = np.mean(deviations)
107
+
108
+ results[quant_file] = [deviations, avg]
109
+
110
+ print(
111
+ f'Mean-Squared Deviation - '
112
+ f'{BASELINE_MODEL_FILENAME} vs. {os.path.basename(quant_path)}:'
113
+ )
114
+ for i in range(len(input_texts_as_tokens)):
115
+ print(f'-- Prompt {i}: {deviations[i]}')
116
+ print(f'Average MSD: {avg}')
117
+
118
+ print('-' * 80)
119
+ print(f'Average Mean-Squared Deviation compared to {BASELINE_MODEL_FILENAME}:')
120
+ print('-' * 80)
121
+ for k, v in results.items():
122
+ print(f'{k:>60} -- {v[1]}')
123
+ print('-' * 80)
124
+
125
+ return 0
126
+
127
+ if __name__ == '__main__':
128
+ sys.exit(main())