Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
138286c
1
Parent(s):
5fa40a6
Added framework of support for Azure models (although untested)
Browse files- app.py +9 -4
- requirements.txt +7 -1
- requirements_cpu.txt +3 -1
- requirements_gpu.txt +3 -1
- requirements_no_local.txt +3 -1
- tools/aws_functions.py +17 -9
- tools/config.py +34 -3
- tools/dedup_summaries.py +9 -3
- tools/llm_api_call.py +82 -21
- tools/llm_funcs.py +83 -50
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from tools.custom_csvlogger import CSVLogger_custom
|
|
| 12 |
from tools.auth import authenticate_user
|
| 13 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
| 14 |
from tools.verify_titles import verify_titles
|
| 15 |
-
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN
|
| 16 |
|
| 17 |
def ensure_folder_exists(output_folder:str):
|
| 18 |
"""Checks if the specified folder exists, creates it if not."""
|
|
@@ -307,6 +307,9 @@ with app:
|
|
| 307 |
with gr.Accordion("Gemini API keys", open = False):
|
| 308 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
| 309 |
|
|
|
|
|
|
|
|
|
|
| 310 |
with gr.Accordion("Hugging Face API keys", open = False):
|
| 311 |
hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
|
| 312 |
|
|
@@ -369,7 +372,7 @@ with app:
|
|
| 369 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 370 |
success(load_in_data_file,
|
| 371 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
| 372 |
-
|
| 373 |
inputs=[in_group_col,
|
| 374 |
in_data_files,
|
| 375 |
file_data_state,
|
|
@@ -405,6 +408,7 @@ with app:
|
|
| 405 |
aws_access_key_textbox,
|
| 406 |
aws_secret_key_textbox,
|
| 407 |
hf_api_key_textbox,
|
|
|
|
| 408 |
output_folder_state],
|
| 409 |
outputs=[display_topic_table_markdown,
|
| 410 |
master_topic_df_state,
|
|
@@ -467,10 +471,10 @@ with app:
|
|
| 467 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 468 |
success(load_in_data_file,
|
| 469 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
| 470 |
-
|
| 471 |
inputs=[in_group_col,
|
| 472 |
in_data_files,
|
| 473 |
-
file_data_state,
|
| 474 |
master_topic_df_state,
|
| 475 |
master_reference_df_state,
|
| 476 |
master_unique_topics_df_state,
|
|
@@ -503,6 +507,7 @@ with app:
|
|
| 503 |
aws_access_key_textbox,
|
| 504 |
aws_secret_key_textbox,
|
| 505 |
hf_api_key_textbox,
|
|
|
|
| 506 |
output_folder_state],
|
| 507 |
outputs=[display_topic_table_markdown,
|
| 508 |
master_topic_df_state,
|
|
|
|
| 12 |
from tools.auth import authenticate_user
|
| 13 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
| 14 |
from tools.verify_titles import verify_titles
|
| 15 |
+
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN, AZURE_API_KEY
|
| 16 |
|
| 17 |
def ensure_folder_exists(output_folder:str):
|
| 18 |
"""Checks if the specified folder exists, creates it if not."""
|
|
|
|
| 307 |
with gr.Accordion("Gemini API keys", open = False):
|
| 308 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
| 309 |
|
| 310 |
+
with gr.Accordion("Azure AI Inference", open = False):
|
| 311 |
+
azure_api_key_textbox = gr.Textbox(value = AZURE_API_KEY, label="Enter Azure AI Inference API key (only if using Azure models)", lines=1, type="password")
|
| 312 |
+
|
| 313 |
with gr.Accordion("Hugging Face API keys", open = False):
|
| 314 |
hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
|
| 315 |
|
|
|
|
| 372 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 373 |
success(load_in_data_file,
|
| 374 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
| 375 |
+
success(fn=wrapper_extract_topics_per_column_value,
|
| 376 |
inputs=[in_group_col,
|
| 377 |
in_data_files,
|
| 378 |
file_data_state,
|
|
|
|
| 408 |
aws_access_key_textbox,
|
| 409 |
aws_secret_key_textbox,
|
| 410 |
hf_api_key_textbox,
|
| 411 |
+
azure_api_key_textbox,
|
| 412 |
output_folder_state],
|
| 413 |
outputs=[display_topic_table_markdown,
|
| 414 |
master_topic_df_state,
|
|
|
|
| 471 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 472 |
success(load_in_data_file,
|
| 473 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
| 474 |
+
success(fn=wrapper_extract_topics_per_column_value,
|
| 475 |
inputs=[in_group_col,
|
| 476 |
in_data_files,
|
| 477 |
+
file_data_state,
|
| 478 |
master_topic_df_state,
|
| 479 |
master_reference_df_state,
|
| 480 |
master_unique_topics_df_state,
|
|
|
|
| 507 |
aws_access_key_textbox,
|
| 508 |
aws_secret_key_textbox,
|
| 509 |
hf_api_key_textbox,
|
| 510 |
+
azure_api_key_textbox,
|
| 511 |
output_folder_state],
|
| 512 |
outputs=[display_topic_table_markdown,
|
| 513 |
master_topic_df_state,
|
requirements.txt
CHANGED
|
@@ -9,7 +9,9 @@ openpyxl==3.1.5
|
|
| 9 |
markdown==3.7
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
-
google-genai==1.
|
|
|
|
|
|
|
| 13 |
html5lib==1.1
|
| 14 |
beautifulsoup4==4.12.3
|
| 15 |
rapidfuzz==3.13.0
|
|
@@ -24,5 +26,9 @@ accelerate==1.10.1
|
|
| 24 |
#torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
| 25 |
# For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
|
| 26 |
#https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
|
|
|
| 9 |
markdown==3.7
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
+
google-genai==1.33.0
|
| 13 |
+
azure-ai-inference==1.0.0b9
|
| 14 |
+
azure-core==1.35.0
|
| 15 |
html5lib==1.1
|
| 16 |
beautifulsoup4==4.12.3
|
| 17 |
rapidfuzz==3.13.0
|
|
|
|
| 26 |
#torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
| 27 |
# For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
|
| 28 |
#https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
| 29 |
+
# CPU only (for e.g. Hugging Face CPU instances)
|
| 30 |
+
#torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
| 31 |
+
# For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
|
| 32 |
+
#https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
| 33 |
|
| 34 |
|
requirements_cpu.txt
CHANGED
|
@@ -8,7 +8,9 @@ openpyxl==3.1.5
|
|
| 8 |
markdown==3.7
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
-
google-genai==1.
|
|
|
|
|
|
|
| 12 |
html5lib==1.1
|
| 13 |
beautifulsoup4==4.12.3
|
| 14 |
rapidfuzz==3.13.0
|
|
|
|
| 8 |
markdown==3.7
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
+
google-genai==1.33.0
|
| 12 |
+
azure-ai-inference==1.0.0b9
|
| 13 |
+
azure-core==1.35.0
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
requirements_gpu.txt
CHANGED
|
@@ -8,7 +8,9 @@ openpyxl==3.1.5
|
|
| 8 |
markdown==3.7
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
-
google-genai==1.
|
|
|
|
|
|
|
| 12 |
html5lib==1.1
|
| 13 |
beautifulsoup4==4.12.3
|
| 14 |
rapidfuzz==3.13.0
|
|
|
|
| 8 |
markdown==3.7
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
+
google-genai==1.33.0
|
| 12 |
+
azure-ai-inference==1.0.0b9
|
| 13 |
+
azure-core==1.35.0
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
requirements_no_local.txt
CHANGED
|
@@ -9,7 +9,9 @@ openpyxl==3.1.5
|
|
| 9 |
markdown==3.7
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
-
google-genai==1.
|
|
|
|
|
|
|
| 13 |
html5lib==1.1
|
| 14 |
beautifulsoup4==4.12.3
|
| 15 |
rapidfuzz==3.13.0
|
|
|
|
| 9 |
markdown==3.7
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
+
google-genai==1.33.0
|
| 13 |
+
azure-ai-inference==1.0.0b9
|
| 14 |
+
azure-core==1.35.0
|
| 15 |
html5lib==1.1
|
| 16 |
beautifulsoup4==4.12.3
|
| 17 |
rapidfuzz==3.13.0
|
tools/aws_functions.py
CHANGED
|
@@ -11,34 +11,42 @@ def connect_to_bedrock_runtime(model_name_map:dict, model_choice:str, aws_access
|
|
| 11 |
# If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
|
| 12 |
model_source = model_name_map[model_choice]["source"]
|
| 13 |
|
| 14 |
-
if "AWS" in model_source:
|
| 15 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
print("Connecting to Bedrock using AWS access key and secret keys from user input.")
|
| 17 |
bedrock_runtime = boto3.client('bedrock-runtime',
|
| 18 |
aws_access_key_id=aws_access_key_textbox,
|
| 19 |
aws_secret_access_key=aws_secret_key_textbox, region_name=AWS_REGION)
|
| 20 |
-
elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
|
| 21 |
-
print("Connecting to Bedrock via existing SSO connection")
|
| 22 |
-
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
| 23 |
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 24 |
print("Getting Bedrock credentials from environment variables")
|
| 25 |
bedrock_runtime = boto3.client('bedrock-runtime',
|
| 26 |
aws_access_key_id=AWS_ACCESS_KEY,
|
| 27 |
aws_secret_access_key=AWS_SECRET_KEY,
|
| 28 |
-
region_name=AWS_REGION)
|
|
|
|
|
|
|
|
|
|
| 29 |
else:
|
| 30 |
bedrock_runtime = ""
|
| 31 |
out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
|
| 32 |
print(out_message)
|
| 33 |
raise Exception(out_message)
|
| 34 |
else:
|
| 35 |
-
bedrock_runtime =
|
|
|
|
|
|
|
| 36 |
|
| 37 |
return bedrock_runtime
|
| 38 |
|
| 39 |
def connect_to_s3_client(aws_access_key_textbox:str="", aws_secret_key_textbox:str=""):
|
| 40 |
# If running an anthropic model, assume that running an AWS s3 model, load in s3
|
| 41 |
-
s3_client =
|
| 42 |
|
| 43 |
if aws_access_key_textbox and aws_secret_key_textbox:
|
| 44 |
print("Connecting to s3 using AWS access key and secret keys from user input.")
|
|
@@ -148,7 +156,7 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=buck
|
|
| 148 |
"""
|
| 149 |
if RUN_AWS_FUNCTIONS == "1":
|
| 150 |
|
| 151 |
-
final_out_message =
|
| 152 |
|
| 153 |
s3_client = connect_to_s3_client(aws_access_key_textbox, aws_secret_key_textbox)
|
| 154 |
#boto3.client('s3')
|
|
|
|
| 11 |
# If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
|
| 12 |
model_source = model_name_map[model_choice]["source"]
|
| 13 |
|
| 14 |
+
if "AWS" in model_source:
|
| 15 |
+
if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
|
| 16 |
+
print("Connecting to Bedrock via existing SSO connection")
|
| 17 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
| 18 |
+
elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
|
| 19 |
+
print("Connecting to Bedrock via existing SSO connection")
|
| 20 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
| 21 |
+
elif aws_access_key_textbox and aws_secret_key_textbox:
|
| 22 |
print("Connecting to Bedrock using AWS access key and secret keys from user input.")
|
| 23 |
bedrock_runtime = boto3.client('bedrock-runtime',
|
| 24 |
aws_access_key_id=aws_access_key_textbox,
|
| 25 |
aws_secret_access_key=aws_secret_key_textbox, region_name=AWS_REGION)
|
|
|
|
|
|
|
|
|
|
| 26 |
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
|
| 27 |
print("Getting Bedrock credentials from environment variables")
|
| 28 |
bedrock_runtime = boto3.client('bedrock-runtime',
|
| 29 |
aws_access_key_id=AWS_ACCESS_KEY,
|
| 30 |
aws_secret_access_key=AWS_SECRET_KEY,
|
| 31 |
+
region_name=AWS_REGION)
|
| 32 |
+
elif RUN_AWS_FUNCTIONS == "1":
|
| 33 |
+
print("Connecting to Bedrock via existing SSO connection")
|
| 34 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
| 35 |
else:
|
| 36 |
bedrock_runtime = ""
|
| 37 |
out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
|
| 38 |
print(out_message)
|
| 39 |
raise Exception(out_message)
|
| 40 |
else:
|
| 41 |
+
bedrock_runtime = list()
|
| 42 |
+
|
| 43 |
+
print("Bedrock runtime connected:", bedrock_runtime)
|
| 44 |
|
| 45 |
return bedrock_runtime
|
| 46 |
|
| 47 |
def connect_to_s3_client(aws_access_key_textbox:str="", aws_secret_key_textbox:str=""):
|
| 48 |
# If running an anthropic model, assume that running an AWS s3 model, load in s3
|
| 49 |
+
s3_client = list()
|
| 50 |
|
| 51 |
if aws_access_key_textbox and aws_secret_key_textbox:
|
| 52 |
print("Connecting to s3 using AWS access key and secret keys from user input.")
|
|
|
|
| 156 |
"""
|
| 157 |
if RUN_AWS_FUNCTIONS == "1":
|
| 158 |
|
| 159 |
+
final_out_message = list()
|
| 160 |
|
| 161 |
s3_client = connect_to_s3_client(aws_access_key_textbox, aws_secret_key_textbox)
|
| 162 |
#boto3.client('s3')
|
tools/config.py
CHANGED
|
@@ -206,6 +206,33 @@ RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
|
| 206 |
RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
|
| 207 |
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Build up options for models
|
| 210 |
|
| 211 |
model_full_names = list()
|
|
@@ -225,12 +252,16 @@ if RUN_AWS_BEDROCK_MODELS == "1":
|
|
| 225 |
model_source.extend(["AWS", "AWS", "AWS", "AWS", "AWS"])
|
| 226 |
|
| 227 |
if RUN_GEMINI_MODELS == "1":
|
| 228 |
-
model_full_names.extend(["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"])
|
| 229 |
model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
|
| 230 |
model_source.extend(["Gemini", "Gemini", "Gemini"])
|
| 231 |
|
| 232 |
-
#
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
model_name_map = {
|
| 236 |
full: {"short_name": short, "source": source}
|
|
|
|
| 206 |
RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
|
| 207 |
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
| 208 |
|
| 209 |
+
# Build up options for models
|
| 210 |
+
###
|
| 211 |
+
# LLM variables
|
| 212 |
+
###
|
| 213 |
+
|
| 214 |
+
MAX_TOKENS = int(get_or_create_env_var('MAX_TOKENS', '4096')) # Maximum number of output tokens
|
| 215 |
+
TIMEOUT_WAIT = int(get_or_create_env_var('TIMEOUT_WAIT', '30')) # AWS now seems to have a 60 second minimum wait between API calls
|
| 216 |
+
NUMBER_OF_RETRY_ATTEMPTS = int(get_or_create_env_var('NUMBER_OF_RETRY_ATTEMPTS', '5'))
|
| 217 |
+
# Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
|
| 218 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS = int(get_or_create_env_var('MAX_OUTPUT_VALIDATION_ATTEMPTS', '3'))
|
| 219 |
+
MAX_TIME_FOR_LOOP = int(get_or_create_env_var('MAX_TIME_FOR_LOOP', '99999'))
|
| 220 |
+
BATCH_SIZE_DEFAULT = int(get_or_create_env_var('BATCH_SIZE_DEFAULT', '5'))
|
| 221 |
+
DEDUPLICATION_THRESHOLD = int(get_or_create_env_var('DEDUPLICATION_THRESHOLD', '90'))
|
| 222 |
+
MAX_COMMENT_CHARS = int(get_or_create_env_var('MAX_COMMENT_CHARS', '14000'))
|
| 223 |
+
|
| 224 |
+
RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "1")
|
| 225 |
+
|
| 226 |
+
RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
|
| 227 |
+
|
| 228 |
+
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
| 229 |
+
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
| 230 |
+
|
| 231 |
+
# Azure AI Inference settings
|
| 232 |
+
RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "0")
|
| 233 |
+
AZURE_API_KEY = get_or_create_env_var('AZURE_API_KEY', '')
|
| 234 |
+
AZURE_INFERENCE_ENDPOINT = get_or_create_env_var('AZURE_INFERENCE_ENDPOINT', '')
|
| 235 |
+
|
| 236 |
# Build up options for models
|
| 237 |
|
| 238 |
model_full_names = list()
|
|
|
|
| 252 |
model_source.extend(["AWS", "AWS", "AWS", "AWS", "AWS"])
|
| 253 |
|
| 254 |
if RUN_GEMINI_MODELS == "1":
|
| 255 |
+
model_full_names.extend(["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"])
|
| 256 |
model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
|
| 257 |
model_source.extend(["Gemini", "Gemini", "Gemini"])
|
| 258 |
|
| 259 |
+
# Register Azure AI models (model names must match your Azure deployments)
|
| 260 |
+
if RUN_AZURE_MODELS == "1":
|
| 261 |
+
# Example deployments; adjust to the deployments you actually create in Azure
|
| 262 |
+
model_full_names.extend(["gpt-5-mini"])
|
| 263 |
+
model_short_names.extend(["gpt-5-mini"])
|
| 264 |
+
model_source.extend(["Azure"])
|
| 265 |
|
| 266 |
model_name_map = {
|
| 267 |
full: {"short_name": short, "source": source}
|
tools/dedup_summaries.py
CHANGED
|
@@ -8,12 +8,13 @@ import time
|
|
| 8 |
import markdown
|
| 9 |
import boto3
|
| 10 |
from tqdm import tqdm
|
|
|
|
| 11 |
|
| 12 |
from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill
|
| 13 |
-
from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata
|
| 14 |
from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
|
| 15 |
-
from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX
|
| 16 |
from tools.aws_functions import connect_to_bedrock_runtime
|
|
|
|
| 17 |
|
| 18 |
max_tokens = MAX_TOKENS
|
| 19 |
timeout_wait = TIMEOUT_WAIT
|
|
@@ -440,10 +441,13 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
| 440 |
google_client = list()
|
| 441 |
google_config = {}
|
| 442 |
|
| 443 |
-
# Prepare Gemini models before query
|
| 444 |
if "Gemini" in model_source:
|
| 445 |
#print("Using Gemini model:", model_choice)
|
| 446 |
google_client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
|
|
|
|
|
|
|
|
|
| 447 |
elif "Local" in model_source:
|
| 448 |
pass
|
| 449 |
#print("Using local model: ", model_choice)
|
|
@@ -594,6 +598,8 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
| 594 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
| 595 |
|
| 596 |
try:
|
|
|
|
|
|
|
| 597 |
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
|
| 598 |
summarised_output = response
|
| 599 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
|
|
|
| 8 |
import markdown
|
| 9 |
import boto3
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
import os
|
| 12 |
|
| 13 |
from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill
|
| 14 |
+
from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata, construct_azure_client
|
| 15 |
from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
|
|
|
|
| 16 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 17 |
+
from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
|
| 18 |
|
| 19 |
max_tokens = MAX_TOKENS
|
| 20 |
timeout_wait = TIMEOUT_WAIT
|
|
|
|
| 441 |
google_client = list()
|
| 442 |
google_config = {}
|
| 443 |
|
| 444 |
+
# Prepare Gemini models before query
|
| 445 |
if "Gemini" in model_source:
|
| 446 |
#print("Using Gemini model:", model_choice)
|
| 447 |
google_client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
| 448 |
+
elif "Azure" in model_source:
|
| 449 |
+
# Azure client (endpoint from env/config)
|
| 450 |
+
google_client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=AZURE_INFERENCE_ENDPOINT)
|
| 451 |
elif "Local" in model_source:
|
| 452 |
pass
|
| 453 |
#print("Using local model: ", model_choice)
|
|
|
|
| 598 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
| 599 |
|
| 600 |
try:
|
| 601 |
+
print("formatted_summarise_topic_descriptions_system_prompt:", formatted_summarise_topic_descriptions_system_prompt)
|
| 602 |
+
|
| 603 |
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
|
| 604 |
summarised_output = response
|
| 605 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
tools/llm_api_call.py
CHANGED
|
@@ -16,8 +16,8 @@ GradioFileData = gr.FileData
|
|
| 16 |
|
| 17 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
|
| 18 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details
|
| 19 |
-
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
|
| 20 |
-
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX
|
| 21 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 22 |
|
| 23 |
if RUN_LOCAL_MODEL == "1":
|
|
@@ -477,8 +477,6 @@ def write_llm_output_and_logs(response_text: str,
|
|
| 477 |
new_reference_df = pd.DataFrame(reference_data)
|
| 478 |
else:
|
| 479 |
new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Summary", "Start row of group"])
|
| 480 |
-
|
| 481 |
-
print("new_reference_df:", new_reference_df)
|
| 482 |
|
| 483 |
# Append on old reference data
|
| 484 |
if not new_reference_df.empty:
|
|
@@ -689,6 +687,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 689 |
aws_access_key_textbox:str='',
|
| 690 |
aws_secret_key_textbox:str='',
|
| 691 |
hf_api_key_textbox:str='',
|
|
|
|
| 692 |
max_tokens:int=max_tokens,
|
| 693 |
model_name_map:dict=model_name_map,
|
| 694 |
max_time_for_loop:int=max_time_for_loop,
|
|
@@ -708,7 +707,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 708 |
- unique_table_df_display_table_markdown (str): Table for display in markdown format.
|
| 709 |
- file_name (str): File name of the data file.
|
| 710 |
- num_batches (int): Number of batches required to go through all the response rows.
|
| 711 |
-
- in_api_key (str): The API key for authentication.
|
| 712 |
- temperature (float): The temperature parameter for the model.
|
| 713 |
- chosen_cols (List[str]): A list of chosen columns to process.
|
| 714 |
- candidate_topics (gr.FileData): A Gradio FileData object of existing candidate topics submitted by the user.
|
|
@@ -843,7 +842,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 843 |
|
| 844 |
for i in topics_loop:
|
| 845 |
reported_batch_no = latest_batch_completed + 1
|
| 846 |
-
print("Running batch:", reported_batch_no)
|
| 847 |
|
| 848 |
# Call the function to prepare the input table
|
| 849 |
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
|
|
@@ -859,10 +858,16 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 859 |
|
| 860 |
formatted_system_prompt = add_existing_topics_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
|
| 861 |
|
| 862 |
-
# Prepare
|
| 863 |
if "Gemini" in model_source:
|
| 864 |
print("Using Gemini model:", model_choice)
|
| 865 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
elif "anthropic.claude" in model_choice:
|
| 867 |
print("Using AWS Bedrock model:", model_choice)
|
| 868 |
else:
|
|
@@ -1034,6 +1039,11 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 1034 |
if model_source == "Gemini":
|
| 1035 |
print("Using Gemini model:", model_choice)
|
| 1036 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
| 1038 |
print("Using local model:", model_choice)
|
| 1039 |
else:
|
|
@@ -1118,7 +1128,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 1118 |
|
| 1119 |
# Increase latest file completed count unless we are over the last batch number
|
| 1120 |
if latest_batch_completed <= num_batches:
|
| 1121 |
-
print("Completed batch number:", str(reported_batch_no))
|
| 1122 |
latest_batch_completed += 1
|
| 1123 |
|
| 1124 |
toc = time.perf_counter()
|
|
@@ -1242,19 +1252,16 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1242 |
initial_existing_topic_summary_df: pd.DataFrame,
|
| 1243 |
initial_unique_table_df_display_table_markdown: str,
|
| 1244 |
original_file_name: str, # Original file name, to be modified per segment
|
| 1245 |
-
# Initial state parameters (wrapper will use these for the very first call)
|
| 1246 |
total_number_of_batches:int,
|
| 1247 |
in_api_key: str,
|
| 1248 |
temperature: float,
|
| 1249 |
chosen_cols: List[str],
|
| 1250 |
model_choice: str,
|
| 1251 |
candidate_topics: GradioFileData = None,
|
| 1252 |
-
|
| 1253 |
initial_first_loop_state: bool = True,
|
| 1254 |
initial_whole_conversation_metadata_str: str = '',
|
| 1255 |
initial_latest_batch_completed: int = 0,
|
| 1256 |
initial_time_taken: float = 0,
|
| 1257 |
-
|
| 1258 |
initial_table_prompt: str = initial_table_prompt,
|
| 1259 |
prompt2: str = prompt2,
|
| 1260 |
prompt3: str = prompt3,
|
|
@@ -1273,14 +1280,67 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1273 |
aws_access_key_textbox:str="",
|
| 1274 |
aws_secret_key_textbox:str="",
|
| 1275 |
hf_api_key_textbox:str="",
|
|
|
|
| 1276 |
output_folder: str = OUTPUT_FOLDER,
|
| 1277 |
force_single_topic_prompt: str = force_single_topic_prompt,
|
| 1278 |
max_tokens: int = max_tokens,
|
| 1279 |
model_name_map: dict = model_name_map,
|
| 1280 |
max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
|
|
|
|
| 1281 |
CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
|
| 1282 |
progress=Progress(track_tqdm=True) # type: ignore
|
| 1283 |
) -> Tuple: # Mimicking the return tuple structure of extract_topics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1284 |
|
| 1285 |
acc_input_tokens = 0
|
| 1286 |
acc_output_tokens = 0
|
|
@@ -1380,7 +1440,7 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1380 |
seg_join_files,
|
| 1381 |
seg_reference_df_pivot,
|
| 1382 |
seg_missing_df
|
| 1383 |
-
|
| 1384 |
in_data_file=in_data_file,
|
| 1385 |
file_data=filtered_file_data,
|
| 1386 |
existing_topics_table=pd.DataFrame(), #acc_topics_table.copy(), # Pass the accumulated table
|
|
@@ -1389,19 +1449,17 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1389 |
unique_table_df_display_table_markdown="", # extract_topics will generate this
|
| 1390 |
file_name=segment_file_name,
|
| 1391 |
num_batches=current_num_batches,
|
| 1392 |
-
latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
|
| 1393 |
-
first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
|
| 1394 |
-
out_message= list(), # Fresh for each call
|
| 1395 |
-
out_file_paths= list(),# Fresh for each call
|
| 1396 |
-
log_files_output_paths= list(),# Fresh for each call
|
| 1397 |
-
whole_conversation_metadata_str="", # Fresh for each call
|
| 1398 |
-
time_taken=0, # Time taken for this specific call, wrapper sums it.
|
| 1399 |
-
# Pass through other parameters
|
| 1400 |
in_api_key=in_api_key,
|
| 1401 |
temperature=temperature,
|
| 1402 |
chosen_cols=chosen_cols,
|
| 1403 |
model_choice=model_choice,
|
| 1404 |
candidate_topics=candidate_topics,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1405 |
initial_table_prompt=initial_table_prompt,
|
| 1406 |
prompt2=prompt2,
|
| 1407 |
prompt3=prompt3,
|
|
@@ -1411,6 +1469,7 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1411 |
number_of_prompts_used=number_of_prompts_used,
|
| 1412 |
batch_size=batch_size,
|
| 1413 |
context_textbox=context_textbox,
|
|
|
|
| 1414 |
sentiment_checkbox=sentiment_checkbox,
|
| 1415 |
force_zero_shot_radio=force_zero_shot_radio,
|
| 1416 |
in_excel_sheets=in_excel_sheets,
|
|
@@ -1422,11 +1481,13 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1422 |
aws_access_key_textbox=aws_access_key_textbox,
|
| 1423 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
| 1424 |
hf_api_key_textbox=hf_api_key_textbox,
|
|
|
|
| 1425 |
max_tokens=max_tokens,
|
| 1426 |
model_name_map=model_name_map,
|
| 1427 |
max_time_for_loop=max_time_for_loop,
|
| 1428 |
CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE,
|
| 1429 |
-
|
|
|
|
| 1430 |
)
|
| 1431 |
|
| 1432 |
# Aggregate results
|
|
|
|
| 16 |
|
| 17 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
|
| 18 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details
|
| 19 |
+
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client
|
| 20 |
+
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
|
| 21 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 22 |
|
| 23 |
if RUN_LOCAL_MODEL == "1":
|
|
|
|
| 477 |
new_reference_df = pd.DataFrame(reference_data)
|
| 478 |
else:
|
| 479 |
new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Summary", "Start row of group"])
|
|
|
|
|
|
|
| 480 |
|
| 481 |
# Append on old reference data
|
| 482 |
if not new_reference_df.empty:
|
|
|
|
| 687 |
aws_access_key_textbox:str='',
|
| 688 |
aws_secret_key_textbox:str='',
|
| 689 |
hf_api_key_textbox:str='',
|
| 690 |
+
azure_api_key_textbox:str='',
|
| 691 |
max_tokens:int=max_tokens,
|
| 692 |
model_name_map:dict=model_name_map,
|
| 693 |
max_time_for_loop:int=max_time_for_loop,
|
|
|
|
| 707 |
- unique_table_df_display_table_markdown (str): Table for display in markdown format.
|
| 708 |
- file_name (str): File name of the data file.
|
| 709 |
- num_batches (int): Number of batches required to go through all the response rows.
|
| 710 |
+
- in_api_key (str): The API key for authentication (Google Gemini).
|
| 711 |
- temperature (float): The temperature parameter for the model.
|
| 712 |
- chosen_cols (List[str]): A list of chosen columns to process.
|
| 713 |
- candidate_topics (gr.FileData): A Gradio FileData object of existing candidate topics submitted by the user.
|
|
|
|
| 842 |
|
| 843 |
for i in topics_loop:
|
| 844 |
reported_batch_no = latest_batch_completed + 1
|
| 845 |
+
print("Running response batch:", reported_batch_no)
|
| 846 |
|
| 847 |
# Call the function to prepare the input table
|
| 848 |
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
|
|
|
|
| 858 |
|
| 859 |
formatted_system_prompt = add_existing_topics_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
|
| 860 |
|
| 861 |
+
# Prepare clients before query
|
| 862 |
if "Gemini" in model_source:
|
| 863 |
print("Using Gemini model:", model_choice)
|
| 864 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
|
| 865 |
+
elif "Azure" in model_source:
|
| 866 |
+
print("Using Azure AI Inference model:", model_choice)
|
| 867 |
+
# If provided, set env for downstream calls too
|
| 868 |
+
if azure_api_key_textbox:
|
| 869 |
+
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 870 |
+
google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
|
| 871 |
elif "anthropic.claude" in model_choice:
|
| 872 |
print("Using AWS Bedrock model:", model_choice)
|
| 873 |
else:
|
|
|
|
| 1039 |
if model_source == "Gemini":
|
| 1040 |
print("Using Gemini model:", model_choice)
|
| 1041 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
|
| 1042 |
+
elif model_source == "Azure":
|
| 1043 |
+
print("Using Azure AI Inference model:", model_choice)
|
| 1044 |
+
if azure_api_key_textbox:
|
| 1045 |
+
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 1046 |
+
google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
|
| 1047 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
| 1048 |
print("Using local model:", model_choice)
|
| 1049 |
else:
|
|
|
|
| 1128 |
|
| 1129 |
# Increase latest file completed count unless we are over the last batch number
|
| 1130 |
if latest_batch_completed <= num_batches:
|
| 1131 |
+
#print("Completed batch number:", str(reported_batch_no))
|
| 1132 |
latest_batch_completed += 1
|
| 1133 |
|
| 1134 |
toc = time.perf_counter()
|
|
|
|
| 1252 |
initial_existing_topic_summary_df: pd.DataFrame,
|
| 1253 |
initial_unique_table_df_display_table_markdown: str,
|
| 1254 |
original_file_name: str, # Original file name, to be modified per segment
|
|
|
|
| 1255 |
total_number_of_batches:int,
|
| 1256 |
in_api_key: str,
|
| 1257 |
temperature: float,
|
| 1258 |
chosen_cols: List[str],
|
| 1259 |
model_choice: str,
|
| 1260 |
candidate_topics: GradioFileData = None,
|
|
|
|
| 1261 |
initial_first_loop_state: bool = True,
|
| 1262 |
initial_whole_conversation_metadata_str: str = '',
|
| 1263 |
initial_latest_batch_completed: int = 0,
|
| 1264 |
initial_time_taken: float = 0,
|
|
|
|
| 1265 |
initial_table_prompt: str = initial_table_prompt,
|
| 1266 |
prompt2: str = prompt2,
|
| 1267 |
prompt3: str = prompt3,
|
|
|
|
| 1280 |
aws_access_key_textbox:str="",
|
| 1281 |
aws_secret_key_textbox:str="",
|
| 1282 |
hf_api_key_textbox:str="",
|
| 1283 |
+
azure_api_key_textbox:str="",
|
| 1284 |
output_folder: str = OUTPUT_FOLDER,
|
| 1285 |
force_single_topic_prompt: str = force_single_topic_prompt,
|
| 1286 |
max_tokens: int = max_tokens,
|
| 1287 |
model_name_map: dict = model_name_map,
|
| 1288 |
max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
|
| 1289 |
+
reasoning_suffix: str = reasoning_suffix,
|
| 1290 |
CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
|
| 1291 |
progress=Progress(track_tqdm=True) # type: ignore
|
| 1292 |
) -> Tuple: # Mimicking the return tuple structure of extract_topics
|
| 1293 |
+
"""
|
| 1294 |
+
A wrapper function that iterates through unique values in a specified grouping column
|
| 1295 |
+
and calls the `extract_topics` function for each segment of the data.
|
| 1296 |
+
It accumulates results from each call and returns a consolidated output.
|
| 1297 |
+
|
| 1298 |
+
:param grouping_col: The name of the column to group the data by.
|
| 1299 |
+
:param in_data_file: The input data file object (e.g., Gradio FileData).
|
| 1300 |
+
:param file_data: The full DataFrame containing all data.
|
| 1301 |
+
:param initial_existing_topics_table: Initial DataFrame of existing topics.
|
| 1302 |
+
:param initial_existing_reference_df: Initial DataFrame mapping responses to topics.
|
| 1303 |
+
:param initial_existing_topic_summary_df: Initial DataFrame summarizing topics.
|
| 1304 |
+
:param initial_unique_table_df_display_table_markdown: Initial markdown string for topic display.
|
| 1305 |
+
:param original_file_name: The original name of the input file.
|
| 1306 |
+
:param total_number_of_batches: The total number of batches across all data.
|
| 1307 |
+
:param in_api_key: API key for the chosen LLM.
|
| 1308 |
+
:param temperature: Temperature setting for the LLM.
|
| 1309 |
+
:param chosen_cols: List of columns from `file_data` to be processed.
|
| 1310 |
+
:param model_choice: The chosen LLM model (e.g., "Gemini", "AWS Claude").
|
| 1311 |
+
:param candidate_topics: Optional Gradio FileData for candidate topics (zero-shot).
|
| 1312 |
+
:param initial_first_loop_state: Boolean indicating if this is the very first loop iteration.
|
| 1313 |
+
:param initial_whole_conversation_metadata_str: Initial metadata string for the whole conversation.
|
| 1314 |
+
:param initial_latest_batch_completed: The batch number completed in the previous run.
|
| 1315 |
+
:param initial_time_taken: Initial time taken for processing.
|
| 1316 |
+
:param initial_table_prompt: The initial prompt for table summarization.
|
| 1317 |
+
:param prompt2: The second prompt for LLM interaction.
|
| 1318 |
+
:param prompt3: The third prompt for LLM interaction.
|
| 1319 |
+
:param initial_table_system_prompt: The initial system prompt for table summarization.
|
| 1320 |
+
:param add_existing_topics_system_prompt: System prompt for adding existing topics.
|
| 1321 |
+
:param add_existing_topics_prompt: Prompt for adding existing topics.
|
| 1322 |
+
:param number_of_prompts_used: Number of prompts used in the LLM call.
|
| 1323 |
+
:param batch_size: Number of rows to process in each batch for the LLM.
|
| 1324 |
+
:param context_textbox: Additional context provided by the user.
|
| 1325 |
+
:param sentiment_checkbox: Choice for sentiment assessment (e.g., "Negative, Neutral, or Positive").
|
| 1326 |
+
:param force_zero_shot_radio: Option to force responses into zero-shot topics.
|
| 1327 |
+
:param in_excel_sheets: List of Excel sheet names if applicable.
|
| 1328 |
+
:param force_single_topic_radio: Option to force a single topic per response.
|
| 1329 |
+
:param produce_structures_summary_radio: Option to produce a structured summary.
|
| 1330 |
+
:param aws_access_key_textbox: AWS access key for Bedrock.
|
| 1331 |
+
:param aws_secret_key_textbox: AWS secret key for Bedrock.
|
| 1332 |
+
:param hf_api_key_textbox: Hugging Face API key for local models.
|
| 1333 |
+
:param azure_api_key_textbox: Azure API key for Azure AI Inference.
|
| 1334 |
+
:param output_folder: The folder where output files will be saved.
|
| 1335 |
+
:param force_single_topic_prompt: Prompt for forcing a single topic.
|
| 1336 |
+
:param max_tokens: Maximum tokens for LLM generation.
|
| 1337 |
+
:param model_name_map: Dictionary mapping model names to their properties.
|
| 1338 |
+
:param max_time_for_loop: Maximum time allowed for the processing loop.
|
| 1339 |
+
:param reasoning_suffix: Suffix to append for reasoning.
|
| 1340 |
+
:param CHOSEN_LOCAL_MODEL_TYPE: Type of local model chosen.
|
| 1341 |
+
:param progress: Gradio Progress object for tracking progress.
|
| 1342 |
+
:return: A tuple containing consolidated results, mimicking the return structure of `extract_topics`.
|
| 1343 |
+
"""
|
| 1344 |
|
| 1345 |
acc_input_tokens = 0
|
| 1346 |
acc_output_tokens = 0
|
|
|
|
| 1440 |
seg_join_files,
|
| 1441 |
seg_reference_df_pivot,
|
| 1442 |
seg_missing_df
|
| 1443 |
+
) = extract_topics(
|
| 1444 |
in_data_file=in_data_file,
|
| 1445 |
file_data=filtered_file_data,
|
| 1446 |
existing_topics_table=pd.DataFrame(), #acc_topics_table.copy(), # Pass the accumulated table
|
|
|
|
| 1449 |
unique_table_df_display_table_markdown="", # extract_topics will generate this
|
| 1450 |
file_name=segment_file_name,
|
| 1451 |
num_batches=current_num_batches,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
in_api_key=in_api_key,
|
| 1453 |
temperature=temperature,
|
| 1454 |
chosen_cols=chosen_cols,
|
| 1455 |
model_choice=model_choice,
|
| 1456 |
candidate_topics=candidate_topics,
|
| 1457 |
+
latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
|
| 1458 |
+
out_message= list(), # Fresh for each call
|
| 1459 |
+
out_file_paths= list(),# Fresh for each call
|
| 1460 |
+
log_files_output_paths= list(),# Fresh for each call
|
| 1461 |
+
first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
|
| 1462 |
+
whole_conversation_metadata_str="", # Fresh for each call
|
| 1463 |
initial_table_prompt=initial_table_prompt,
|
| 1464 |
prompt2=prompt2,
|
| 1465 |
prompt3=prompt3,
|
|
|
|
| 1469 |
number_of_prompts_used=number_of_prompts_used,
|
| 1470 |
batch_size=batch_size,
|
| 1471 |
context_textbox=context_textbox,
|
| 1472 |
+
time_taken=0, # Time taken for this specific call, wrapper sums it.
|
| 1473 |
sentiment_checkbox=sentiment_checkbox,
|
| 1474 |
force_zero_shot_radio=force_zero_shot_radio,
|
| 1475 |
in_excel_sheets=in_excel_sheets,
|
|
|
|
| 1481 |
aws_access_key_textbox=aws_access_key_textbox,
|
| 1482 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
| 1483 |
hf_api_key_textbox=hf_api_key_textbox,
|
| 1484 |
+
azure_api_key_textbox=azure_api_key_textbox,
|
| 1485 |
max_tokens=max_tokens,
|
| 1486 |
model_name_map=model_name_map,
|
| 1487 |
max_time_for_loop=max_time_for_loop,
|
| 1488 |
CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE,
|
| 1489 |
+
reasoning_suffix=reasoning_suffix,
|
| 1490 |
+
progress=progress
|
| 1491 |
)
|
| 1492 |
|
| 1493 |
# Aggregate results
|
tools/llm_funcs.py
CHANGED
|
@@ -13,12 +13,16 @@ from google.genai import types
|
|
| 13 |
import gradio as gr
|
| 14 |
from gradio import Progress
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
model_type = None # global variable setup
|
| 17 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
| 18 |
model = list() # Define empty list for model functions to run
|
| 19 |
tokenizer = list() #[] # Define empty list for model functions to run
|
| 20 |
|
| 21 |
-
from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, INT8_WITH_OFFLOAD_TO_CPU
|
| 22 |
from tools.prompts import initial_table_assistant_prefill
|
| 23 |
|
| 24 |
if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
|
|
@@ -500,16 +504,7 @@ def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
|
| 500 |
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int, random_seed=seed) -> Tuple[object, dict]:
|
| 501 |
"""
|
| 502 |
Constructs a GenerativeModel for Gemini API calls.
|
| 503 |
-
|
| 504 |
-
Parameters:
|
| 505 |
-
- in_api_key (str): The API key for authentication.
|
| 506 |
-
- temperature (float): The temperature parameter for the model, controlling the randomness of the output.
|
| 507 |
-
- model_choice (str): The choice of model to use for generation.
|
| 508 |
-
- system_prompt (str): The system prompt to guide the generation.
|
| 509 |
-
- max_tokens (int): The maximum number of tokens to generate.
|
| 510 |
-
|
| 511 |
-
Returns:
|
| 512 |
-
- Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
|
| 513 |
"""
|
| 514 |
# Construct a GenerativeModel
|
| 515 |
try:
|
|
@@ -532,6 +527,31 @@ def construct_gemini_generative_model(in_api_key: str, temperature: float, model
|
|
| 532 |
|
| 533 |
return client, config
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
|
| 536 |
"""
|
| 537 |
This function sends a request to AWS Claude with the following parameters:
|
|
@@ -667,15 +687,10 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 667 |
duration = end_time - start_time
|
| 668 |
tokens_per_second = num_generated_tokens / duration
|
| 669 |
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
# print(f"Assistant's Reply: {assistant_reply}")
|
| 675 |
-
# print("\n--- Performance ---")
|
| 676 |
-
# print(f"Time taken: {duration:.2f} seconds")
|
| 677 |
-
# print(f"Generated tokens: {num_generated_tokens}")
|
| 678 |
-
# print(f"Tokens per second: {tokens_per_second:.2f}")
|
| 679 |
|
| 680 |
return assistant_reply, num_input_tokens, num_generated_tokens
|
| 681 |
|
|
@@ -725,6 +740,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 725 |
|
| 726 |
if i == number_of_api_retry_attempts:
|
| 727 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
|
|
|
| 728 |
elif "AWS" in model_source:
|
| 729 |
for i in progress_bar:
|
| 730 |
try:
|
|
@@ -740,6 +756,35 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 740 |
|
| 741 |
if i == number_of_api_retry_attempts:
|
| 742 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
elif "Local" in model_source:
|
| 744 |
# This is the local model
|
| 745 |
for i in progress_bar:
|
|
@@ -776,28 +821,29 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 776 |
# Check if is a LLama.cpp model response
|
| 777 |
if isinstance(response, ResponseObject):
|
| 778 |
response_text = response.text
|
| 779 |
-
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
| 780 |
elif 'choices' in response: # LLama.cpp model response
|
| 781 |
if "gpt-oss" in model_choice:
|
| 782 |
response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
|
| 783 |
else:
|
| 784 |
response_text = response['choices'][0]['message']['content']
|
| 785 |
response_text = response_text.strip()
|
| 786 |
-
conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
|
| 787 |
elif model_source == "Gemini":
|
| 788 |
response_text = response.text
|
| 789 |
response_text = response_text.strip()
|
| 790 |
-
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
| 791 |
else: # Assume transformers model response
|
| 792 |
if "gpt-oss" in model_choice:
|
| 793 |
response_text = response.split('<|start|>assistant<|channel|>final<|message|>')[1]
|
| 794 |
else:
|
| 795 |
response_text = response
|
| 796 |
-
|
|
|
|
| 797 |
|
| 798 |
return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
| 799 |
|
| 800 |
-
def process_requests(prompts: List[str],
|
|
|
|
|
|
|
|
|
|
| 801 |
"""
|
| 802 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
| 803 |
|
|
@@ -836,28 +882,27 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
| 836 |
whole_conversation.append(response_text)
|
| 837 |
|
| 838 |
# Create conversation metadata
|
| 839 |
-
if master == False:
|
| 840 |
-
|
| 841 |
-
else:
|
| 842 |
-
|
| 843 |
-
|
|
|
|
| 844 |
|
| 845 |
# if not isinstance(response, str):
|
| 846 |
try:
|
| 847 |
if "AWS" in model_source:
|
| 848 |
-
#print("Extracting usage metadata from Converse API response...")
|
| 849 |
-
|
| 850 |
-
# Using .get() is safer than direct access, in case a key is missing.
|
| 851 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
| 852 |
input_tokens = response.usage_metadata.get('inputTokens', 0)
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
elif "Gemini" in model_source:
|
| 857 |
-
|
| 858 |
output_tokens = response.usage_metadata.candidates_token_count
|
| 859 |
input_tokens = response.usage_metadata.prompt_token_count
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
elif "Local" in model_source:
|
| 862 |
if USE_LLAMA_CPP == "True":
|
| 863 |
output_tokens = response['usage'].get('completion_tokens', 0)
|
|
@@ -1012,20 +1057,8 @@ def calculate_tokens_from_metadata(metadata_string:str, model_choice:str, model_
|
|
| 1012 |
|
| 1013 |
# Regex to find the numbers following the keys in the "Query summary metadata" section
|
| 1014 |
# This ensures we get the final, aggregated totals for the whole query.
|
| 1015 |
-
#if "Gemini" in model_source:
|
| 1016 |
input_regex = r"input_tokens: (\d+)"
|
| 1017 |
output_regex = r"output_tokens: (\d+)"
|
| 1018 |
-
# elif "AWS" in model_source:
|
| 1019 |
-
# input_regex = r"inputTokens: (\d+)"
|
| 1020 |
-
# output_regex = r"outputTokens: (\d+)"
|
| 1021 |
-
# elif "Local" in model_source:
|
| 1022 |
-
# print("Local model source")
|
| 1023 |
-
# input_regex = r"\'prompt_tokens\': (\d+)"
|
| 1024 |
-
# output_regex = r"\'completion_tokens\': (\d+)"
|
| 1025 |
-
|
| 1026 |
-
#print("Metadata string:", metadata_string)
|
| 1027 |
-
#print("Input regex:", input_regex)
|
| 1028 |
-
#print("Output regex:", output_regex)
|
| 1029 |
|
| 1030 |
# re.findall returns a list of all matching strings (the captured groups).
|
| 1031 |
input_token_strings = re.findall(input_regex, metadata_string)
|
|
|
|
| 13 |
import gradio as gr
|
| 14 |
from gradio import Progress
|
| 15 |
|
| 16 |
+
from azure.ai.inference import ChatCompletionsClient
|
| 17 |
+
from azure.core.credentials import AzureKeyCredential
|
| 18 |
+
from azure.ai.inference.models import SystemMessage, UserMessage
|
| 19 |
+
|
| 20 |
model_type = None # global variable setup
|
| 21 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
| 22 |
model = list() # Define empty list for model functions to run
|
| 23 |
tokenizer = list() #[] # Define empty list for model functions to run
|
| 24 |
|
| 25 |
+
from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, INT8_WITH_OFFLOAD_TO_CPU, AZURE_INFERENCE_ENDPOINT
|
| 26 |
from tools.prompts import initial_table_assistant_prefill
|
| 27 |
|
| 28 |
if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
|
|
|
|
| 504 |
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int, random_seed=seed) -> Tuple[object, dict]:
|
| 505 |
"""
|
| 506 |
Constructs a GenerativeModel for Gemini API calls.
|
| 507 |
+
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
"""
|
| 509 |
# Construct a GenerativeModel
|
| 510 |
try:
|
|
|
|
| 527 |
|
| 528 |
return client, config
|
| 529 |
|
| 530 |
+
def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
|
| 531 |
+
"""
|
| 532 |
+
Constructs a ChatCompletionsClient for Azure AI Inference.
|
| 533 |
+
"""
|
| 534 |
+
try:
|
| 535 |
+
key = None
|
| 536 |
+
if in_api_key:
|
| 537 |
+
key = in_api_key
|
| 538 |
+
elif os.environ.get("AZURE_INFERENCE_CREDENTIAL"):
|
| 539 |
+
key = os.environ["AZURE_INFERENCE_CREDENTIAL"]
|
| 540 |
+
elif os.environ.get("AZURE_API_KEY"):
|
| 541 |
+
key = os.environ["AZURE_API_KEY"]
|
| 542 |
+
if not key:
|
| 543 |
+
raise Warning("No Azure API key found.")
|
| 544 |
+
|
| 545 |
+
if not endpoint:
|
| 546 |
+
endpoint = os.environ.get("AZURE_INFERENCE_ENDPOINT", "")
|
| 547 |
+
if not endpoint:
|
| 548 |
+
raise Warning("No Azure inference endpoint found.")
|
| 549 |
+
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key))
|
| 550 |
+
return client, {}
|
| 551 |
+
except Exception as e:
|
| 552 |
+
print("Error constructing Azure ChatCompletions client:", e)
|
| 553 |
+
raise
|
| 554 |
+
|
| 555 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
|
| 556 |
"""
|
| 557 |
This function sends a request to AWS Claude with the following parameters:
|
|
|
|
| 687 |
duration = end_time - start_time
|
| 688 |
tokens_per_second = num_generated_tokens / duration
|
| 689 |
|
| 690 |
+
print("\n--- Performance ---")
|
| 691 |
+
print(f"Time taken: {duration:.2f} seconds")
|
| 692 |
+
print(f"Generated tokens: {num_generated_tokens}")
|
| 693 |
+
print(f"Tokens per second: {tokens_per_second:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
return assistant_reply, num_input_tokens, num_generated_tokens
|
| 696 |
|
|
|
|
| 740 |
|
| 741 |
if i == number_of_api_retry_attempts:
|
| 742 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
| 743 |
+
|
| 744 |
elif "AWS" in model_source:
|
| 745 |
for i in progress_bar:
|
| 746 |
try:
|
|
|
|
| 756 |
|
| 757 |
if i == number_of_api_retry_attempts:
|
| 758 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
| 759 |
+
elif "Azure" in model_source:
|
| 760 |
+
for i in progress_bar:
|
| 761 |
+
try:
|
| 762 |
+
print("Calling Azure AI Inference model, attempt", i + 1)
|
| 763 |
+
# Use structured messages for Azure
|
| 764 |
+
response_raw = google_client.complete(
|
| 765 |
+
messages=[
|
| 766 |
+
SystemMessage(content=system_prompt),
|
| 767 |
+
UserMessage(content=prompt),
|
| 768 |
+
],
|
| 769 |
+
model=model_choice
|
| 770 |
+
)
|
| 771 |
+
response_text = response_raw.choices[0].message.content
|
| 772 |
+
usage = getattr(response_raw, "usage", None)
|
| 773 |
+
input_tokens = 0
|
| 774 |
+
output_tokens = 0
|
| 775 |
+
if usage is not None:
|
| 776 |
+
input_tokens = getattr(usage, "input_tokens", getattr(usage, "prompt_tokens", 0))
|
| 777 |
+
output_tokens = getattr(usage, "output_tokens", getattr(usage, "completion_tokens", 0))
|
| 778 |
+
response = ResponseObject(
|
| 779 |
+
text=response_text,
|
| 780 |
+
usage_metadata={'inputTokens': input_tokens, 'outputTokens': output_tokens}
|
| 781 |
+
)
|
| 782 |
+
break
|
| 783 |
+
except Exception as e:
|
| 784 |
+
print("Call to Azure model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
| 785 |
+
time.sleep(timeout_wait)
|
| 786 |
+
if i == number_of_api_retry_attempts:
|
| 787 |
+
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
| 788 |
elif "Local" in model_source:
|
| 789 |
# This is the local model
|
| 790 |
for i in progress_bar:
|
|
|
|
| 821 |
# Check if is a LLama.cpp model response
|
| 822 |
if isinstance(response, ResponseObject):
|
| 823 |
response_text = response.text
|
|
|
|
| 824 |
elif 'choices' in response: # LLama.cpp model response
|
| 825 |
if "gpt-oss" in model_choice:
|
| 826 |
response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
|
| 827 |
else:
|
| 828 |
response_text = response['choices'][0]['message']['content']
|
| 829 |
response_text = response_text.strip()
|
|
|
|
| 830 |
elif model_source == "Gemini":
|
| 831 |
response_text = response.text
|
| 832 |
response_text = response_text.strip()
|
|
|
|
| 833 |
else: # Assume transformers model response
|
| 834 |
if "gpt-oss" in model_choice:
|
| 835 |
response_text = response.split('<|start|>assistant<|channel|>final<|message|>')[1]
|
| 836 |
else:
|
| 837 |
response_text = response
|
| 838 |
+
|
| 839 |
+
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
| 840 |
|
| 841 |
return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
| 842 |
|
| 843 |
+
def process_requests(prompts: List[str],
|
| 844 |
+
system_prompt: str,
|
| 845 |
+
conversation_history: List[dict],
|
| 846 |
+
whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), tokenizer=tokenizer, master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
| 847 |
"""
|
| 848 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
| 849 |
|
|
|
|
| 882 |
whole_conversation.append(response_text)
|
| 883 |
|
| 884 |
# Create conversation metadata
|
| 885 |
+
# if master == False:
|
| 886 |
+
# whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 887 |
+
# else:
|
| 888 |
+
# #whole_conversation_metadata.append(f"Query summary metadata:")
|
| 889 |
+
|
| 890 |
+
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 891 |
|
| 892 |
# if not isinstance(response, str):
|
| 893 |
try:
|
| 894 |
if "AWS" in model_source:
|
|
|
|
|
|
|
|
|
|
| 895 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
| 896 |
input_tokens = response.usage_metadata.get('inputTokens', 0)
|
| 897 |
+
|
| 898 |
+
elif "Gemini" in model_source:
|
|
|
|
|
|
|
|
|
|
| 899 |
output_tokens = response.usage_metadata.candidates_token_count
|
| 900 |
input_tokens = response.usage_metadata.prompt_token_count
|
| 901 |
|
| 902 |
+
elif "Azure" in model_source:
|
| 903 |
+
input_tokens = response.usage_metadata.get('inputTokens', 0)
|
| 904 |
+
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
| 905 |
+
|
| 906 |
elif "Local" in model_source:
|
| 907 |
if USE_LLAMA_CPP == "True":
|
| 908 |
output_tokens = response['usage'].get('completion_tokens', 0)
|
|
|
|
| 1057 |
|
| 1058 |
# Regex to find the numbers following the keys in the "Query summary metadata" section
|
| 1059 |
# This ensures we get the final, aggregated totals for the whole query.
|
|
|
|
| 1060 |
input_regex = r"input_tokens: (\d+)"
|
| 1061 |
output_regex = r"output_tokens: (\d+)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1062 |
|
| 1063 |
# re.findall returns a list of all matching strings (the captured groups).
|
| 1064 |
input_token_strings = re.findall(input_regex, metadata_string)
|