wai572's picture
Merge commit '04b24527b550f02d4cac5bfb7b811a2e45b5f9aa'
4580319
import math
import os
import tempfile
from ctypes import c_int, c_uint, pointer, string_at
from datetime import datetime
import cv2
import gradio as gr
import numpy as np
import pandas as pd
from gradio_modal import Modal
from PIL import Image
# from transformers import pipeline
import dds
from identify_cards import (
SUIT_TEMPLATE_PATH,
determine_and_correct_orientation,
find_rank_candidates,
get_suit_from_image_rules,
load_suit_templates,
save_img_with_rect,
)
from utils import (
arrange_hand,
convert2dup,
convert2pbn,
convert2pbn_board,
convert2pbn_txt,
convert2xhd,
is_text_valid,
)
# --- グローバル変数・設定 ---
trocr_pipeline = None
SUITS_BY_COLOR = {"black": "S", "green": "C", "red": "H", "orange": "D"}
VALID_RANKS = ["A", "K", "Q", "J", "T", "9", "8", "7", "6", "5", "4", "3", "2"]
PLAYER_ORDER = ["north", "east", "south", "west"]
SUIT_ORDER = {"S": 0, "H": 1, "D": 2, "C": 3}
RANK_ORDER = {
"A": 14,
"K": 13,
"Q": 12,
"J": 11,
"T": 10,
"9": 9,
"8": 8,
"7": 7,
"6": 6,
"5": 5,
"4": 4,
"3": 3,
"2": 2,
}
DEFAULT_THRESHOLDS = {
"L_black": 65.0,
"a_green": 126.0,
"a_red": 134.0,
"ba_black": -4.5,
"ab_black": 250.0,
"a_b_red": 9.0,
}
def load_model():
"""
TrOCRモデルをバックグラウンドで読み込む関数。
UIのロード完了後に demo.load() イベントで呼び出される。
"""
global trocr_pipeline
try:
if trocr_pipeline is None:
print("バックグラウンドでTrOCRモデルを読み込んでいます...")
trocr_pipeline = pipeline(
"image-to-text", model="microsoft/trocr-small-printed"
)
print("TrOCRの準備が完了しました。")
# UIコンポーネントを更新するための値を返す
return gr.update(
value="準備完了。画像を選択して分析を開始してください。"
), gr.update(interactive=True)
except Exception as e:
error_message = f"AIモデルの読み込みエラー: {e}"
print(error_message)
gr.Warning(
f"AIモデルの読み込みに失敗しました。分析機能は利用できません。詳細はログを確認してください。"
)
# エラーメッセージを表示し、分析ボタンは無効のままにする
return gr.update(value=error_message), gr.update(interactive=False)
def get_player_regions(img, box, margin):
bx, by, bw, bh = box
h, w, _ = img.shape
player_regions = {
"north": img[0:by, :],
"south": img[by + bh : h, :],
"west": img[by - margin : by + bh + margin, 0:bx],
"east": img[by - margin : by + bh + margin, bx + bw : w],
}
for player, region in player_regions.items():
if region is not None and region.size > 0:
if player == "north":
player_regions[player] = cv2.rotate(region, cv2.ROTATE_180)
elif player == "east":
player_regions[player] = cv2.rotate(
region, cv2.ROTATE_90_CLOCKWISE
)
elif player == "west":
player_regions[player] = cv2.rotate(
region, cv2.ROTATE_90_COUNTERCLOCKWISE
)
return player_regions
def arrange_data(raw_rank_data):
# 生データから最終的な手札を作成・表示
all_result = []
temp_hands = {} # ファイルごとの手札を一時保存
for rank_data in raw_rank_data:
filename = rank_data["filename"]
if filename not in temp_hands:
temp_hands[filename] = {p: [] for p in PLAYER_ORDER}
color_name = rank_data["color"]
suit = SUITS_BY_COLOR[color_name]
card_name = f"{suit}{rank_data['name']}"
temp_hands[filename][rank_data["player"]].append(card_name)
# 整形してall_resultsに格納
for filename, hands in temp_hands.items():
all_result.append(
{
"filename": filename,
"hands": {
player: arrange_hand(cards)
for player, cards in hands.items()
},
}
)
return all_result
def analyze_image_gradio(image_paths, progress=gr.Progress()):
global trocr_pipeline
# モデルが読み込まれているか確認
if trocr_pipeline is None:
gr.Warning(
"AIモデルがまだ読み込まれていません。しばらく待ってから再度お試しください。"
)
# 空の更新を返すことで、UIの状態を変えずに処理を終了
return (gr.update(),) * 11
all_results = []
num_total_files = len(image_paths)
progress(0, desc="テンプレート画像読み込み中...")
suit_templates = load_suit_templates(SUIT_TEMPLATE_PATH)
if not suit_templates:
raise gr.Error(
f"エラー: {SUIT_TEMPLATE_PATH} フォルダにスートのテンプレート画像が見つかりません。"
)
try:
all_candidates_global = []
processed_files_info = []
# image_objects = {}
for i, image_path in enumerate(image_paths):
progress(
(i + 1) / num_total_files * 0.15,
desc="ステージ1/3: 文字候補を検出中...",
)
filename = os.path.basename(image_path)
progress(
(i + 1) / num_total_files * 0.3,
f"分析中 ({i+1}/{num_total_files}): {filename}",
)
try:
# ファイルをバイナリモードで安全に読み込む
with open(image_path, "rb") as f:
# バイトデータをNumPy配列に変換
file_bytes = np.asarray(
bytearray(f.read()), dtype=np.uint8
)
# NumPy配列(メモリ上のデータ)から画像をデコード
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if image is None:
raise gr.Error(
"OpenCVが画像をデコードできませんでした。ファイルが破損しているか、非対応の形式の可能性があります。"
)
# image_objects[filename] = image
except Exception as e:
# ファイル読み込み自体のエラーをキャッチ
all_results.append(
{"filename": filename, "error": f"画像読み込みエラー: {e}"}
)
# image_objects[filename] = None
continue
# box = find_center_box(image)
print("detect board")
rotated_image, box, scale = determine_and_correct_orientation(
image, lambda msg: print(msg)
)
if box is None:
all_results.append(
{"filename": filename, "error": "中央ボードの検出に失敗"}
)
continue
print(box)
save_img_with_rect("debug_rotated.jpg", rotated_image, [box])
MARGIN = 200
player_regions = get_player_regions(rotated_image, box, MARGIN)
for player, region in player_regions.items():
candidates = find_rank_candidates(
region, suit_templates, player, scale
)
for cand in candidates:
cand["filename"] = filename
cand["player"] = player
all_candidates_global.append(cand)
processed_files_info.append({"filename": filename, "error": None})
progress(
0.4, desc="ステージ2/3: 文字認識を実行中... (時間がかかります)"
)
if not all_candidates_global or not trocr_pipeline:
progress(1, desc="認識する文字候補がありませんでした。")
print("認識する文字候補がありませんでした。")
return all_results # エラーがあった画像の結果だけを返す
try:
candidates_pil_images = [
Image.fromarray(cv2.cvtColor(c["img"], cv2.COLOR_BGR2RGB))
for c in all_candidates_global
]
ocr_results = trocr_pipeline(candidates_pil_images)
except Exception as e:
gr.Warning(f"OCR処理中にエラーが発生しました: {e}")
# --- ステージ3: 結果の仕分けと最終的なカードの特定 ---
progress(0.9, desc="ステージ3/3: 認識結果を仕分け中...")
print([result[0]["generated_text"] for result in ocr_results])
raw_data = []
# blacks = []
# reds = []
for i, result in enumerate(ocr_results):
text = result[0]["generated_text"].upper().strip()
print(text, is_text_valid(text))
text = is_text_valid(text)
if text is not None:
candidate_info = all_candidates_global[i]
print(
f"--- 診断中: ランク '{text}' of {candidate_info['player']} at {candidate_info['pos']} with thick:{candidate_info['thickness']} ---"
)
color_name, avg_lab = get_suit_from_image_rules(
candidate_info["no_pad"], DEFAULT_THRESHOLDS
)
print(color_name)
if color_name == "mark":
continue
candidate_info["avg_lab"] = avg_lab
candidate_info["color"] = color_name
candidate_info["name"] = text
raw_data.append(candidate_info)
# print("\r\n".join(blacks))
# print("\r\n".join(reds))
all_results = arrange_data(raw_data)
pbn_content = convert2pbn(all_results)
pbn_filename = f"analysis_{datetime.now().strftime('%Y%m%d')}.pbn"
# if processed_files_info:
# last_result = {"filename": processed_files_info[0]["filename"], 1ands": all_results[0][1ands"]}
if all_results:
# ダウンロード用にPBNコンテンツを値として設定し、表示状態にする
export_update = gr.update(interactive=True)
else:
export_update = gr.update(interactive=False)
final_result = all_results[0]["hands"]
filenames = [os.path.basename(p) for p in image_paths]
dropdown_update = gr.update(
choices=filenames, value=filenames[0], interactive=True, open=True
)
dataframes = run_dds_analysis(all_results, progress)
for result in all_results:
if result["filename"] in dataframes.keys():
result["dds"] = dataframes[result["filename"]]
return (
*display_selected_result(filenames[0], all_results),
all_results,
dropdown_update,
export_update,
)
except Exception as e:
raise gr.Error(f"致命的なエラー: {e}")
def display_selected_result(selected_filename, all_results):
"""ドロップダウンで選択されたファイルの結果を表示する"""
result = next(
(r for r in all_results if r["filename"] == selected_filename), None
)
output_hands = {p: "" for p in PLAYER_ORDER}
dds_df_update = gr.update(value=None)
if result and "hands" in result and result["hands"]:
for player, hand in result["hands"].items():
output_hands[player] = ", ".join(hand) if hand else "(なし)"
dds_visible = True
dds_df_update = gr.update(value=result.get("dds"), visible=True)
elif result and "error" in result:
# エラーがあった場合、最初のTextboxにエラーメッセージを表示
output_hands["north"] = f"エラー: {result['error']}"
return (
output_hands["north"],
output_hands["south"],
output_hands["west"],
output_hands["east"],
dds_df_update,
)
def validate_deal(hands):
if not hands:
return False, "分析対象のカードデータがありません"
total_cards = []
# 手札が合計52枚あるかチェック
for player, hand in hands.items():
if len(hand) != 13:
return (
False,
f"エラー: {player.capitalize()}の手札が13枚ではありません",
)
for card in hand:
if card in total_cards:
return (
False,
f"エラー: 重複したカードが検出されました ({card})",
)
total_cards.append(card)
return True, "デックは正常です"
def format_dds_data(table):
headers = [
"Declarer",
"NT",
"Spades ♠",
"Hearts ♥",
"Diamonds ♦",
"Clubs ♣",
]
rows = [
[
"North",
table[4][0],
table[0][0],
table[1][0],
table[2][0],
table[3][0],
],
[
"South",
table[4][2],
table[0][2],
table[1][2],
table[2][2],
table[3][2],
],
[
"East",
table[4][1],
table[0][1],
table[1][1],
table[2][1],
table[3][1],
],
[
"West",
table[4][3],
table[0][3],
table[1][3],
table[2][3],
table[3][3],
],
]
return headers, rows
def run_dds_analysis(all_results_state, progress=gr.Progress()):
"""ダブルダミー分析を実行する"""
valid_deals = []
for result in all_results_state:
if "hands" in result:
is_valid, _ = validate_deal(result["hands"])
if is_valid:
valid_deals.append(result)
if len(valid_deals) == 0:
raise gr.Error(
"分析不可", "分析対象となる正常なディールがありません。"
)
# self.status_var.set(f"{len(valid_deals)}件のディールを分析中...")
try:
deals = dds.ddTableDealsPBN()
deals.noOfTables = len(valid_deals)
for i, result in enumerate(valid_deals):
pbn_deal_string = convert2pbn_txt(result["hands"], "N")
print(pbn_deal_string)
# table_deal_pbn = dds.ddTableDealPBN()
# table_deal_pbn.cards = pbn_deal_string.encode("utf-8")
deals.deals[i].cards = pbn_deal_string.encode("utf-8")
dds.SetMaxThreads(0)
table_res = dds.ddTablesRes()
per_res = dds.allParResults()
# table_res_pointer = pointer(table_res)
res = dds.CalcAllTablesPBN(
pointer(deals),
0,
(c_int * 5)(0, 0, 0, 0, 0),
pointer(table_res),
pointer(per_res),
)
print("dds")
if res != dds.RETURN_NO_FAULT:
err_char_p = dds.ErrorMessage(res)
err_string = (
string_at(err_char_p).decode("utf-8")
if err_char_p
else "Unknown error"
)
raise RuntimeError(
f"DDS Solver failed with code: {res} ({err_string})"
)
print("dds")
filenames = [d["filename"] for d in valid_deals]
dataframes = {}
for i, filename in enumerate(filenames):
headers, rows = format_dds_data(table_res.results[i].resTable)
print(rows)
dataframes[filename] = pd.DataFrame(rows, columns=headers)
return dataframes
# 3. 結果を新しいウィンドウで表示
except Exception as e:
raise gr.Error(f"DDS分析エラー: 分析中にエラーが発生しました:\n{e}")
# self.status_var.set("DDS分析中にエラーが発生しました。")
def prepare_export_files(all_results):
"""エクスポートボタンが押されたときに各形式のファイルを生成し、ダウンロードボタンを返す"""
if not all_results:
gr.Warning("エクスポート対象のデータがありません。")
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
)
filename_base, _ = os.path.splitext(all_results[0]["filename"])
# --- PBN ---
pbn_content = convert2pbn(all_results)
with tempfile.NamedTemporaryFile(
delete=False, mode="w", suffix=".pbn", encoding="utf-8"
) as f:
f.write(pbn_content)
pbn_path = f.name
# --- XHD ---
xhd_content = convert2xhd(all_results, filename_base)
with tempfile.NamedTemporaryFile(
delete=False,
mode="w",
suffix=".xhd",
encoding="shift_jis",
errors="ignore",
) as f:
f.write(xhd_content)
xhd_path = f.name
# --- DUP ---
dup_content = convert2dup(all_results, None)
with tempfile.NamedTemporaryFile(
delete=False, mode="w", suffix=".dup", encoding="utf-8"
) as f:
f.write(dup_content)
dup_path = f.name
return (
gr.update(value=pbn_path, visible=True),
gr.update(value=xhd_path, visible=True),
gr.update(value=dup_path, visible=True),
gr.update(visible=True), # モーダルを表示
)
# --- Gradio UIの定義 ---
with gr.Blocks(
theme=gr.themes.Soft(), css="footer {visibility: hidden}"
) as demo:
# 状態を保持するための非表示コンポーネント
all_results_state = gr.State([])
raw_data_state = gr.State([])
current_result_state = gr.State(None)
gr.Markdown("# Bridge Card Recognizer")
gr.Markdown(
"カメラで撮影したブリッジのプレイ中の写真から、各プレイヤーの手札を自動で認識します。"
)
with gr.Row():
with gr.Column(scale=2):
image_input = gr.File(
label="画像ファイルを選択",
file_count="multiple",
file_types=["image"],
type="filepath",
)
analyze_button = gr.Button(
"分析開始", variant="primary", interactive=False
)
export_button = gr.Button(
"結果をエクスポート", interactive=False
) # 新しいエクスポートボタン
status_label = gr.Label(
label="ステータス",
value="準備完了。画像を選択して分析を開始してください。",
)
if not trocr_pipeline:
gr.Warning(
"OCRモデルの読み込みに失敗しました。分析機能は利用できません。"
)
with gr.Column(scale=3):
results_dropdown = gr.Dropdown(
label="表示するファイルを選択", interactive=False
)
gr.Markdown("### 認識結果")
with gr.Row():
north_box = gr.Textbox(label="North", interactive=False)
south_box = gr.Textbox(label="South", interactive=False)
with gr.Row():
west_box = gr.Textbox(label="West", interactive=False)
east_box = gr.Textbox(label="East", interactive=False)
with gr.Row():
# dds_button = gr.Button("ダブルダミー分析 (DDS)", visible=False)
# debugger_button = gr.Button("カラーデバッガー", visible=False)
export_file = gr.File(
label="ダウンロード", visible=False, interactive=False
)
with gr.Accordion("ダブルダミー分析 結果", open=False):
dds_output_df = gr.DataFrame(
label="最適プレイ手数", visible=False
)
# エクスポート用モーダル
with Modal(visible=False) as export_modal:
gr.Markdown("### エクスポート形式を選択してください")
gr.Markdown(
"ボタンをクリックすると、対応する形式のファイルがダウンロードされます。"
)
with gr.Row():
pbn_dl_btn = gr.DownloadButton("PBN形式 (.pbn)", variant="primary")
xhd_dl_btn = gr.DownloadButton(
"XHD形式 (.xhd)", variant="secondary"
)
dup_dl_btn = gr.DownloadButton(
"DUP形式 (.dup)", variant="secondary"
)
# --- イベントリスナー ---
demo.load(
fn=load_model, inputs=None, outputs=[status_label, analyze_button]
)
analyze_button.click(
fn=analyze_image_gradio,
inputs=[image_input],
outputs=[
north_box,
south_box,
west_box,
east_box,
dds_output_df,
all_results_state,
results_dropdown,
export_button,
],
)
results_dropdown.change(
fn=display_selected_result,
inputs=[results_dropdown, all_results_state],
outputs=[north_box, south_box, west_box, east_box, dds_output_df],
)
export_button.click(
fn=prepare_export_files,
inputs=[all_results_state],
outputs=[pbn_dl_btn, xhd_dl_btn, dup_dl_btn, export_modal],
)
# dds_button.click(
# fn=run_dds_analysis,
# inputs=[all_results_state],
# outputs=[dds_output_df],
# )
if __name__ == "__main__":
demo.launch(debug=True)