wiserkhan / app.py
khanhamzawiser's picture
Update app.py
3d62056 verified
from sqlalchemy import create_engine, Table, Column, String, Integer, Float, Text, TIMESTAMP, MetaData
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy import text
from llama_index.core import SQLDatabase
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.llms.huggingface import HuggingFaceLLM
import logging
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# PostgreSQL DB connection (converted from JDBC)
engine = create_engine("postgresql+psycopg2://postgres:password@0.tcp.ngrok.io:5434/postgres")
metadata_obj = MetaData()
# Define the machine_current_log table
machine_current_log_table = Table(
"machine_current_log",
metadata_obj,
Column("mac", Text, primary_key=True),
Column("created_at", TIMESTAMP(timezone=True), primary_key=True),
Column("CT1", Float),
Column("CT2", Float),
Column("CT3", Float),
Column("CT_Avg", Float),
Column("total_current", Float),
Column("state", Text),
Column("state_duration", Integer),
Column("fault_status", Text),
Column("fw_version", Text),
Column("machineId", UUID),
Column("hi", Text),
)
# Create the table
metadata_obj.create_all(engine)
# Convert to TimescaleDB hypertable
with engine.connect() as conn:
conn.execute(text("SELECT create_hypertable('machine_current_log', 'created_at', if_not_exists => TRUE);"))
print("TimescaleDB hypertable created")
conn.commit()
# Query 1: Get all MAC addresses
print("\nQuerying all MAC addresses:")
with engine.connect() as con:
rows = con.execute(text("SELECT mac from machine_current_log"))
for row in rows:
print(row)
# Query 2: Get all data and count
print("\nQuerying all data and count:")
stmt = text("""
SELECT mac, created_at, CT1, CT2, CT3, CT_Avg,
total_current, state, state_duration, fault_status,
fw_version, machineId
FROM machine_current_log
""")
with engine.connect() as connection:
print("hello")
count_stmt = text("SELECT COUNT(*) FROM machine_current_log")
count = connection.execute(count_stmt).scalar()
print(f"Total number of rows in table: {count}")
results = connection.execute(stmt).fetchall()
print(results)
# Set up LlamaIndex natural language querying
sql_database = SQLDatabase(engine)
llm = HuggingFaceLLM(
model_name="HuggingFaceH4/zephyr-7b-beta",
context_window=2048,
max_new_tokens=256,
generate_kwargs={"temperature": 0.7, "top_p": 0.95},
)
query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=["machine_current_log"],
llm=llm
)
def natural_language_query(question: str):
try:
response = query_engine.query(question)
return str(response)
except Exception as e:
logger.error(f"Query error: {e}")
return f"Error processing query: {str(e)}"
if __name__ == "__main__":
# Natural language query examples
print("\nNatural Language Query Examples:")
questions = [
"What is the average CT1 reading?",
"Which machine has the highest total current?",
"Show me the latest fault status for each machine"
]
for question in questions:
print(f"\nQuestion: {question}")
print("Answer:", natural_language_query(question))