import math import os import tempfile from ctypes import c_int, c_uint, pointer, string_at from datetime import datetime import cv2 import gradio as gr import numpy as np import pandas as pd from gradio_modal import Modal from PIL import Image # from transformers import pipeline import dds from identify_cards import ( SUIT_TEMPLATE_PATH, determine_and_correct_orientation, find_rank_candidates, get_suit_from_image_rules, load_suit_templates, save_img_with_rect, ) from utils import ( arrange_hand, convert2dup, convert2pbn, convert2pbn_board, convert2pbn_txt, convert2xhd, is_text_valid, ) # --- グローバル変数・設定 --- trocr_pipeline = None SUITS_BY_COLOR = {"black": "S", "green": "C", "red": "H", "orange": "D"} VALID_RANKS = ["A", "K", "Q", "J", "T", "9", "8", "7", "6", "5", "4", "3", "2"] PLAYER_ORDER = ["north", "east", "south", "west"] SUIT_ORDER = {"S": 0, "H": 1, "D": 2, "C": 3} RANK_ORDER = { "A": 14, "K": 13, "Q": 12, "J": 11, "T": 10, "9": 9, "8": 8, "7": 7, "6": 6, "5": 5, "4": 4, "3": 3, "2": 2, } DEFAULT_THRESHOLDS = { "L_black": 65.0, "a_green": 126.0, "a_red": 134.0, "ba_black": -4.5, "ab_black": 250.0, "a_b_red": 9.0, } def load_model(): """ TrOCRモデルをバックグラウンドで読み込む関数。 UIのロード完了後に demo.load() イベントで呼び出される。 """ global trocr_pipeline try: if trocr_pipeline is None: print("バックグラウンドでTrOCRモデルを読み込んでいます...") trocr_pipeline = pipeline( "image-to-text", model="microsoft/trocr-small-printed" ) print("TrOCRの準備が完了しました。") # UIコンポーネントを更新するための値を返す return gr.update( value="準備完了。画像を選択して分析を開始してください。" ), gr.update(interactive=True) except Exception as e: error_message = f"AIモデルの読み込みエラー: {e}" print(error_message) gr.Warning( f"AIモデルの読み込みに失敗しました。分析機能は利用できません。詳細はログを確認してください。" ) # エラーメッセージを表示し、分析ボタンは無効のままにする return gr.update(value=error_message), gr.update(interactive=False) def get_player_regions(img, box, margin): bx, by, bw, bh = box h, w, _ = img.shape player_regions = { "north": img[0:by, :], "south": img[by + bh : h, :], "west": img[by - margin : by + bh + margin, 0:bx], "east": img[by - margin : by + bh + margin, bx + bw : w], } for player, region in player_regions.items(): if region is not None and region.size > 0: if player == "north": player_regions[player] = cv2.rotate(region, cv2.ROTATE_180) elif player == "east": player_regions[player] = cv2.rotate( region, cv2.ROTATE_90_CLOCKWISE ) elif player == "west": player_regions[player] = cv2.rotate( region, cv2.ROTATE_90_COUNTERCLOCKWISE ) return player_regions def arrange_data(raw_rank_data): # 生データから最終的な手札を作成・表示 all_result = [] temp_hands = {} # ファイルごとの手札を一時保存 for rank_data in raw_rank_data: filename = rank_data["filename"] if filename not in temp_hands: temp_hands[filename] = {p: [] for p in PLAYER_ORDER} color_name = rank_data["color"] suit = SUITS_BY_COLOR[color_name] card_name = f"{suit}{rank_data['name']}" temp_hands[filename][rank_data["player"]].append(card_name) # 整形してall_resultsに格納 for filename, hands in temp_hands.items(): all_result.append( { "filename": filename, "hands": { player: arrange_hand(cards) for player, cards in hands.items() }, } ) return all_result def analyze_image_gradio(image_paths, progress=gr.Progress()): global trocr_pipeline # モデルが読み込まれているか確認 if trocr_pipeline is None: gr.Warning( "AIモデルがまだ読み込まれていません。しばらく待ってから再度お試しください。" ) # 空の更新を返すことで、UIの状態を変えずに処理を終了 return (gr.update(),) * 11 all_results = [] num_total_files = len(image_paths) progress(0, desc="テンプレート画像読み込み中...") suit_templates = load_suit_templates(SUIT_TEMPLATE_PATH) if not suit_templates: raise gr.Error( f"エラー: {SUIT_TEMPLATE_PATH} フォルダにスートのテンプレート画像が見つかりません。" ) try: all_candidates_global = [] processed_files_info = [] # image_objects = {} for i, image_path in enumerate(image_paths): progress( (i + 1) / num_total_files * 0.15, desc="ステージ1/3: 文字候補を検出中...", ) filename = os.path.basename(image_path) progress( (i + 1) / num_total_files * 0.3, f"分析中 ({i+1}/{num_total_files}): {filename}", ) try: # ファイルをバイナリモードで安全に読み込む with open(image_path, "rb") as f: # バイトデータをNumPy配列に変換 file_bytes = np.asarray( bytearray(f.read()), dtype=np.uint8 ) # NumPy配列(メモリ上のデータ)から画像をデコード image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) if image is None: raise gr.Error( "OpenCVが画像をデコードできませんでした。ファイルが破損しているか、非対応の形式の可能性があります。" ) # image_objects[filename] = image except Exception as e: # ファイル読み込み自体のエラーをキャッチ all_results.append( {"filename": filename, "error": f"画像読み込みエラー: {e}"} ) # image_objects[filename] = None continue # box = find_center_box(image) print("detect board") rotated_image, box, scale = determine_and_correct_orientation( image, lambda msg: print(msg) ) if box is None: all_results.append( {"filename": filename, "error": "中央ボードの検出に失敗"} ) continue print(box) save_img_with_rect("debug_rotated.jpg", rotated_image, [box]) MARGIN = 200 player_regions = get_player_regions(rotated_image, box, MARGIN) for player, region in player_regions.items(): candidates = find_rank_candidates( region, suit_templates, player, scale ) for cand in candidates: cand["filename"] = filename cand["player"] = player all_candidates_global.append(cand) processed_files_info.append({"filename": filename, "error": None}) progress( 0.4, desc="ステージ2/3: 文字認識を実行中... (時間がかかります)" ) if not all_candidates_global or not trocr_pipeline: progress(1, desc="認識する文字候補がありませんでした。") print("認識する文字候補がありませんでした。") return all_results # エラーがあった画像の結果だけを返す try: candidates_pil_images = [ Image.fromarray(cv2.cvtColor(c["img"], cv2.COLOR_BGR2RGB)) for c in all_candidates_global ] ocr_results = trocr_pipeline(candidates_pil_images) except Exception as e: gr.Warning(f"OCR処理中にエラーが発生しました: {e}") # --- ステージ3: 結果の仕分けと最終的なカードの特定 --- progress(0.9, desc="ステージ3/3: 認識結果を仕分け中...") print([result[0]["generated_text"] for result in ocr_results]) raw_data = [] # blacks = [] # reds = [] for i, result in enumerate(ocr_results): text = result[0]["generated_text"].upper().strip() print(text, is_text_valid(text)) text = is_text_valid(text) if text is not None: candidate_info = all_candidates_global[i] print( f"--- 診断中: ランク '{text}' of {candidate_info['player']} at {candidate_info['pos']} with thick:{candidate_info['thickness']} ---" ) color_name, avg_lab = get_suit_from_image_rules( candidate_info["no_pad"], DEFAULT_THRESHOLDS ) print(color_name) if color_name == "mark": continue candidate_info["avg_lab"] = avg_lab candidate_info["color"] = color_name candidate_info["name"] = text raw_data.append(candidate_info) # print("\r\n".join(blacks)) # print("\r\n".join(reds)) all_results = arrange_data(raw_data) pbn_content = convert2pbn(all_results) pbn_filename = f"analysis_{datetime.now().strftime('%Y%m%d')}.pbn" # if processed_files_info: # last_result = {"filename": processed_files_info[0]["filename"], 1ands": all_results[0][1ands"]} if all_results: # ダウンロード用にPBNコンテンツを値として設定し、表示状態にする export_update = gr.update(interactive=True) else: export_update = gr.update(interactive=False) final_result = all_results[0]["hands"] filenames = [os.path.basename(p) for p in image_paths] dropdown_update = gr.update( choices=filenames, value=filenames[0], interactive=True, open=True ) dataframes = run_dds_analysis(all_results, progress) for result in all_results: if result["filename"] in dataframes.keys(): result["dds"] = dataframes[result["filename"]] return ( *display_selected_result(filenames[0], all_results), all_results, dropdown_update, export_update, ) except Exception as e: raise gr.Error(f"致命的なエラー: {e}") def display_selected_result(selected_filename, all_results): """ドロップダウンで選択されたファイルの結果を表示する""" result = next( (r for r in all_results if r["filename"] == selected_filename), None ) output_hands = {p: "" for p in PLAYER_ORDER} dds_df_update = gr.update(value=None) if result and "hands" in result and result["hands"]: for player, hand in result["hands"].items(): output_hands[player] = ", ".join(hand) if hand else "(なし)" dds_visible = True dds_df_update = gr.update(value=result.get("dds"), visible=True) elif result and "error" in result: # エラーがあった場合、最初のTextboxにエラーメッセージを表示 output_hands["north"] = f"エラー: {result['error']}" return ( output_hands["north"], output_hands["south"], output_hands["west"], output_hands["east"], dds_df_update, ) def validate_deal(hands): if not hands: return False, "分析対象のカードデータがありません" total_cards = [] # 手札が合計52枚あるかチェック for player, hand in hands.items(): if len(hand) != 13: return ( False, f"エラー: {player.capitalize()}の手札が13枚ではありません", ) for card in hand: if card in total_cards: return ( False, f"エラー: 重複したカードが検出されました ({card})", ) total_cards.append(card) return True, "デックは正常です" def format_dds_data(table): headers = [ "Declarer", "NT", "Spades ♠", "Hearts ♥", "Diamonds ♦", "Clubs ♣", ] rows = [ [ "North", table[4][0], table[0][0], table[1][0], table[2][0], table[3][0], ], [ "South", table[4][2], table[0][2], table[1][2], table[2][2], table[3][2], ], [ "East", table[4][1], table[0][1], table[1][1], table[2][1], table[3][1], ], [ "West", table[4][3], table[0][3], table[1][3], table[2][3], table[3][3], ], ] return headers, rows def run_dds_analysis(all_results_state, progress=gr.Progress()): """ダブルダミー分析を実行する""" valid_deals = [] for result in all_results_state: if "hands" in result: is_valid, _ = validate_deal(result["hands"]) if is_valid: valid_deals.append(result) if len(valid_deals) == 0: raise gr.Error( "分析不可", "分析対象となる正常なディールがありません。" ) # self.status_var.set(f"{len(valid_deals)}件のディールを分析中...") try: deals = dds.ddTableDealsPBN() deals.noOfTables = len(valid_deals) for i, result in enumerate(valid_deals): pbn_deal_string = convert2pbn_txt(result["hands"], "N") print(pbn_deal_string) # table_deal_pbn = dds.ddTableDealPBN() # table_deal_pbn.cards = pbn_deal_string.encode("utf-8") deals.deals[i].cards = pbn_deal_string.encode("utf-8") dds.SetMaxThreads(0) table_res = dds.ddTablesRes() per_res = dds.allParResults() # table_res_pointer = pointer(table_res) res = dds.CalcAllTablesPBN( pointer(deals), 0, (c_int * 5)(0, 0, 0, 0, 0), pointer(table_res), pointer(per_res), ) print("dds") if res != dds.RETURN_NO_FAULT: err_char_p = dds.ErrorMessage(res) err_string = ( string_at(err_char_p).decode("utf-8") if err_char_p else "Unknown error" ) raise RuntimeError( f"DDS Solver failed with code: {res} ({err_string})" ) print("dds") filenames = [d["filename"] for d in valid_deals] dataframes = {} for i, filename in enumerate(filenames): headers, rows = format_dds_data(table_res.results[i].resTable) print(rows) dataframes[filename] = pd.DataFrame(rows, columns=headers) return dataframes # 3. 結果を新しいウィンドウで表示 except Exception as e: raise gr.Error(f"DDS分析エラー: 分析中にエラーが発生しました:\n{e}") # self.status_var.set("DDS分析中にエラーが発生しました。") def prepare_export_files(all_results): """エクスポートボタンが押されたときに各形式のファイルを生成し、ダウンロードボタンを返す""" if not all_results: gr.Warning("エクスポート対象のデータがありません。") return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), ) filename_base, _ = os.path.splitext(all_results[0]["filename"]) # --- PBN --- pbn_content = convert2pbn(all_results) with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=".pbn", encoding="utf-8" ) as f: f.write(pbn_content) pbn_path = f.name # --- XHD --- xhd_content = convert2xhd(all_results, filename_base) with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=".xhd", encoding="shift_jis", errors="ignore", ) as f: f.write(xhd_content) xhd_path = f.name # --- DUP --- dup_content = convert2dup(all_results, None) with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=".dup", encoding="utf-8" ) as f: f.write(dup_content) dup_path = f.name return ( gr.update(value=pbn_path, visible=True), gr.update(value=xhd_path, visible=True), gr.update(value=dup_path, visible=True), gr.update(visible=True), # モーダルを表示 ) # --- Gradio UIの定義 --- with gr.Blocks( theme=gr.themes.Soft(), css="footer {visibility: hidden}" ) as demo: # 状態を保持するための非表示コンポーネント all_results_state = gr.State([]) raw_data_state = gr.State([]) current_result_state = gr.State(None) gr.Markdown("# Bridge Card Recognizer") gr.Markdown( "カメラで撮影したブリッジのプレイ中の写真から、各プレイヤーの手札を自動で認識します。" ) with gr.Row(): with gr.Column(scale=2): image_input = gr.File( label="画像ファイルを選択", file_count="multiple", file_types=["image"], type="filepath", ) analyze_button = gr.Button( "分析開始", variant="primary", interactive=False ) export_button = gr.Button( "結果をエクスポート", interactive=False ) # 新しいエクスポートボタン status_label = gr.Label( label="ステータス", value="準備完了。画像を選択して分析を開始してください。", ) if not trocr_pipeline: gr.Warning( "OCRモデルの読み込みに失敗しました。分析機能は利用できません。" ) with gr.Column(scale=3): results_dropdown = gr.Dropdown( label="表示するファイルを選択", interactive=False ) gr.Markdown("### 認識結果") with gr.Row(): north_box = gr.Textbox(label="North", interactive=False) south_box = gr.Textbox(label="South", interactive=False) with gr.Row(): west_box = gr.Textbox(label="West", interactive=False) east_box = gr.Textbox(label="East", interactive=False) with gr.Row(): # dds_button = gr.Button("ダブルダミー分析 (DDS)", visible=False) # debugger_button = gr.Button("カラーデバッガー", visible=False) export_file = gr.File( label="ダウンロード", visible=False, interactive=False ) with gr.Accordion("ダブルダミー分析 結果", open=False): dds_output_df = gr.DataFrame( label="最適プレイ手数", visible=False ) # エクスポート用モーダル with Modal(visible=False) as export_modal: gr.Markdown("### エクスポート形式を選択してください") gr.Markdown( "ボタンをクリックすると、対応する形式のファイルがダウンロードされます。" ) with gr.Row(): pbn_dl_btn = gr.DownloadButton("PBN形式 (.pbn)", variant="primary") xhd_dl_btn = gr.DownloadButton( "XHD形式 (.xhd)", variant="secondary" ) dup_dl_btn = gr.DownloadButton( "DUP形式 (.dup)", variant="secondary" ) # --- イベントリスナー --- demo.load( fn=load_model, inputs=None, outputs=[status_label, analyze_button] ) analyze_button.click( fn=analyze_image_gradio, inputs=[image_input], outputs=[ north_box, south_box, west_box, east_box, dds_output_df, all_results_state, results_dropdown, export_button, ], ) results_dropdown.change( fn=display_selected_result, inputs=[results_dropdown, all_results_state], outputs=[north_box, south_box, west_box, east_box, dds_output_df], ) export_button.click( fn=prepare_export_files, inputs=[all_results_state], outputs=[pbn_dl_btn, xhd_dl_btn, dup_dl_btn, export_modal], ) # dds_button.click( # fn=run_dds_analysis, # inputs=[all_results_state], # outputs=[dds_output_df], # ) if __name__ == "__main__": demo.launch(debug=True)