seanpedrickcase commited on
Commit
138286c
·
1 Parent(s): 5fa40a6

Added framework of support for Azure models (although untested)

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from tools.custom_csvlogger import CSVLogger_custom
12
  from tools.auth import authenticate_user
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
14
  from tools.verify_titles import verify_titles
15
- from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN
16
 
17
  def ensure_folder_exists(output_folder:str):
18
  """Checks if the specified folder exists, creates it if not."""
@@ -307,6 +307,9 @@ with app:
307
  with gr.Accordion("Gemini API keys", open = False):
308
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
309
 
 
 
 
310
  with gr.Accordion("Hugging Face API keys", open = False):
311
  hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
312
 
@@ -369,7 +372,7 @@ with app:
369
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
370
  success(load_in_data_file,
371
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
372
- success(fn=wrapper_extract_topics_per_column_value,
373
  inputs=[in_group_col,
374
  in_data_files,
375
  file_data_state,
@@ -405,6 +408,7 @@ with app:
405
  aws_access_key_textbox,
406
  aws_secret_key_textbox,
407
  hf_api_key_textbox,
 
408
  output_folder_state],
409
  outputs=[display_topic_table_markdown,
410
  master_topic_df_state,
@@ -467,10 +471,10 @@ with app:
467
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
468
  success(load_in_data_file,
469
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
470
- success(fn=wrapper_extract_topics_per_column_value,
471
  inputs=[in_group_col,
472
  in_data_files,
473
- file_data_state,
474
  master_topic_df_state,
475
  master_reference_df_state,
476
  master_unique_topics_df_state,
@@ -503,6 +507,7 @@ with app:
503
  aws_access_key_textbox,
504
  aws_secret_key_textbox,
505
  hf_api_key_textbox,
 
506
  output_folder_state],
507
  outputs=[display_topic_table_markdown,
508
  master_topic_df_state,
 
12
  from tools.auth import authenticate_user
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
14
  from tools.verify_titles import verify_titles
15
+ from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN, AZURE_API_KEY
16
 
17
  def ensure_folder_exists(output_folder:str):
18
  """Checks if the specified folder exists, creates it if not."""
 
307
  with gr.Accordion("Gemini API keys", open = False):
308
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
309
 
310
+ with gr.Accordion("Azure AI Inference", open = False):
311
+ azure_api_key_textbox = gr.Textbox(value = AZURE_API_KEY, label="Enter Azure AI Inference API key (only if using Azure models)", lines=1, type="password")
312
+
313
  with gr.Accordion("Hugging Face API keys", open = False):
314
  hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
315
 
 
372
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
373
  success(load_in_data_file,
374
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
375
+ success(fn=wrapper_extract_topics_per_column_value,
376
  inputs=[in_group_col,
377
  in_data_files,
378
  file_data_state,
 
408
  aws_access_key_textbox,
409
  aws_secret_key_textbox,
410
  hf_api_key_textbox,
411
+ azure_api_key_textbox,
412
  output_folder_state],
413
  outputs=[display_topic_table_markdown,
414
  master_topic_df_state,
 
471
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
472
  success(load_in_data_file,
473
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
474
+ success(fn=wrapper_extract_topics_per_column_value,
475
  inputs=[in_group_col,
476
  in_data_files,
477
+ file_data_state,
478
  master_topic_df_state,
479
  master_reference_df_state,
480
  master_unique_topics_df_state,
 
507
  aws_access_key_textbox,
508
  aws_secret_key_textbox,
509
  hf_api_key_textbox,
510
+ azure_api_key_textbox,
511
  output_folder_state],
512
  outputs=[display_topic_table_markdown,
513
  master_topic_df_state,
requirements.txt CHANGED
@@ -9,7 +9,9 @@ openpyxl==3.1.5
9
  markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
- google-genai==1.32.0
 
 
13
  html5lib==1.1
14
  beautifulsoup4==4.12.3
15
  rapidfuzz==3.13.0
@@ -24,5 +26,9 @@ accelerate==1.10.1
24
  #torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
25
  # For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
26
  #https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
 
 
 
 
27
 
28
 
 
9
  markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
+ google-genai==1.33.0
13
+ azure-ai-inference==1.0.0b9
14
+ azure-core==1.35.0
15
  html5lib==1.1
16
  beautifulsoup4==4.12.3
17
  rapidfuzz==3.13.0
 
26
  #torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
27
  # For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
28
  #https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
29
+ # CPU only (for e.g. Hugging Face CPU instances)
30
+ #torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
31
+ # For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
32
+ #https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
33
 
34
 
requirements_cpu.txt CHANGED
@@ -8,7 +8,9 @@ openpyxl==3.1.5
8
  markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
- google-genai==1.32.0
 
 
12
  html5lib==1.1
13
  beautifulsoup4==4.12.3
14
  rapidfuzz==3.13.0
 
8
  markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
+ google-genai==1.33.0
12
+ azure-ai-inference==1.0.0b9
13
+ azure-core==1.35.0
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
requirements_gpu.txt CHANGED
@@ -8,7 +8,9 @@ openpyxl==3.1.5
8
  markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
- google-genai==1.32.0
 
 
12
  html5lib==1.1
13
  beautifulsoup4==4.12.3
14
  rapidfuzz==3.13.0
 
8
  markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
+ google-genai==1.33.0
12
+ azure-ai-inference==1.0.0b9
13
+ azure-core==1.35.0
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
requirements_no_local.txt CHANGED
@@ -9,7 +9,9 @@ openpyxl==3.1.5
9
  markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
- google-genai==1.32.0
 
 
13
  html5lib==1.1
14
  beautifulsoup4==4.12.3
15
  rapidfuzz==3.13.0
 
9
  markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
+ google-genai==1.33.0
13
+ azure-ai-inference==1.0.0b9
14
+ azure-core==1.35.0
15
  html5lib==1.1
16
  beautifulsoup4==4.12.3
17
  rapidfuzz==3.13.0
tools/aws_functions.py CHANGED
@@ -11,34 +11,42 @@ def connect_to_bedrock_runtime(model_name_map:dict, model_choice:str, aws_access
11
  # If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
12
  model_source = model_name_map[model_choice]["source"]
13
 
14
- if "AWS" in model_source:
15
- if aws_access_key_textbox and aws_secret_key_textbox:
 
 
 
 
 
 
16
  print("Connecting to Bedrock using AWS access key and secret keys from user input.")
17
  bedrock_runtime = boto3.client('bedrock-runtime',
18
  aws_access_key_id=aws_access_key_textbox,
19
  aws_secret_access_key=aws_secret_key_textbox, region_name=AWS_REGION)
20
- elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
21
- print("Connecting to Bedrock via existing SSO connection")
22
- bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
23
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
24
  print("Getting Bedrock credentials from environment variables")
25
  bedrock_runtime = boto3.client('bedrock-runtime',
26
  aws_access_key_id=AWS_ACCESS_KEY,
27
  aws_secret_access_key=AWS_SECRET_KEY,
28
- region_name=AWS_REGION)
 
 
 
29
  else:
30
  bedrock_runtime = ""
31
  out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
32
  print(out_message)
33
  raise Exception(out_message)
34
  else:
35
- bedrock_runtime = []
 
 
36
 
37
  return bedrock_runtime
38
 
39
  def connect_to_s3_client(aws_access_key_textbox:str="", aws_secret_key_textbox:str=""):
40
  # If running an anthropic model, assume that running an AWS s3 model, load in s3
41
- s3_client = []
42
 
43
  if aws_access_key_textbox and aws_secret_key_textbox:
44
  print("Connecting to s3 using AWS access key and secret keys from user input.")
@@ -148,7 +156,7 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=buck
148
  """
149
  if RUN_AWS_FUNCTIONS == "1":
150
 
151
- final_out_message = []
152
 
153
  s3_client = connect_to_s3_client(aws_access_key_textbox, aws_secret_key_textbox)
154
  #boto3.client('s3')
 
11
  # If running an anthropic model, assume that running an AWS Bedrock model, load in Bedrock
12
  model_source = model_name_map[model_choice]["source"]
13
 
14
+ if "AWS" in model_source:
15
+ if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
16
+ print("Connecting to Bedrock via existing SSO connection")
17
+ bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
18
+ elif RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
19
+ print("Connecting to Bedrock via existing SSO connection")
20
+ bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
21
+ elif aws_access_key_textbox and aws_secret_key_textbox:
22
  print("Connecting to Bedrock using AWS access key and secret keys from user input.")
23
  bedrock_runtime = boto3.client('bedrock-runtime',
24
  aws_access_key_id=aws_access_key_textbox,
25
  aws_secret_access_key=aws_secret_key_textbox, region_name=AWS_REGION)
 
 
 
26
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
27
  print("Getting Bedrock credentials from environment variables")
28
  bedrock_runtime = boto3.client('bedrock-runtime',
29
  aws_access_key_id=AWS_ACCESS_KEY,
30
  aws_secret_access_key=AWS_SECRET_KEY,
31
+ region_name=AWS_REGION)
32
+ elif RUN_AWS_FUNCTIONS == "1":
33
+ print("Connecting to Bedrock via existing SSO connection")
34
+ bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
35
  else:
36
  bedrock_runtime = ""
37
  out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
38
  print(out_message)
39
  raise Exception(out_message)
40
  else:
41
+ bedrock_runtime = list()
42
+
43
+ print("Bedrock runtime connected:", bedrock_runtime)
44
 
45
  return bedrock_runtime
46
 
47
  def connect_to_s3_client(aws_access_key_textbox:str="", aws_secret_key_textbox:str=""):
48
  # If running an anthropic model, assume that running an AWS s3 model, load in s3
49
+ s3_client = list()
50
 
51
  if aws_access_key_textbox and aws_secret_key_textbox:
52
  print("Connecting to s3 using AWS access key and secret keys from user input.")
 
156
  """
157
  if RUN_AWS_FUNCTIONS == "1":
158
 
159
+ final_out_message = list()
160
 
161
  s3_client = connect_to_s3_client(aws_access_key_textbox, aws_secret_key_textbox)
162
  #boto3.client('s3')
tools/config.py CHANGED
@@ -206,6 +206,33 @@ RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
206
  RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
207
  GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Build up options for models
210
 
211
  model_full_names = list()
@@ -225,12 +252,16 @@ if RUN_AWS_BEDROCK_MODELS == "1":
225
  model_source.extend(["AWS", "AWS", "AWS", "AWS", "AWS"])
226
 
227
  if RUN_GEMINI_MODELS == "1":
228
- model_full_names.extend(["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"]) # , # Gemini pro No longer available on free tier
229
  model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
230
  model_source.extend(["Gemini", "Gemini", "Gemini"])
231
 
232
- #print("model_short_names:", model_short_names)
233
- #print("model_full_names:", model_full_names)
 
 
 
 
234
 
235
  model_name_map = {
236
  full: {"short_name": short, "source": source}
 
206
  RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
207
  GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
208
 
209
+ # Build up options for models
210
+ ###
211
+ # LLM variables
212
+ ###
213
+
214
+ MAX_TOKENS = int(get_or_create_env_var('MAX_TOKENS', '4096')) # Maximum number of output tokens
215
+ TIMEOUT_WAIT = int(get_or_create_env_var('TIMEOUT_WAIT', '30')) # AWS now seems to have a 60 second minimum wait between API calls
216
+ NUMBER_OF_RETRY_ATTEMPTS = int(get_or_create_env_var('NUMBER_OF_RETRY_ATTEMPTS', '5'))
217
+ # Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
218
+ MAX_OUTPUT_VALIDATION_ATTEMPTS = int(get_or_create_env_var('MAX_OUTPUT_VALIDATION_ATTEMPTS', '3'))
219
+ MAX_TIME_FOR_LOOP = int(get_or_create_env_var('MAX_TIME_FOR_LOOP', '99999'))
220
+ BATCH_SIZE_DEFAULT = int(get_or_create_env_var('BATCH_SIZE_DEFAULT', '5'))
221
+ DEDUPLICATION_THRESHOLD = int(get_or_create_env_var('DEDUPLICATION_THRESHOLD', '90'))
222
+ MAX_COMMENT_CHARS = int(get_or_create_env_var('MAX_COMMENT_CHARS', '14000'))
223
+
224
+ RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "1")
225
+
226
+ RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
227
+
228
+ RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
229
+ GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
230
+
231
+ # Azure AI Inference settings
232
+ RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "0")
233
+ AZURE_API_KEY = get_or_create_env_var('AZURE_API_KEY', '')
234
+ AZURE_INFERENCE_ENDPOINT = get_or_create_env_var('AZURE_INFERENCE_ENDPOINT', '')
235
+
236
  # Build up options for models
237
 
238
  model_full_names = list()
 
252
  model_source.extend(["AWS", "AWS", "AWS", "AWS", "AWS"])
253
 
254
  if RUN_GEMINI_MODELS == "1":
255
+ model_full_names.extend(["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"])
256
  model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
257
  model_source.extend(["Gemini", "Gemini", "Gemini"])
258
 
259
+ # Register Azure AI models (model names must match your Azure deployments)
260
+ if RUN_AZURE_MODELS == "1":
261
+ # Example deployments; adjust to the deployments you actually create in Azure
262
+ model_full_names.extend(["gpt-5-mini"])
263
+ model_short_names.extend(["gpt-5-mini"])
264
+ model_source.extend(["Azure"])
265
 
266
  model_name_map = {
267
  full: {"short_name": short, "source": source}
tools/dedup_summaries.py CHANGED
@@ -8,12 +8,13 @@ import time
8
  import markdown
9
  import boto3
10
  from tqdm import tqdm
 
11
 
12
  from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill
13
- from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata
14
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
15
- from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX
16
  from tools.aws_functions import connect_to_bedrock_runtime
 
17
 
18
  max_tokens = MAX_TOKENS
19
  timeout_wait = TIMEOUT_WAIT
@@ -440,10 +441,13 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
440
  google_client = list()
441
  google_config = {}
442
 
443
- # Prepare Gemini models before query
444
  if "Gemini" in model_source:
445
  #print("Using Gemini model:", model_choice)
446
  google_client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
 
 
 
447
  elif "Local" in model_source:
448
  pass
449
  #print("Using local model: ", model_choice)
@@ -594,6 +598,8 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
594
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
595
 
596
  try:
 
 
597
  response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
598
  summarised_output = response
599
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
 
8
  import markdown
9
  import boto3
10
  from tqdm import tqdm
11
+ import os
12
 
13
  from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill
14
+ from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata, construct_azure_client
15
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
 
16
  from tools.aws_functions import connect_to_bedrock_runtime
17
+ from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
18
 
19
  max_tokens = MAX_TOKENS
20
  timeout_wait = TIMEOUT_WAIT
 
441
  google_client = list()
442
  google_config = {}
443
 
444
+ # Prepare Gemini models before query
445
  if "Gemini" in model_source:
446
  #print("Using Gemini model:", model_choice)
447
  google_client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
448
+ elif "Azure" in model_source:
449
+ # Azure client (endpoint from env/config)
450
+ google_client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=AZURE_INFERENCE_ENDPOINT)
451
  elif "Local" in model_source:
452
  pass
453
  #print("Using local model: ", model_choice)
 
598
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
599
 
600
  try:
601
+ print("formatted_summarise_topic_descriptions_system_prompt:", formatted_summarise_topic_descriptions_system_prompt)
602
+
603
  response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
604
  summarised_output = response
605
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
tools/llm_api_call.py CHANGED
@@ -16,8 +16,8 @@ GradioFileData = gr.FileData
16
 
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details
19
- from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
20
- from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX
21
  from tools.aws_functions import connect_to_bedrock_runtime
22
 
23
  if RUN_LOCAL_MODEL == "1":
@@ -477,8 +477,6 @@ def write_llm_output_and_logs(response_text: str,
477
  new_reference_df = pd.DataFrame(reference_data)
478
  else:
479
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Summary", "Start row of group"])
480
-
481
- print("new_reference_df:", new_reference_df)
482
 
483
  # Append on old reference data
484
  if not new_reference_df.empty:
@@ -689,6 +687,7 @@ def extract_topics(in_data_file: GradioFileData,
689
  aws_access_key_textbox:str='',
690
  aws_secret_key_textbox:str='',
691
  hf_api_key_textbox:str='',
 
692
  max_tokens:int=max_tokens,
693
  model_name_map:dict=model_name_map,
694
  max_time_for_loop:int=max_time_for_loop,
@@ -708,7 +707,7 @@ def extract_topics(in_data_file: GradioFileData,
708
  - unique_table_df_display_table_markdown (str): Table for display in markdown format.
709
  - file_name (str): File name of the data file.
710
  - num_batches (int): Number of batches required to go through all the response rows.
711
- - in_api_key (str): The API key for authentication.
712
  - temperature (float): The temperature parameter for the model.
713
  - chosen_cols (List[str]): A list of chosen columns to process.
714
  - candidate_topics (gr.FileData): A Gradio FileData object of existing candidate topics submitted by the user.
@@ -843,7 +842,7 @@ def extract_topics(in_data_file: GradioFileData,
843
 
844
  for i in topics_loop:
845
  reported_batch_no = latest_batch_completed + 1
846
- print("Running batch:", reported_batch_no)
847
 
848
  # Call the function to prepare the input table
849
  simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
@@ -859,10 +858,16 @@ def extract_topics(in_data_file: GradioFileData,
859
 
860
  formatted_system_prompt = add_existing_topics_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
861
 
862
- # Prepare Gemini models before query
863
  if "Gemini" in model_source:
864
  print("Using Gemini model:", model_choice)
865
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
 
 
 
 
 
 
866
  elif "anthropic.claude" in model_choice:
867
  print("Using AWS Bedrock model:", model_choice)
868
  else:
@@ -1034,6 +1039,11 @@ def extract_topics(in_data_file: GradioFileData,
1034
  if model_source == "Gemini":
1035
  print("Using Gemini model:", model_choice)
1036
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
 
 
 
 
 
1037
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
1038
  print("Using local model:", model_choice)
1039
  else:
@@ -1118,7 +1128,7 @@ def extract_topics(in_data_file: GradioFileData,
1118
 
1119
  # Increase latest file completed count unless we are over the last batch number
1120
  if latest_batch_completed <= num_batches:
1121
- print("Completed batch number:", str(reported_batch_no))
1122
  latest_batch_completed += 1
1123
 
1124
  toc = time.perf_counter()
@@ -1242,19 +1252,16 @@ def wrapper_extract_topics_per_column_value(
1242
  initial_existing_topic_summary_df: pd.DataFrame,
1243
  initial_unique_table_df_display_table_markdown: str,
1244
  original_file_name: str, # Original file name, to be modified per segment
1245
- # Initial state parameters (wrapper will use these for the very first call)
1246
  total_number_of_batches:int,
1247
  in_api_key: str,
1248
  temperature: float,
1249
  chosen_cols: List[str],
1250
  model_choice: str,
1251
  candidate_topics: GradioFileData = None,
1252
-
1253
  initial_first_loop_state: bool = True,
1254
  initial_whole_conversation_metadata_str: str = '',
1255
  initial_latest_batch_completed: int = 0,
1256
  initial_time_taken: float = 0,
1257
-
1258
  initial_table_prompt: str = initial_table_prompt,
1259
  prompt2: str = prompt2,
1260
  prompt3: str = prompt3,
@@ -1273,14 +1280,67 @@ def wrapper_extract_topics_per_column_value(
1273
  aws_access_key_textbox:str="",
1274
  aws_secret_key_textbox:str="",
1275
  hf_api_key_textbox:str="",
 
1276
  output_folder: str = OUTPUT_FOLDER,
1277
  force_single_topic_prompt: str = force_single_topic_prompt,
1278
  max_tokens: int = max_tokens,
1279
  model_name_map: dict = model_name_map,
1280
  max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
 
1281
  CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
1282
  progress=Progress(track_tqdm=True) # type: ignore
1283
  ) -> Tuple: # Mimicking the return tuple structure of extract_topics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1284
 
1285
  acc_input_tokens = 0
1286
  acc_output_tokens = 0
@@ -1380,7 +1440,7 @@ def wrapper_extract_topics_per_column_value(
1380
  seg_join_files,
1381
  seg_reference_df_pivot,
1382
  seg_missing_df
1383
- ) = extract_topics(
1384
  in_data_file=in_data_file,
1385
  file_data=filtered_file_data,
1386
  existing_topics_table=pd.DataFrame(), #acc_topics_table.copy(), # Pass the accumulated table
@@ -1389,19 +1449,17 @@ def wrapper_extract_topics_per_column_value(
1389
  unique_table_df_display_table_markdown="", # extract_topics will generate this
1390
  file_name=segment_file_name,
1391
  num_batches=current_num_batches,
1392
- latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
1393
- first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1394
- out_message= list(), # Fresh for each call
1395
- out_file_paths= list(),# Fresh for each call
1396
- log_files_output_paths= list(),# Fresh for each call
1397
- whole_conversation_metadata_str="", # Fresh for each call
1398
- time_taken=0, # Time taken for this specific call, wrapper sums it.
1399
- # Pass through other parameters
1400
  in_api_key=in_api_key,
1401
  temperature=temperature,
1402
  chosen_cols=chosen_cols,
1403
  model_choice=model_choice,
1404
  candidate_topics=candidate_topics,
 
 
 
 
 
 
1405
  initial_table_prompt=initial_table_prompt,
1406
  prompt2=prompt2,
1407
  prompt3=prompt3,
@@ -1411,6 +1469,7 @@ def wrapper_extract_topics_per_column_value(
1411
  number_of_prompts_used=number_of_prompts_used,
1412
  batch_size=batch_size,
1413
  context_textbox=context_textbox,
 
1414
  sentiment_checkbox=sentiment_checkbox,
1415
  force_zero_shot_radio=force_zero_shot_radio,
1416
  in_excel_sheets=in_excel_sheets,
@@ -1422,11 +1481,13 @@ def wrapper_extract_topics_per_column_value(
1422
  aws_access_key_textbox=aws_access_key_textbox,
1423
  aws_secret_key_textbox=aws_secret_key_textbox,
1424
  hf_api_key_textbox=hf_api_key_textbox,
 
1425
  max_tokens=max_tokens,
1426
  model_name_map=model_name_map,
1427
  max_time_for_loop=max_time_for_loop,
1428
  CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE,
1429
- progress=progress,
 
1430
  )
1431
 
1432
  # Aggregate results
 
16
 
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details
19
+ from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client
20
+ from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
21
  from tools.aws_functions import connect_to_bedrock_runtime
22
 
23
  if RUN_LOCAL_MODEL == "1":
 
477
  new_reference_df = pd.DataFrame(reference_data)
478
  else:
479
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Summary", "Start row of group"])
 
 
480
 
481
  # Append on old reference data
482
  if not new_reference_df.empty:
 
687
  aws_access_key_textbox:str='',
688
  aws_secret_key_textbox:str='',
689
  hf_api_key_textbox:str='',
690
+ azure_api_key_textbox:str='',
691
  max_tokens:int=max_tokens,
692
  model_name_map:dict=model_name_map,
693
  max_time_for_loop:int=max_time_for_loop,
 
707
  - unique_table_df_display_table_markdown (str): Table for display in markdown format.
708
  - file_name (str): File name of the data file.
709
  - num_batches (int): Number of batches required to go through all the response rows.
710
+ - in_api_key (str): The API key for authentication (Google Gemini).
711
  - temperature (float): The temperature parameter for the model.
712
  - chosen_cols (List[str]): A list of chosen columns to process.
713
  - candidate_topics (gr.FileData): A Gradio FileData object of existing candidate topics submitted by the user.
 
842
 
843
  for i in topics_loop:
844
  reported_batch_no = latest_batch_completed + 1
845
+ print("Running response batch:", reported_batch_no)
846
 
847
  # Call the function to prepare the input table
848
  simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
 
858
 
859
  formatted_system_prompt = add_existing_topics_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
860
 
861
+ # Prepare clients before query
862
  if "Gemini" in model_source:
863
  print("Using Gemini model:", model_choice)
864
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
865
+ elif "Azure" in model_source:
866
+ print("Using Azure AI Inference model:", model_choice)
867
+ # If provided, set env for downstream calls too
868
+ if azure_api_key_textbox:
869
+ os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
870
+ google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
871
  elif "anthropic.claude" in model_choice:
872
  print("Using AWS Bedrock model:", model_choice)
873
  else:
 
1039
  if model_source == "Gemini":
1040
  print("Using Gemini model:", model_choice)
1041
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
1042
+ elif model_source == "Azure":
1043
+ print("Using Azure AI Inference model:", model_choice)
1044
+ if azure_api_key_textbox:
1045
+ os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
1046
+ google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
1047
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
1048
  print("Using local model:", model_choice)
1049
  else:
 
1128
 
1129
  # Increase latest file completed count unless we are over the last batch number
1130
  if latest_batch_completed <= num_batches:
1131
+ #print("Completed batch number:", str(reported_batch_no))
1132
  latest_batch_completed += 1
1133
 
1134
  toc = time.perf_counter()
 
1252
  initial_existing_topic_summary_df: pd.DataFrame,
1253
  initial_unique_table_df_display_table_markdown: str,
1254
  original_file_name: str, # Original file name, to be modified per segment
 
1255
  total_number_of_batches:int,
1256
  in_api_key: str,
1257
  temperature: float,
1258
  chosen_cols: List[str],
1259
  model_choice: str,
1260
  candidate_topics: GradioFileData = None,
 
1261
  initial_first_loop_state: bool = True,
1262
  initial_whole_conversation_metadata_str: str = '',
1263
  initial_latest_batch_completed: int = 0,
1264
  initial_time_taken: float = 0,
 
1265
  initial_table_prompt: str = initial_table_prompt,
1266
  prompt2: str = prompt2,
1267
  prompt3: str = prompt3,
 
1280
  aws_access_key_textbox:str="",
1281
  aws_secret_key_textbox:str="",
1282
  hf_api_key_textbox:str="",
1283
+ azure_api_key_textbox:str="",
1284
  output_folder: str = OUTPUT_FOLDER,
1285
  force_single_topic_prompt: str = force_single_topic_prompt,
1286
  max_tokens: int = max_tokens,
1287
  model_name_map: dict = model_name_map,
1288
  max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
1289
+ reasoning_suffix: str = reasoning_suffix,
1290
  CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
1291
  progress=Progress(track_tqdm=True) # type: ignore
1292
  ) -> Tuple: # Mimicking the return tuple structure of extract_topics
1293
+ """
1294
+ A wrapper function that iterates through unique values in a specified grouping column
1295
+ and calls the `extract_topics` function for each segment of the data.
1296
+ It accumulates results from each call and returns a consolidated output.
1297
+
1298
+ :param grouping_col: The name of the column to group the data by.
1299
+ :param in_data_file: The input data file object (e.g., Gradio FileData).
1300
+ :param file_data: The full DataFrame containing all data.
1301
+ :param initial_existing_topics_table: Initial DataFrame of existing topics.
1302
+ :param initial_existing_reference_df: Initial DataFrame mapping responses to topics.
1303
+ :param initial_existing_topic_summary_df: Initial DataFrame summarizing topics.
1304
+ :param initial_unique_table_df_display_table_markdown: Initial markdown string for topic display.
1305
+ :param original_file_name: The original name of the input file.
1306
+ :param total_number_of_batches: The total number of batches across all data.
1307
+ :param in_api_key: API key for the chosen LLM.
1308
+ :param temperature: Temperature setting for the LLM.
1309
+ :param chosen_cols: List of columns from `file_data` to be processed.
1310
+ :param model_choice: The chosen LLM model (e.g., "Gemini", "AWS Claude").
1311
+ :param candidate_topics: Optional Gradio FileData for candidate topics (zero-shot).
1312
+ :param initial_first_loop_state: Boolean indicating if this is the very first loop iteration.
1313
+ :param initial_whole_conversation_metadata_str: Initial metadata string for the whole conversation.
1314
+ :param initial_latest_batch_completed: The batch number completed in the previous run.
1315
+ :param initial_time_taken: Initial time taken for processing.
1316
+ :param initial_table_prompt: The initial prompt for table summarization.
1317
+ :param prompt2: The second prompt for LLM interaction.
1318
+ :param prompt3: The third prompt for LLM interaction.
1319
+ :param initial_table_system_prompt: The initial system prompt for table summarization.
1320
+ :param add_existing_topics_system_prompt: System prompt for adding existing topics.
1321
+ :param add_existing_topics_prompt: Prompt for adding existing topics.
1322
+ :param number_of_prompts_used: Number of prompts used in the LLM call.
1323
+ :param batch_size: Number of rows to process in each batch for the LLM.
1324
+ :param context_textbox: Additional context provided by the user.
1325
+ :param sentiment_checkbox: Choice for sentiment assessment (e.g., "Negative, Neutral, or Positive").
1326
+ :param force_zero_shot_radio: Option to force responses into zero-shot topics.
1327
+ :param in_excel_sheets: List of Excel sheet names if applicable.
1328
+ :param force_single_topic_radio: Option to force a single topic per response.
1329
+ :param produce_structures_summary_radio: Option to produce a structured summary.
1330
+ :param aws_access_key_textbox: AWS access key for Bedrock.
1331
+ :param aws_secret_key_textbox: AWS secret key for Bedrock.
1332
+ :param hf_api_key_textbox: Hugging Face API key for local models.
1333
+ :param azure_api_key_textbox: Azure API key for Azure AI Inference.
1334
+ :param output_folder: The folder where output files will be saved.
1335
+ :param force_single_topic_prompt: Prompt for forcing a single topic.
1336
+ :param max_tokens: Maximum tokens for LLM generation.
1337
+ :param model_name_map: Dictionary mapping model names to their properties.
1338
+ :param max_time_for_loop: Maximum time allowed for the processing loop.
1339
+ :param reasoning_suffix: Suffix to append for reasoning.
1340
+ :param CHOSEN_LOCAL_MODEL_TYPE: Type of local model chosen.
1341
+ :param progress: Gradio Progress object for tracking progress.
1342
+ :return: A tuple containing consolidated results, mimicking the return structure of `extract_topics`.
1343
+ """
1344
 
1345
  acc_input_tokens = 0
1346
  acc_output_tokens = 0
 
1440
  seg_join_files,
1441
  seg_reference_df_pivot,
1442
  seg_missing_df
1443
+ ) = extract_topics(
1444
  in_data_file=in_data_file,
1445
  file_data=filtered_file_data,
1446
  existing_topics_table=pd.DataFrame(), #acc_topics_table.copy(), # Pass the accumulated table
 
1449
  unique_table_df_display_table_markdown="", # extract_topics will generate this
1450
  file_name=segment_file_name,
1451
  num_batches=current_num_batches,
 
 
 
 
 
 
 
 
1452
  in_api_key=in_api_key,
1453
  temperature=temperature,
1454
  chosen_cols=chosen_cols,
1455
  model_choice=model_choice,
1456
  candidate_topics=candidate_topics,
1457
+ latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
1458
+ out_message= list(), # Fresh for each call
1459
+ out_file_paths= list(),# Fresh for each call
1460
+ log_files_output_paths= list(),# Fresh for each call
1461
+ first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1462
+ whole_conversation_metadata_str="", # Fresh for each call
1463
  initial_table_prompt=initial_table_prompt,
1464
  prompt2=prompt2,
1465
  prompt3=prompt3,
 
1469
  number_of_prompts_used=number_of_prompts_used,
1470
  batch_size=batch_size,
1471
  context_textbox=context_textbox,
1472
+ time_taken=0, # Time taken for this specific call, wrapper sums it.
1473
  sentiment_checkbox=sentiment_checkbox,
1474
  force_zero_shot_radio=force_zero_shot_radio,
1475
  in_excel_sheets=in_excel_sheets,
 
1481
  aws_access_key_textbox=aws_access_key_textbox,
1482
  aws_secret_key_textbox=aws_secret_key_textbox,
1483
  hf_api_key_textbox=hf_api_key_textbox,
1484
+ azure_api_key_textbox=azure_api_key_textbox,
1485
  max_tokens=max_tokens,
1486
  model_name_map=model_name_map,
1487
  max_time_for_loop=max_time_for_loop,
1488
  CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE,
1489
+ reasoning_suffix=reasoning_suffix,
1490
+ progress=progress
1491
  )
1492
 
1493
  # Aggregate results
tools/llm_funcs.py CHANGED
@@ -13,12 +13,16 @@ from google.genai import types
13
  import gradio as gr
14
  from gradio import Progress
15
 
 
 
 
 
16
  model_type = None # global variable setup
17
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
18
  model = list() # Define empty list for model functions to run
19
  tokenizer = list() #[] # Define empty list for model functions to run
20
 
21
- from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, INT8_WITH_OFFLOAD_TO_CPU
22
  from tools.prompts import initial_table_assistant_prefill
23
 
24
  if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
@@ -500,16 +504,7 @@ def llama_cpp_streaming(history, full_prompt, temperature=temperature):
500
  def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int, random_seed=seed) -> Tuple[object, dict]:
501
  """
502
  Constructs a GenerativeModel for Gemini API calls.
503
-
504
- Parameters:
505
- - in_api_key (str): The API key for authentication.
506
- - temperature (float): The temperature parameter for the model, controlling the randomness of the output.
507
- - model_choice (str): The choice of model to use for generation.
508
- - system_prompt (str): The system prompt to guide the generation.
509
- - max_tokens (int): The maximum number of tokens to generate.
510
-
511
- Returns:
512
- - Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
513
  """
514
  # Construct a GenerativeModel
515
  try:
@@ -532,6 +527,31 @@ def construct_gemini_generative_model(in_api_key: str, temperature: float, model
532
 
533
  return client, config
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
536
  """
537
  This function sends a request to AWS Claude with the following parameters:
@@ -667,15 +687,10 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
667
  duration = end_time - start_time
668
  tokens_per_second = num_generated_tokens / duration
669
 
670
- # print("\n--- Inference Results ---")
671
- # print(f"System Prompt: {conversation[0]['content']}")
672
- # print(f"User Prompt: {conversation[1]['content']}")
673
- # print("---")
674
- # print(f"Assistant's Reply: {assistant_reply}")
675
- # print("\n--- Performance ---")
676
- # print(f"Time taken: {duration:.2f} seconds")
677
- # print(f"Generated tokens: {num_generated_tokens}")
678
- # print(f"Tokens per second: {tokens_per_second:.2f}")
679
 
680
  return assistant_reply, num_input_tokens, num_generated_tokens
681
 
@@ -725,6 +740,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
725
 
726
  if i == number_of_api_retry_attempts:
727
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
 
728
  elif "AWS" in model_source:
729
  for i in progress_bar:
730
  try:
@@ -740,6 +756,35 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
740
 
741
  if i == number_of_api_retry_attempts:
742
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  elif "Local" in model_source:
744
  # This is the local model
745
  for i in progress_bar:
@@ -776,28 +821,29 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
776
  # Check if is a LLama.cpp model response
777
  if isinstance(response, ResponseObject):
778
  response_text = response.text
779
- conversation_history.append({'role': 'assistant', 'parts': [response_text]})
780
  elif 'choices' in response: # LLama.cpp model response
781
  if "gpt-oss" in model_choice:
782
  response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
783
  else:
784
  response_text = response['choices'][0]['message']['content']
785
  response_text = response_text.strip()
786
- conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
787
  elif model_source == "Gemini":
788
  response_text = response.text
789
  response_text = response_text.strip()
790
- conversation_history.append({'role': 'assistant', 'parts': [response_text]})
791
  else: # Assume transformers model response
792
  if "gpt-oss" in model_choice:
793
  response_text = response.split('<|start|>assistant<|channel|>final<|message|>')[1]
794
  else:
795
  response_text = response
796
- conversation_history.append({'role': 'assistant', 'parts': [response_text]})
 
797
 
798
  return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
799
 
800
- def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), tokenizer=tokenizer, master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
 
 
 
801
  """
802
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
803
 
@@ -836,28 +882,27 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
836
  whole_conversation.append(response_text)
837
 
838
  # Create conversation metadata
839
- if master == False:
840
- whole_conversation_metadata.append(f"Batch {batch_no}:")
841
- else:
842
- #whole_conversation_metadata.append(f"Query summary metadata:")
843
- whole_conversation_metadata.append(f"Batch {batch_no}:")
 
844
 
845
  # if not isinstance(response, str):
846
  try:
847
  if "AWS" in model_source:
848
- #print("Extracting usage metadata from Converse API response...")
849
-
850
- # Using .get() is safer than direct access, in case a key is missing.
851
  output_tokens = response.usage_metadata.get('outputTokens', 0)
852
  input_tokens = response.usage_metadata.get('inputTokens', 0)
853
-
854
- #print(f"Extracted Token Counts - Input: {input_tokens}, Output: {output_tokens}")
855
-
856
- elif "Gemini" in model_source:
857
-
858
  output_tokens = response.usage_metadata.candidates_token_count
859
  input_tokens = response.usage_metadata.prompt_token_count
860
 
 
 
 
 
861
  elif "Local" in model_source:
862
  if USE_LLAMA_CPP == "True":
863
  output_tokens = response['usage'].get('completion_tokens', 0)
@@ -1012,20 +1057,8 @@ def calculate_tokens_from_metadata(metadata_string:str, model_choice:str, model_
1012
 
1013
  # Regex to find the numbers following the keys in the "Query summary metadata" section
1014
  # This ensures we get the final, aggregated totals for the whole query.
1015
- #if "Gemini" in model_source:
1016
  input_regex = r"input_tokens: (\d+)"
1017
  output_regex = r"output_tokens: (\d+)"
1018
- # elif "AWS" in model_source:
1019
- # input_regex = r"inputTokens: (\d+)"
1020
- # output_regex = r"outputTokens: (\d+)"
1021
- # elif "Local" in model_source:
1022
- # print("Local model source")
1023
- # input_regex = r"\'prompt_tokens\': (\d+)"
1024
- # output_regex = r"\'completion_tokens\': (\d+)"
1025
-
1026
- #print("Metadata string:", metadata_string)
1027
- #print("Input regex:", input_regex)
1028
- #print("Output regex:", output_regex)
1029
 
1030
  # re.findall returns a list of all matching strings (the captured groups).
1031
  input_token_strings = re.findall(input_regex, metadata_string)
 
13
  import gradio as gr
14
  from gradio import Progress
15
 
16
+ from azure.ai.inference import ChatCompletionsClient
17
+ from azure.core.credentials import AzureKeyCredential
18
+ from azure.ai.inference.models import SystemMessage, UserMessage
19
+
20
  model_type = None # global variable setup
21
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
22
  model = list() # Define empty list for model functions to run
23
  tokenizer = list() #[] # Define empty list for model functions to run
24
 
25
+ from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, INT8_WITH_OFFLOAD_TO_CPU, AZURE_INFERENCE_ENDPOINT
26
  from tools.prompts import initial_table_assistant_prefill
27
 
28
  if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
 
504
  def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int, random_seed=seed) -> Tuple[object, dict]:
505
  """
506
  Constructs a GenerativeModel for Gemini API calls.
507
+ ...
 
 
 
 
 
 
 
 
 
508
  """
509
  # Construct a GenerativeModel
510
  try:
 
527
 
528
  return client, config
529
 
530
+ def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
531
+ """
532
+ Constructs a ChatCompletionsClient for Azure AI Inference.
533
+ """
534
+ try:
535
+ key = None
536
+ if in_api_key:
537
+ key = in_api_key
538
+ elif os.environ.get("AZURE_INFERENCE_CREDENTIAL"):
539
+ key = os.environ["AZURE_INFERENCE_CREDENTIAL"]
540
+ elif os.environ.get("AZURE_API_KEY"):
541
+ key = os.environ["AZURE_API_KEY"]
542
+ if not key:
543
+ raise Warning("No Azure API key found.")
544
+
545
+ if not endpoint:
546
+ endpoint = os.environ.get("AZURE_INFERENCE_ENDPOINT", "")
547
+ if not endpoint:
548
+ raise Warning("No Azure inference endpoint found.")
549
+ client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key))
550
+ return client, {}
551
+ except Exception as e:
552
+ print("Error constructing Azure ChatCompletions client:", e)
553
+ raise
554
+
555
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
556
  """
557
  This function sends a request to AWS Claude with the following parameters:
 
687
  duration = end_time - start_time
688
  tokens_per_second = num_generated_tokens / duration
689
 
690
+ print("\n--- Performance ---")
691
+ print(f"Time taken: {duration:.2f} seconds")
692
+ print(f"Generated tokens: {num_generated_tokens}")
693
+ print(f"Tokens per second: {tokens_per_second:.2f}")
 
 
 
 
 
694
 
695
  return assistant_reply, num_input_tokens, num_generated_tokens
696
 
 
740
 
741
  if i == number_of_api_retry_attempts:
742
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
743
+
744
  elif "AWS" in model_source:
745
  for i in progress_bar:
746
  try:
 
756
 
757
  if i == number_of_api_retry_attempts:
758
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
759
+ elif "Azure" in model_source:
760
+ for i in progress_bar:
761
+ try:
762
+ print("Calling Azure AI Inference model, attempt", i + 1)
763
+ # Use structured messages for Azure
764
+ response_raw = google_client.complete(
765
+ messages=[
766
+ SystemMessage(content=system_prompt),
767
+ UserMessage(content=prompt),
768
+ ],
769
+ model=model_choice
770
+ )
771
+ response_text = response_raw.choices[0].message.content
772
+ usage = getattr(response_raw, "usage", None)
773
+ input_tokens = 0
774
+ output_tokens = 0
775
+ if usage is not None:
776
+ input_tokens = getattr(usage, "input_tokens", getattr(usage, "prompt_tokens", 0))
777
+ output_tokens = getattr(usage, "output_tokens", getattr(usage, "completion_tokens", 0))
778
+ response = ResponseObject(
779
+ text=response_text,
780
+ usage_metadata={'inputTokens': input_tokens, 'outputTokens': output_tokens}
781
+ )
782
+ break
783
+ except Exception as e:
784
+ print("Call to Azure model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
785
+ time.sleep(timeout_wait)
786
+ if i == number_of_api_retry_attempts:
787
+ return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
788
  elif "Local" in model_source:
789
  # This is the local model
790
  for i in progress_bar:
 
821
  # Check if is a LLama.cpp model response
822
  if isinstance(response, ResponseObject):
823
  response_text = response.text
 
824
  elif 'choices' in response: # LLama.cpp model response
825
  if "gpt-oss" in model_choice:
826
  response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
827
  else:
828
  response_text = response['choices'][0]['message']['content']
829
  response_text = response_text.strip()
 
830
  elif model_source == "Gemini":
831
  response_text = response.text
832
  response_text = response_text.strip()
 
833
  else: # Assume transformers model response
834
  if "gpt-oss" in model_choice:
835
  response_text = response.split('<|start|>assistant<|channel|>final<|message|>')[1]
836
  else:
837
  response_text = response
838
+
839
+ conversation_history.append({'role': 'assistant', 'parts': [response_text]})
840
 
841
  return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
842
 
843
+ def process_requests(prompts: List[str],
844
+ system_prompt: str,
845
+ conversation_history: List[dict],
846
+ whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), tokenizer=tokenizer, master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
847
  """
848
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
849
 
 
882
  whole_conversation.append(response_text)
883
 
884
  # Create conversation metadata
885
+ # if master == False:
886
+ # whole_conversation_metadata.append(f"Batch {batch_no}:")
887
+ # else:
888
+ # #whole_conversation_metadata.append(f"Query summary metadata:")
889
+
890
+ whole_conversation_metadata.append(f"Batch {batch_no}:")
891
 
892
  # if not isinstance(response, str):
893
  try:
894
  if "AWS" in model_source:
 
 
 
895
  output_tokens = response.usage_metadata.get('outputTokens', 0)
896
  input_tokens = response.usage_metadata.get('inputTokens', 0)
897
+
898
+ elif "Gemini" in model_source:
 
 
 
899
  output_tokens = response.usage_metadata.candidates_token_count
900
  input_tokens = response.usage_metadata.prompt_token_count
901
 
902
+ elif "Azure" in model_source:
903
+ input_tokens = response.usage_metadata.get('inputTokens', 0)
904
+ output_tokens = response.usage_metadata.get('outputTokens', 0)
905
+
906
  elif "Local" in model_source:
907
  if USE_LLAMA_CPP == "True":
908
  output_tokens = response['usage'].get('completion_tokens', 0)
 
1057
 
1058
  # Regex to find the numbers following the keys in the "Query summary metadata" section
1059
  # This ensures we get the final, aggregated totals for the whole query.
 
1060
  input_regex = r"input_tokens: (\d+)"
1061
  output_regex = r"output_tokens: (\d+)"
 
 
 
 
 
 
 
 
 
 
 
1062
 
1063
  # re.findall returns a list of all matching strings (the captured groups).
1064
  input_token_strings = re.findall(input_regex, metadata_string)