chandini2595 commited on
Commit
f825473
·
1 Parent(s): 0545f86

Update: added handwritten extraction and product list storage

Browse files
Files changed (1) hide show
  1. app.py +66 -13
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
  from PIL import Image
5
  import io
6
  import json
@@ -22,6 +22,9 @@ from torch.utils.tensorboard import SummaryWriter
22
  import matplotlib.pyplot as plt
23
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
24
  import matplotlib
 
 
 
25
  matplotlib.use('Agg')
26
 
27
  # Configure logging
@@ -94,32 +97,45 @@ def extract_fields(image_path):
94
  return results
95
 
96
  def extract_products(text):
97
- # This pattern matches lines like: "1076903 PISTACHIO 14.49"
98
- product_pattern = r"\d{6,} ([A-Z0-9 ]+) (\d+\.\d{2})"
 
99
  matches = re.findall(product_pattern, text)
100
- products = [{"name": name.strip(), "price": float(price)} for name, price in matches]
 
 
 
 
 
 
 
 
 
 
 
101
  return products
102
 
103
  def extract_with_perplexity_llm(ocr_text):
104
  prompt = f"""
105
  You are an expert at extracting structured data from receipts.
106
 
107
- From the following OCR text, extract these fields and return them as a flat JSON object with exactly these keys:
108
  - name (customer name)
109
  - date (date of purchase)
110
- - amount_paid (total amount paid, or price if only one product)
111
  - receipt_no (receipt number)
112
- - product (the main product name, as a string; if multiple products, pick the most expensive or the only one)
113
-
114
- **Note:** If the receipt has only one product, set 'product' to its name and 'amount_paid' to its price. If there is a 'price' and an 'amount paid', treat them as the same if they are equal.
115
 
116
  Example output:
117
  {{
118
  "name": "Mrs. Genevieve Lopez",
119
  "date": "12/13/2024",
120
- "amount_paid": 579.18,
121
  "receipt_no": "042085",
122
- "product": "Wireless Airpods"
 
 
 
123
  }}
124
 
125
  Text:
@@ -142,9 +158,32 @@ Text:
142
  )
143
  return response.choices[0].message.content
144
 
 
 
 
 
 
 
 
 
 
 
145
  def save_to_dynamodb(data, table_name="Receipts"):
146
- # ... existing code ...
147
- # data["products"] is a list of dicts
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  table.put_item(Item=data)
149
 
150
  def merge_extractions(regex_fields, llm_fields):
@@ -312,5 +351,19 @@ def main():
312
  else:
313
  st.info("Confusion matrix not found.")
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  if __name__ == "__main__":
316
  main()
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification, TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
5
  import io
6
  import json
 
22
  import matplotlib.pyplot as plt
23
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
24
  import matplotlib
25
+ import boto3
26
+ from decimal import Decimal
27
+ import uuid
28
  matplotlib.use('Agg')
29
 
30
  # Configure logging
 
97
  return results
98
 
99
  def extract_products(text):
100
+ # Pattern to match product lines with quantity, name, and price
101
+ # Example: "2 PISTACHIO 14.49" or "1076903 PISTACHIO 14.49"
102
+ product_pattern = r"(?:(\d+)\s+)?([A-Z0-9 ]+)\s+(\d+\.\d{2})"
103
  matches = re.findall(product_pattern, text)
104
+
105
+ products = []
106
+ for match in matches:
107
+ quantity, name, price = match
108
+ product = {
109
+ "name": name.strip(),
110
+ "price": float(price),
111
+ "quantity": int(quantity) if quantity else 1,
112
+ "total": float(price) * (int(quantity) if quantity else 1)
113
+ }
114
+ products.append(product)
115
+
116
  return products
117
 
118
  def extract_with_perplexity_llm(ocr_text):
119
  prompt = f"""
120
  You are an expert at extracting structured data from receipts.
121
 
122
+ From the following OCR text, extract these fields and return them as a JSON object with exactly these keys:
123
  - name (customer name)
124
  - date (date of purchase)
125
+ - amount_paid (total amount paid)
126
  - receipt_no (receipt number)
127
+ - products (a list of all products, each with name, price, and quantity if available)
 
 
128
 
129
  Example output:
130
  {{
131
  "name": "Mrs. Genevieve Lopez",
132
  "date": "12/13/2024",
133
+ "amount_paid": 29.69,
134
  "receipt_no": "042085",
135
+ "products": [
136
+ {{"name": "Orange Juice", "price": 2.15, "quantity": 1}},
137
+ {{"name": "Apples", "price": 3.50, "quantity": 1}}
138
+ ]
139
  }}
140
 
141
  Text:
 
158
  )
159
  return response.choices[0].message.content
160
 
161
+ def convert_floats_to_decimal(obj):
162
+ if isinstance(obj, float):
163
+ return Decimal(str(obj))
164
+ elif isinstance(obj, dict):
165
+ return {k: convert_floats_to_decimal(v) for k, v in obj.items()}
166
+ elif isinstance(obj, list):
167
+ return [convert_floats_to_decimal(i) for i in obj]
168
+ else:
169
+ return obj
170
+
171
  def save_to_dynamodb(data, table_name="Receipts"):
172
+ dynamodb = boto3.resource('dynamodb')
173
+ table = dynamodb.Table(table_name)
174
+
175
+ # Calculate total amount if not provided
176
+ if "products" in data and not data.get("amount_paid"):
177
+ total = sum(product["total"] for product in data["products"])
178
+ data["amount_paid"] = total
179
+
180
+ # Convert all float values to Decimal for DynamoDB
181
+ data = convert_floats_to_decimal(data)
182
+
183
+ # Generate receipt number if not present
184
+ if not data.get("receipt_no"):
185
+ data["receipt_no"] = str(uuid.uuid4())
186
+
187
  table.put_item(Item=data)
188
 
189
  def merge_extractions(regex_fields, llm_fields):
 
351
  else:
352
  st.info("Confusion matrix not found.")
353
 
354
+ # Load model and processor
355
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
356
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
357
+
358
+ # Load your image (crop to handwritten region if possible)
359
+ image = Image.open('handwritten_sample.jpg').convert("RGB")
360
+
361
+ # Preprocess and predict
362
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
363
+ generated_ids = model.generate(pixel_values)
364
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
365
+
366
+ print("Handwritten text:", generated_text)
367
+
368
  if __name__ == "__main__":
369
  main()