board-recognizer / main.py
wai572's picture
init
27db1bc
raw
history blame
10.9 kB
# main.py
import io
import os
from ctypes import c_int, pointer, string_at
from datetime import datetime
import cv2
import numpy as np
from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import dds
from app import (
DEFAULT_THRESHOLDS,
arrange_data,
format_dds_data,
get_player_regions,
validate_deal,
)
from identify_cards import (
SUIT_TEMPLATE_PATH,
determine_and_correct_orientation,
find_rank_candidates,
get_suit_from_image_rules,
load_suit_templates,
save_img_with_rect,
)
from utils import convert2pbn, convert2pbn_txt, is_text_valid
# from app import arrange_data, run_dds_analysis # Gradioのapp.pyからロジックを移植
# FastAPIインスタンスを作成
app = FastAPI()
# AIモデルとテンプレートを起動時に読み込む
trocr_pipeline = None # load_model()のロジックをここに
suit_templates = None
@app.on_event("startup")
def load_dependencies():
global trocr_pipeline, suit_templates
# TrOCRモデルをロード (Gradioのload_model関数を参考)
from transformers import pipeline
try:
print("Loading TrOCR model...")
trocr_pipeline = pipeline(
"image-to-text", model="microsoft/trocr-small-printed"
)
print("TrOCR model loaded.")
except Exception as e:
print(f"Failed to load TrOCR model: {e}")
trocr_pipeline = None
# スートテンプレートをロード
suit_templates = load_suit_templates("templates/suits/")
@app.post("/analyze/")
async def analyze_image(image_paths, progress):
global trocr_pipeline
# モデルが読み込まれているか確認
if trocr_pipeline is None:
print(
"AIモデルがまだ読み込まれていません。しばらく待ってから再度お試しください。"
)
# 空の更新を返すことで、UIの状態を変えずに処理を終了
return
all_results = []
num_total_files = len(image_paths)
progress(0, desc="テンプレート画像読み込み中...")
suit_templates = load_suit_templates(SUIT_TEMPLATE_PATH)
if not suit_templates:
raise (
f"エラー: {SUIT_TEMPLATE_PATH} フォルダにスートのテンプレート画像が見つかりません。"
)
try:
all_candidates_global = []
processed_files_info = []
# image_objects = {}
for i, image_path in enumerate(image_paths):
progress(
(i + 1) / num_total_files * 0.15,
desc="ステージ1/3: 文字候補を検出中...",
)
filename = os.path.basename(image_path)
progress(
(i + 1) / num_total_files * 0.3,
f"分析中 ({i+1}/{num_total_files}): {filename}",
)
try:
# ファイルをバイナリモードで安全に読み込む
with open(image_path, "rb") as f:
# バイトデータをNumPy配列に変換
file_bytes = np.asarray(
bytearray(f.read()), dtype=np.uint8
)
# NumPy配列(メモリ上のデータ)から画像をデコード
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if image is None:
raise (
"OpenCVが画像をデコードできませんでした。ファイルが破損しているか、非対応の形式の可能性があります。"
)
# image_objects[filename] = image
except Exception as e:
# ファイル読み込み自体のエラーをキャッチ
all_results.append(
{"filename": filename, "error": f"画像読み込みエラー: {e}"}
)
# image_objects[filename] = None
continue
# box = find_center_box(image)
print("detect board")
rotated_image, box, scale = determine_and_correct_orientation(
image, lambda msg: print(msg)
)
if box is None:
all_results.append(
{"filename": filename, "error": "中央ボードの検出に失敗"}
)
continue
print(box)
save_img_with_rect("debug_rotated.jpg", rotated_image, [box])
MARGIN = 200
player_regions = get_player_regions(rotated_image, box, MARGIN)
for player, region in player_regions.items():
candidates = find_rank_candidates(
region, suit_templates, player, scale
)
for cand in candidates:
cand["filename"] = filename
cand["player"] = player
all_candidates_global.append(cand)
processed_files_info.append({"filename": filename, "error": None})
progress(
0.4, desc="ステージ2/3: 文字認識を実行中... (時間がかかります)"
)
if not all_candidates_global or not trocr_pipeline:
progress(1, desc="認識する文字候補がありませんでした。")
print("認識する文字候補がありませんでした。")
return all_results # エラーがあった画像の結果だけを返す
try:
candidates_pil_images = [
Image.fromarray(cv2.cvtColor(c["img"], cv2.COLOR_BGR2RGB))
for c in all_candidates_global
]
ocr_results = trocr_pipeline(candidates_pil_images)
except Exception as e:
print(f"OCR処理中にエラーが発生しました: {e}")
# --- ステージ3: 結果の仕分けと最終的なカードの特定 ---
progress(0.9, desc="ステージ3/3: 認識結果を仕分け中...")
print([result[0]["generated_text"] for result in ocr_results])
raw_data = []
# blacks = []
# reds = []
for i, result in enumerate(ocr_results):
text = result[0]["generated_text"].upper().strip()
print(text, is_text_valid(text))
text = is_text_valid(text)
if text is not None:
candidate_info = all_candidates_global[i]
print(
f"--- 診断中: ランク '{text}' of {candidate_info['player']} at {candidate_info['pos']} with thick:{candidate_info['thickness']} ---"
)
color_name, avg_lab = get_suit_from_image_rules(
candidate_info["no_pad"], DEFAULT_THRESHOLDS
)
print(color_name)
if color_name == "mark":
continue
candidate_info["avg_lab"] = avg_lab
candidate_info["color"] = color_name
candidate_info["name"] = text
raw_data.append(candidate_info)
# print("\r\n".join(blacks))
# print("\r\n".join(reds))
all_results = arrange_data(raw_data)
pbn_content = convert2pbn(all_results)
pbn_filename = f"analysis_{datetime.now().strftime('%Y%m%d')}.pbn"
# if processed_files_info:
# last_result = {"filename": processed_files_info[0]["filename"], 1ands": all_results[0][1ands"]}
# if all_results:
# # ダウンロード用にPBNコンテンツを値として設定し、表示状態にする
# export_update = gr.update(interactive=True)
# else:
# export_update = gr.update(interactive=False)
final_result = all_results[0]["hands"]
filenames = [os.path.basename(p) for p in image_paths]
# dropdown_update = gr.update(
# choices=filenames, value=filenames[0], interactive=True, open=True
# )
dataframes = run_dds_analysis(all_results, progress)
for result in all_results:
if result["filename"] in dataframes.keys():
result["dds"] = dataframes[result["filename"]]
return JSONResponse(content=all_results)
except Exception as e:
raise (f"致命的なエラー: {e}")
def run_dds_analysis(all_results_state):
"""ダブルダミー分析を実行する"""
valid_deals = []
for result in all_results_state:
if "hands" in result:
is_valid, _ = validate_deal(result["hands"])
if is_valid:
valid_deals.append(result)
if len(valid_deals) == 0:
raise ("分析不可", "分析対象となる正常なディールがありません。")
# self.status_var.set(f"{len(valid_deals)}件のディールを分析中...")
try:
deals = dds.ddTableDealsPBN()
deals.noOfTables = len(valid_deals)
for i, result in enumerate(valid_deals):
pbn_deal_string = convert2pbn_txt(result["hands"], "N")
print(pbn_deal_string)
# table_deal_pbn = dds.ddTableDealPBN()
# table_deal_pbn.cards = pbn_deal_string.encode("utf-8")
deals.deals[i].cards = pbn_deal_string.encode("utf-8")
dds.SetMaxThreads(0)
table_res = dds.ddTablesRes()
per_res = dds.allParResults()
# table_res_pointer = pointer(table_res)
res = dds.CalcAllTablesPBN(
pointer(deals),
0,
(c_int * 5)(0, 0, 0, 0, 0),
pointer(table_res),
pointer(per_res),
)
print("dds")
if res != dds.RETURN_NO_FAULT:
err_char_p = dds.ErrorMessage(res)
err_string = (
string_at(err_char_p).decode("utf-8")
if err_char_p
else "Unknown error"
)
raise RuntimeError(
f"DDS Solver failed with code: {res} ({err_string})"
)
print("dds")
filenames = [d["filename"] for d in valid_deals]
dataframes = {}
for i, filename in enumerate(filenames):
headers, rows = format_dds_data(table_res.results[i].resTable)
print(rows)
dataframes[filename] = rows
return dataframes
# 3. 結果を新しいウィンドウで表示
except Exception as e:
raise (f"DDS分析エラー: 分析中にエラーが発生しました:\n{e}")
# self.status_var.set("DDS分析中にエラーが発生しました。")