|
import os |
|
import re |
|
import json |
|
import requests |
|
from typing import List, Dict, Optional, Tuple |
|
import gradio as gr |
|
from googlesearch import search |
|
import google.generativeai as genai |
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold |
|
|
|
def initialize_gemini(api_key: str): |
|
"""Initialize the Google Gemini API with appropriate configurations""" |
|
genai.configure(api_key=api_key) |
|
generation_config = { |
|
"temperature": 0.2, |
|
"top_p": 0.8, |
|
"top_k": 40, |
|
"max_output_tokens": 1024, |
|
} |
|
safety_settings = { |
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
} |
|
|
|
model = genai.GenerativeModel( |
|
model_name="gemini-1.5-flash", |
|
generation_config=generation_config, |
|
safety_settings=safety_settings |
|
) |
|
return model |
|
|
|
def google_search_naics(company_name: str) -> List[str]: |
|
""" |
|
Find potential NAICS codes for a company using multiple targeted Google searches |
|
Uses more specific search queries to improve results |
|
""" |
|
naics_codes = set() |
|
|
|
|
|
queries = [ |
|
f"NAICS code for {company_name}", |
|
f"what is {company_name} company NAICS code", |
|
f"{company_name} business entity NAICS classification", |
|
f"{company_name} industry classification NAICS", |
|
f"{company_name} company information NAICS" |
|
] |
|
|
|
try: |
|
print(f"🔎 Searching Google for NAICS codes for '{company_name}'...") |
|
|
|
for query in queries: |
|
print(f" Query: {query}") |
|
try: |
|
|
|
search_results = search(query, stop=3, pause=2) |
|
|
|
for result_url in search_results: |
|
try: |
|
response = requests.get(result_url, timeout=5) |
|
if response.status_code == 200: |
|
|
|
found_codes = re.findall(r'\b\d{6}\b', response.text) |
|
naics_codes.update(found_codes) |
|
|
|
|
|
if found_codes: |
|
print(f" Found codes in {result_url}: {found_codes}") |
|
except Exception as e: |
|
print(f" ⚠️ Error fetching {result_url}: {e}") |
|
except Exception as e: |
|
print(f" ⚠️ Error with query '{query}': {e}") |
|
continue |
|
|
|
|
|
return list(naics_codes)[:10] |
|
except Exception as e: |
|
print(f"❌ Error performing Google search: {str(e)}") |
|
return [] |
|
|
|
def get_naics_classification(model, company_name: str, context: str, candidates: List[str]) -> dict: |
|
""" |
|
Use Gemini AI to determine the most appropriate NAICS code from candidates |
|
First provides reasoning, then returns the NAICS code and explanation |
|
""" |
|
try: |
|
print("🤖 AI is analyzing NAICS classification...") |
|
|
|
|
|
if candidates: |
|
|
|
prompt = f""" |
|
You are a NAICS code classification expert. Based on the company information provided and the NAICS code candidates found from Google search, determine the most appropriate NAICS code. |
|
|
|
Company Name: {company_name} |
|
Context Information: {context} |
|
|
|
NAICS Code Candidates from Google Search: {candidates} |
|
|
|
First, research what these NAICS codes represent: |
|
1. For each NAICS code candidate, briefly explain what industry or business activity it corresponds to. |
|
2. Then explain which industry classification best matches this company based on the name and context provided. |
|
3. Finally, select the single most appropriate NAICS code from the candidates, or suggest a different one if none match. |
|
|
|
Your response should be in this format: |
|
RESEARCH: [Brief explanation of what each NAICS code represents] |
|
REASONING: [Your detailed reasoning about why the chosen industry classification is most appropriate for this company] |
|
NAICS_CODE: [6-digit NAICS code] |
|
""" |
|
|
|
else: |
|
prompt = f""" |
|
You are a NAICS code classification expert. Based on the company information provided, determine the most appropriate NAICS code. |
|
|
|
Company Name: {company_name} |
|
Context Information: {context} |
|
|
|
First, analyze what industry this company likely belongs to based on its name and the provided context. |
|
Consider standard business classifications and determine the most appropriate category. |
|
Then provide the single most appropriate 6-digit NAICS code. |
|
|
|
Your response should be in this format: |
|
REASONING: [Your detailed reasoning about the company's industry classification, including what business activities it likely performs] |
|
NAICS_CODE: [6-digit NAICS code] |
|
""" |
|
response = model.generate_content(prompt) |
|
response_text = response.text.strip() |
|
|
|
|
|
result = {} |
|
|
|
|
|
if "RESEARCH:" in response_text: |
|
research_match = re.search(r'RESEARCH:(.*?)REASONING:', response_text, re.DOTALL | re.IGNORECASE) |
|
if research_match: |
|
result["research"] = research_match.group(1).strip() |
|
|
|
|
|
reasoning_match = re.search(r'REASONING:(.*?)NAICS_CODE:', response_text, re.DOTALL | re.IGNORECASE) |
|
result["reasoning"] = reasoning_match.group(1).strip() if reasoning_match else "No reasoning provided." |
|
|
|
|
|
naics_match = re.search(r'NAICS_CODE:(.*?)(\d{6})', response_text, re.DOTALL) |
|
if naics_match: |
|
result["naics_code"] = naics_match.group(2) |
|
else: |
|
|
|
code_match = re.search(r'\b(\d{6})\b', response_text) |
|
result["naics_code"] = code_match.group(1) if code_match else "000000" |
|
|
|
return result |
|
except Exception as e: |
|
print(f"❌ Error getting NAICS classification: {str(e)}") |
|
return { |
|
"naics_code": "000000", |
|
"reasoning": f"Error analyzing company: {str(e)}" |
|
} |
|
|
|
def find_naics_code(company_name: str, context: str = "", api_key: Optional[str] = None) -> Dict: |
|
""" |
|
Core function to find NAICS code for a company that can be called from different interfaces |
|
|
|
Args: |
|
company_name: Name of the company |
|
context: Brief description of the company (optional) |
|
api_key: Google Gemini API key (if None, will try to get from environment variable) |
|
|
|
Returns: |
|
Dictionary with NAICS code, reasoning, and optional research |
|
""" |
|
|
|
if not api_key: |
|
api_key = os.environ.get('GEMINI_API_KEY') |
|
if not api_key: |
|
return { |
|
"error": "No API key provided. Set GEMINI_API_KEY environment variable or pass as parameter.", |
|
"naics_code": "000000", |
|
"reasoning": "Error: API key missing" |
|
} |
|
|
|
|
|
try: |
|
model = initialize_gemini(api_key) |
|
except Exception as e: |
|
return { |
|
"error": f"Failed to initialize Gemini API: {str(e)}", |
|
"naics_code": "000000", |
|
"reasoning": f"Error: {str(e)}" |
|
} |
|
|
|
|
|
naics_candidates = google_search_naics(company_name) |
|
|
|
|
|
if not naics_candidates: |
|
print("No NAICS codes found from Google search.") |
|
result = get_naics_classification(model, company_name, context, []) |
|
else: |
|
print(f"Found {len(naics_candidates)} NAICS candidates: {naics_candidates}") |
|
result = get_naics_classification(model, company_name, context, naics_candidates) |
|
|
|
|
|
result["company_name"] = company_name |
|
result["context"] = context |
|
result["candidates"] = naics_candidates |
|
|
|
return result |
|
|
|
|
|
def classify_company(company_name: str, company_description: str, api_key: str = None) -> Tuple[str, str, str]: |
|
"""Process inputs from Gradio and return formatted results""" |
|
if not api_key: |
|
api_key = os.environ.get('GEMINI_API_KEY') |
|
|
|
if not company_name: |
|
return "Error: Company name is required", "", "" |
|
|
|
result = find_naics_code(company_name, company_description, api_key) |
|
|
|
|
|
naics_code = f"**NAICS Code: {result['naics_code']}**" |
|
|
|
|
|
research = "" |
|
if "research" in result and result["research"]: |
|
research = f"## Research on NAICS Codes\n\n{result['research']}" |
|
|
|
|
|
reasoning = f"## Analysis\n\n{result['reasoning']}" |
|
|
|
return naics_code, research, reasoning |
|
|
|
|
|
def create_gradio_interface(): |
|
with gr.Blocks(title="NAICS Code Finder") as demo: |
|
gr.Markdown("# NAICS Code Finder") |
|
gr.Markdown("Enter a company name and optional description to find the most appropriate NAICS code.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
company_name = gr.Textbox(label="Company Name", placeholder="Enter company name") |
|
company_description = gr.Textbox(label="Company Description (optional)", placeholder="Brief description of the company") |
|
api_key = gr.Textbox( |
|
label="Gemini API Key (optional)", |
|
placeholder="Enter your API key or set GEMINI_API_KEY env variable", |
|
visible=not bool(os.environ.get('GEMINI_API_KEY')) |
|
) |
|
submit_btn = gr.Button("Find NAICS Code") |
|
|
|
with gr.Column(): |
|
naics_output = gr.Markdown(label="NAICS Code") |
|
research_output = gr.Markdown(label="Research") |
|
reasoning_output = gr.Markdown(label="Reasoning") |
|
|
|
submit_btn.click( |
|
classify_company, |
|
inputs=[company_name, company_description, api_key], |
|
outputs=[naics_output, research_output, reasoning_output] |
|
) |
|
|
|
gr.Examples( |
|
[ |
|
["Apple Inc", "Tech company that makes iPhones and computers"], |
|
["Starbucks", "Coffee shop chain"], |
|
["Bank of America", "Banking and financial services"], |
|
["Tesla", "Electric vehicle manufacturer"] |
|
], |
|
inputs=[company_name, company_description] |
|
) |
|
|
|
return demo |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |