Spaces:
Running
Running
# main.py | |
import io | |
import os | |
from ctypes import c_int, pointer, string_at | |
from datetime import datetime | |
from typing import List | |
import cv2 | |
import numpy as np | |
from fastapi import FastAPI, HTTPException, Request, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from paddleocr import PaddleOCR | |
from PIL import Image, ImageEnhance | |
import dds | |
from app import ( | |
DEFAULT_THRESHOLDS, | |
arrange_data, | |
format_dds_data, | |
get_player_regions, | |
validate_deal, | |
) | |
from identify_cards import ( | |
SUIT_TEMPLATE_PATH, | |
determine_and_correct_orientation, | |
find_rank_candidates, | |
get_not_white_mask, | |
get_suit_from_image_rules, | |
load_suit_templates, | |
save_img_with_rect, | |
) | |
from utils import convert2pbn, convert2pbn_txt, is_text_valid | |
# from app import arrange_data, run_dds_analysis # Gradioのapp.pyからロジックを移植 | |
# FastAPIインスタンスを作成 | |
app = FastAPI() | |
origins = [ | |
"http://localhost", | |
"http://localhost:5173", # Default URL for Vite React dev server | |
"https://board-recognizer-30ib6veo9-wai572s-projects.vercel.app", # Your deployed frontend | |
"https://board-recognizer.vercel.app", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # Specifies the allowed origins | |
allow_credentials=True, # Allows cookies to be included in requests | |
allow_methods=["*"], # Allows all methods (GET, POST, etc.) | |
allow_headers=["*"], # Allows all headers | |
) | |
# AIモデルとテンプレートを起動時に読み込む | |
# trocr_pipeline = None # load_model()のロジックをここに | |
reader = None | |
suit_templates = None | |
def load_ocr_model(): | |
""" | |
アプリケーション起動時に一度だけEasyOCRのモデルを読み込み、 | |
グローバル変数readerに格納する。 | |
""" | |
global reader | |
# 使用する言語と、モデルの保存先ディレクトリを指定してReaderを初期化 | |
reader = PaddleOCR( | |
lang="en", | |
use_doc_orientation_classify=False, | |
use_doc_unwarping=False, | |
use_textline_orientation=False, | |
) | |
print(f"PaddleOCR model loaded successfully.") | |
# @app.on_event("startup") | |
# def load_dependencies(): | |
# global trocr_pipeline, suit_templates | |
# # TrOCRモデルをロード (Gradioのload_model関数を参考) | |
# from transformers import pipeline | |
# try: | |
# print("Loading TrOCR model...") | |
# trocr_pipeline = pipeline( | |
# "image-to-text", model="microsoft/trocr-small-printed" | |
# ) | |
# print("TrOCR model loaded.") | |
# except Exception as e: | |
# print(f"Failed to load TrOCR model: {e}") | |
# trocr_pipeline = None | |
# # スートテンプレートをロード | |
# suit_templates = load_suit_templates("templates/suits/") | |
async def analyze_image(image_paths: list[UploadFile]): | |
# print(request) | |
# print(list(request.keys())) | |
# image_paths = request["image_paths"] | |
print(image_paths) | |
progress = lambda x, desc: print(x, desc) | |
global reader | |
# global trocr_pipeline | |
# # モデルが読み込まれているか確認 | |
if reader is None: | |
print( | |
"AIモデルがまだ読み込まれていません。しばらく待ってから再度お試しください。" | |
) | |
# 空の更新を返すことで、UIの状態を変えずに処理を終了 | |
return | |
all_results = [] | |
num_total_files = len(image_paths) | |
progress(0, desc="テンプレート画像読み込み中...") | |
suit_templates = load_suit_templates(SUIT_TEMPLATE_PATH) | |
if not suit_templates: | |
raise ( | |
f"エラー: {SUIT_TEMPLATE_PATH} フォルダにスートのテンプレート画像が見つかりません。" | |
) | |
try: | |
all_candidates_global = [] | |
processed_files_info = [] | |
# image_objects = {} | |
for i, image_path in enumerate(image_paths): | |
progress( | |
(i + 1) / num_total_files * 0.15, | |
desc="ステージ1/3: 文字候補を検出中...", | |
) | |
filename = os.path.basename(image_path.filename) | |
progress( | |
(i + 1) / num_total_files * 0.3, | |
f"分析中 ({i+1}/{num_total_files}): {filename}", | |
) | |
try: | |
# ファイルをバイナリモードで安全に読み込む | |
# file_bytes = np.asarray(bytearray(image_path)) | |
print(image_path.file) | |
# file_bytes = np.asarray(bytearray(image_path.file)) | |
# with open(image_path.filename, "rb") as f: | |
# # バイトデータをNumPy配列に変換 | |
# file_bytes = np.asarray( | |
# bytearray(f.read()), dtype=np.uint8 | |
# ) | |
# NumPy配列(メモリ上のデータ)から画像をデコード | |
file_bytes = await image_path.read() | |
# print("file", file_bytes) | |
file_array = np.asarray(bytearray(file_bytes), dtype=np.uint8) | |
image = cv2.imdecode(file_array, cv2.IMREAD_COLOR) | |
# image = image_path.file | |
if image is None: | |
raise ( | |
"OpenCVが画像をデコードできませんでした。ファイルが破損しているか、非対応の形式の可能性があります。" | |
) | |
# image_objects[filename] = image | |
except Exception as e: | |
# ファイル読み込み自体のエラーをキャッチ | |
print(e) | |
all_results.append( | |
{"filename": filename, "error": f"画像読み込みエラー: {e}"} | |
) | |
# image_objects[filename] = None | |
continue | |
# box = find_center_box(image) | |
print("detect board") | |
rotated_image, box, scale = determine_and_correct_orientation( | |
image, lambda msg: print(msg) | |
) | |
if box is None: | |
all_results.append( | |
{"filename": filename, "error": "中央ボードの検出に失敗"} | |
) | |
continue | |
print(box) | |
save_img_with_rect("debug_rotated.jpg", rotated_image, [box]) | |
MARGIN = 200 | |
player_regions = get_player_regions(rotated_image, box, MARGIN) | |
for player, region in player_regions.items(): | |
candidates = find_rank_candidates( | |
region, suit_templates, player, scale | |
) | |
for cand in candidates: | |
cand["filename"] = filename | |
cand["player"] = player | |
all_candidates_global.append(cand) | |
processed_files_info.append({"filename": filename, "error": None}) | |
progress( | |
0.4, desc="ステージ2/3: 文字認識を実行中... (時間がかかります)" | |
) | |
if not all_candidates_global: | |
progress(1, desc="認識する文字候補がありませんでした。") | |
print("認識する文字候補がありませんでした。") | |
return JSONResponse( | |
content=all_results | |
) # エラーがあった画像の結果だけを返す | |
try: | |
ocr_results = [] | |
for candidate in all_candidates_global: | |
# img = Image.fromarray( | |
# cv2.cvtColor(candidate["img"], cv2.COLOR_BGR2RGB) | |
# ) | |
img = candidate["img"] | |
pil_img = Image.fromarray(img) | |
enhancer = ImageEnhance.Contrast(pil_img) | |
im_con = enhancer.enhance(2.0) | |
np_img = np.asarray(im_con) | |
# text_mask = get_not_white_mask(img) | |
# masked_img = cv2.bitwise_and(img, img, mask=text_mask) | |
# result = reader.readtext(candidate["img"]) | |
result = reader.predict(np_img) | |
print(result) | |
if len(result) > 0: | |
ocr_results.append(result) | |
# candidates_pil_images = [ | |
# Image.fromarray(cv2.cvtColor(c["img"], cv2.COLOR_BGR2RGB)) | |
# for c in all_candidates_global | |
# ] | |
# ocr_results = trocr_pipeline(candidates_pil_images) | |
except Exception as e: | |
print(f"OCR処理中にエラーが発生しました: {e}") | |
# --- ステージ3: 結果の仕分けと最終的なカードの特定 --- | |
progress(0.9, desc="ステージ3/3: 認識結果を仕分け中...") | |
# print([result[0]["generated_text"] for result in ocr_results]) | |
raw_data = [] | |
# blacks = [] | |
# reds = [] | |
for i, result in enumerate(ocr_results): | |
# text = result[0]["generated_text"].upper().strip() | |
text = result[0]["rec_texts"] | |
if len(text) > 0: | |
text = text[0] | |
else: | |
text = "" | |
print(text, is_text_valid(text)) | |
text = is_text_valid(text) | |
if text is not None: | |
candidate_info = all_candidates_global[i] | |
print( | |
f"--- 診断中: ランク '{text}' of {candidate_info['player']} at {candidate_info['pos']} with thick:{candidate_info['thickness']} ---" | |
) | |
color_name, avg_lab = get_suit_from_image_rules( | |
candidate_info["no_pad"], DEFAULT_THRESHOLDS | |
) | |
print(color_name) | |
if color_name == "mark": | |
continue | |
candidate_info["avg_lab"] = avg_lab | |
candidate_info["color"] = color_name | |
candidate_info["name"] = text | |
raw_data.append(candidate_info) | |
# print("\r\n".join(blacks)) | |
# print("\r\n".join(reds)) | |
all_results = arrange_data(raw_data) | |
# pbn_content = convert2pbn(all_results) | |
# pbn_filename = f"analysis_{datetime.now().strftime('%Y%m%d')}.pbn" | |
# if processed_files_info: | |
# last_result = {"filename": processed_files_info[0]["filename"], 1ands": all_results[0][1ands"]} | |
# if all_results: | |
# # ダウンロード用にPBNコンテンツを値として設定し、表示状態にする | |
# export_update = gr.update(interactive=True) | |
# else: | |
# export_update = gr.update(interactive=False) | |
# final_result = all_results[0]["hands"] | |
# filenames = [os.path.basename(p) for p in image_paths] | |
# dropdown_update = gr.update( | |
# choices=filenames, value=filenames[0], interactive=True, open=True | |
# ) | |
dataframes = run_dds_analysis(all_results) | |
for result in all_results: | |
if result["filename"] in dataframes.keys(): | |
result["dds"] = dataframes[result["filename"]] | |
return JSONResponse(content=all_results) | |
except Exception as e: | |
raise (f"致命的なエラー: {e}") | |
def run_dds_analysis(all_results_state): | |
"""ダブルダミー分析を実行する""" | |
valid_deals = [] | |
for result in all_results_state: | |
if "hands" in result: | |
print(result["hands"]) | |
is_valid, _ = validate_deal(result["hands"]) | |
if is_valid: | |
valid_deals.append(result) | |
if len(valid_deals) == 0: | |
raise ("分析不可", "分析対象となる正常なディールがありません。") | |
# self.status_var.set(f"{len(valid_deals)}件のディールを分析中...") | |
try: | |
deals = dds.ddTableDealsPBN() | |
deals.noOfTables = len(valid_deals) | |
for i, result in enumerate(valid_deals): | |
pbn_deal_string = convert2pbn_txt(result["hands"], "N") | |
print(pbn_deal_string) | |
# table_deal_pbn = dds.ddTableDealPBN() | |
# table_deal_pbn.cards = pbn_deal_string.encode("utf-8") | |
deals.deals[i].cards = pbn_deal_string.encode("utf-8") | |
dds.SetMaxThreads(0) | |
table_res = dds.ddTablesRes() | |
per_res = dds.allParResults() | |
# table_res_pointer = pointer(table_res) | |
res = dds.CalcAllTablesPBN( | |
pointer(deals), | |
0, | |
(c_int * 5)(0, 0, 0, 0, 0), | |
pointer(table_res), | |
pointer(per_res), | |
) | |
print("dds") | |
if res != dds.RETURN_NO_FAULT: | |
err_char_p = dds.ErrorMessage(res) | |
err_string = ( | |
string_at(err_char_p).decode("utf-8") | |
if err_char_p | |
else "Unknown error" | |
) | |
raise RuntimeError( | |
f"DDS Solver failed with code: {res} ({err_string})" | |
) | |
print("dds") | |
filenames = [d["filename"] for d in valid_deals] | |
dataframes = {} | |
for i, filename in enumerate(filenames): | |
headers, rows = format_dds_data(table_res.results[i].resTable) | |
print(rows) | |
dataframes[filename] = rows | |
return dataframes | |
# 3. 結果を新しいウィンドウで表示 | |
except Exception as e: | |
raise (f"DDS分析エラー: 分析中にエラーが発生しました:\n{e}") | |
# self.status_var.set("DDS分析中にエラーが発生しました。") | |