File size: 3,411 Bytes
b1f2bdd
a5ece8b
aa97025
b1f2bdd
 
d162c32
 
 
 
887b999
d162c32
0b8ba87
b1f2bdd
0e84c33
d162c32
 
 
 
 
 
 
 
b1f2bdd
 
 
 
 
d162c32
 
0b8ba87
b1f2bdd
02d55fb
d162c32
 
 
 
 
 
 
 
 
 
 
 
 
b1f2bdd
d162c32
b1f2bdd
d162c32
 
b1f2bdd
 
d162c32
 
 
 
 
 
b1f2bdd
d162c32
 
 
02d55fb
d162c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f2bdd
d162c32
 
b1f2bdd
d162c32
 
 
 
 
 
 
 
79c9d08
e784f1e
d162c32
 
 
b1f2bdd
 
d162c32
 
b1f2bdd
e784f1e
 
79c9d08
b1f2bdd
0e84c33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import gradio as gr
import pandas as pd
import tensorflow as tf

# TAPAS imports
from tapas.protos import interaction_pb2
from tapas.utils import number_annotation_utils, tf_example_utils, prediction_utils
from tapas.scripts.run_task_main import get_classifier_model, get_task_config

# 1) Load & stringify your CSV
df = pd.read_csv("synthetic_profit.csv")
df = df.astype(str)

# 2) Build the “list of lists” table  
#    (header row + all data rows)
table = [list(df.columns)]
table.extend(df.values.tolist())

# 3) Prepare the TAPAS converter + model
#    – add_aggregation_candidates=True to surface SUM/AVG ops
#    – strip_column_names=False so your exact headers stay visible
config = tf_example_utils.ClassifierConversionConfig(
    vocab_file="tapas_sqa_base/vocab.txt",
    max_seq_length=512,
    max_column_id=512,
    max_row_id=512,
    strip_column_names=False,
    add_aggregation_candidates=True,
)
converter = tf_example_utils.ToClassifierTensorflowExample(config)

# 4) Load your pretrained checkpoint
#    (uses the same flags as run_task_main.py --mode=predict)
task_config = get_task_config(
    task="sqa",
    init_checkpoint="tapas_sqa_base/model.ckpt-0",
    vocab_file=config.vocab_file,
    bsz=1,
    max_seq_length=config.max_seq_length,
)
model, tokenizer = get_classifier_model(task_config)

# 5) Convert a single (table, query) into a TF Example
def make_tf_example(table, query):
    interaction = interaction_pb2.Interaction()
    # a) question
    q = interaction.questions.add()
    q.original_text = query
    # b) columns
    for col in table[0]:
        interaction.table.columns.add().text = col
    # c) rows
    for row_vals in table[1:]:
        row = interaction.table.rows.add()
        for cell in row_vals:
            row.cells.add().text = cell
    # d) numeric annotation helps SUM/AVG
    number_annotation_utils.add_numeric_values(interaction)
    # e) convert to example
    serialized = converter.convert(interaction)
    return serialized

# 6) Run TAPAS and parse its coordinate output
def predict_answer(query):
    # build TF example
    example = make_tf_example(table, query)
    # run prediction
    input_fn = tf_example_utils.input_fn_builder(
        [example],
        is_training=False,
        drop_remainder=False,
        batch_size=1,
        seq_length=config.max_seq_length,
    )
    preds = model.predict(input_fn)
    # parse answer coordinates
    coords = prediction_utils.parse_coordinates(preds[0]["answer_coordinates"])
    # map back to table values
    answers = []
    for (r, c) in coords:
        # table[0] is header row, so data starts at index 1
        answers.append(table[r+1][c])
    return ", ".join(answers) if answers else "No answer found."

# 7) Gradio interface
def answer_fn(question: str) -> str:
    try:
        return predict_answer(question)
    except Exception as e:
        return f"❌ Error: {e}"

iface = gr.Interface(
    fn=answer_fn,
    inputs=gr.Textbox(lines=2, label="Your question"),
    outputs=gr.Textbox(label="Answer"),
    title="SAP Profitability Q&A (TAPAS Low-Level)",
    description=(
        "Uses TAPAS’s Interaction + Converter APIs with aggregation candidates " 
        "and numeric annotations to reliably answer sum/average queries."
    ),
    allow_flagging="never",
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)