Spaces:
Running
Running
import math | |
import os | |
import tempfile | |
from ctypes import c_int, c_uint, pointer, string_at | |
from datetime import datetime | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from gradio_modal import Modal | |
from PIL import Image | |
# from transformers import pipeline | |
import dds | |
from identify_cards import ( | |
SUIT_TEMPLATE_PATH, | |
determine_and_correct_orientation, | |
find_rank_candidates, | |
get_suit_from_image_rules, | |
load_suit_templates, | |
save_img_with_rect, | |
) | |
from utils import ( | |
arrange_hand, | |
convert2dup, | |
convert2pbn, | |
convert2pbn_board, | |
convert2pbn_txt, | |
convert2xhd, | |
is_text_valid, | |
) | |
# --- グローバル変数・設定 --- | |
trocr_pipeline = None | |
SUITS_BY_COLOR = {"black": "S", "green": "C", "red": "H", "orange": "D"} | |
VALID_RANKS = ["A", "K", "Q", "J", "T", "9", "8", "7", "6", "5", "4", "3", "2"] | |
PLAYER_ORDER = ["north", "east", "south", "west"] | |
SUIT_ORDER = {"S": 0, "H": 1, "D": 2, "C": 3} | |
RANK_ORDER = { | |
"A": 14, | |
"K": 13, | |
"Q": 12, | |
"J": 11, | |
"T": 10, | |
"9": 9, | |
"8": 8, | |
"7": 7, | |
"6": 6, | |
"5": 5, | |
"4": 4, | |
"3": 3, | |
"2": 2, | |
} | |
DEFAULT_THRESHOLDS = { | |
"L_black": 65.0, | |
"a_green": 126.0, | |
"a_red": 134.0, | |
"ba_black": -4.5, | |
"ab_black": 250.0, | |
"a_b_red": 9.0, | |
} | |
def load_model(): | |
""" | |
TrOCRモデルをバックグラウンドで読み込む関数。 | |
UIのロード完了後に demo.load() イベントで呼び出される。 | |
""" | |
global trocr_pipeline | |
try: | |
if trocr_pipeline is None: | |
print("バックグラウンドでTrOCRモデルを読み込んでいます...") | |
trocr_pipeline = pipeline( | |
"image-to-text", model="microsoft/trocr-small-printed" | |
) | |
print("TrOCRの準備が完了しました。") | |
# UIコンポーネントを更新するための値を返す | |
return gr.update( | |
value="準備完了。画像を選択して分析を開始してください。" | |
), gr.update(interactive=True) | |
except Exception as e: | |
error_message = f"AIモデルの読み込みエラー: {e}" | |
print(error_message) | |
gr.Warning( | |
f"AIモデルの読み込みに失敗しました。分析機能は利用できません。詳細はログを確認してください。" | |
) | |
# エラーメッセージを表示し、分析ボタンは無効のままにする | |
return gr.update(value=error_message), gr.update(interactive=False) | |
def get_player_regions(img, box, margin): | |
bx, by, bw, bh = box | |
h, w, _ = img.shape | |
player_regions = { | |
"north": img[0:by, :], | |
"south": img[by + bh : h, :], | |
"west": img[by - margin : by + bh + margin, 0:bx], | |
"east": img[by - margin : by + bh + margin, bx + bw : w], | |
} | |
for player, region in player_regions.items(): | |
if region is not None and region.size > 0: | |
if player == "north": | |
player_regions[player] = cv2.rotate(region, cv2.ROTATE_180) | |
elif player == "east": | |
player_regions[player] = cv2.rotate( | |
region, cv2.ROTATE_90_CLOCKWISE | |
) | |
elif player == "west": | |
player_regions[player] = cv2.rotate( | |
region, cv2.ROTATE_90_COUNTERCLOCKWISE | |
) | |
return player_regions | |
def arrange_data(raw_rank_data): | |
# 生データから最終的な手札を作成・表示 | |
all_result = [] | |
temp_hands = {} # ファイルごとの手札を一時保存 | |
for rank_data in raw_rank_data: | |
filename = rank_data["filename"] | |
if filename not in temp_hands: | |
temp_hands[filename] = {p: [] for p in PLAYER_ORDER} | |
color_name = rank_data["color"] | |
suit = SUITS_BY_COLOR[color_name] | |
card_name = f"{suit}{rank_data['name']}" | |
temp_hands[filename][rank_data["player"]].append(card_name) | |
# 整形してall_resultsに格納 | |
for filename, hands in temp_hands.items(): | |
all_result.append( | |
{ | |
"filename": filename, | |
"hands": { | |
player: arrange_hand(cards) | |
for player, cards in hands.items() | |
}, | |
} | |
) | |
return all_result | |
def analyze_image_gradio(image_paths, progress=gr.Progress()): | |
global trocr_pipeline | |
# モデルが読み込まれているか確認 | |
if trocr_pipeline is None: | |
gr.Warning( | |
"AIモデルがまだ読み込まれていません。しばらく待ってから再度お試しください。" | |
) | |
# 空の更新を返すことで、UIの状態を変えずに処理を終了 | |
return (gr.update(),) * 11 | |
all_results = [] | |
num_total_files = len(image_paths) | |
progress(0, desc="テンプレート画像読み込み中...") | |
suit_templates = load_suit_templates(SUIT_TEMPLATE_PATH) | |
if not suit_templates: | |
raise gr.Error( | |
f"エラー: {SUIT_TEMPLATE_PATH} フォルダにスートのテンプレート画像が見つかりません。" | |
) | |
try: | |
all_candidates_global = [] | |
processed_files_info = [] | |
# image_objects = {} | |
for i, image_path in enumerate(image_paths): | |
progress( | |
(i + 1) / num_total_files * 0.15, | |
desc="ステージ1/3: 文字候補を検出中...", | |
) | |
filename = os.path.basename(image_path) | |
progress( | |
(i + 1) / num_total_files * 0.3, | |
f"分析中 ({i+1}/{num_total_files}): {filename}", | |
) | |
try: | |
# ファイルをバイナリモードで安全に読み込む | |
with open(image_path, "rb") as f: | |
# バイトデータをNumPy配列に変換 | |
file_bytes = np.asarray( | |
bytearray(f.read()), dtype=np.uint8 | |
) | |
# NumPy配列(メモリ上のデータ)から画像をデコード | |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
if image is None: | |
raise gr.Error( | |
"OpenCVが画像をデコードできませんでした。ファイルが破損しているか、非対応の形式の可能性があります。" | |
) | |
# image_objects[filename] = image | |
except Exception as e: | |
# ファイル読み込み自体のエラーをキャッチ | |
all_results.append( | |
{"filename": filename, "error": f"画像読み込みエラー: {e}"} | |
) | |
# image_objects[filename] = None | |
continue | |
# box = find_center_box(image) | |
print("detect board") | |
rotated_image, box, scale = determine_and_correct_orientation( | |
image, lambda msg: print(msg) | |
) | |
if box is None: | |
all_results.append( | |
{"filename": filename, "error": "中央ボードの検出に失敗"} | |
) | |
continue | |
print(box) | |
save_img_with_rect("debug_rotated.jpg", rotated_image, [box]) | |
MARGIN = 200 | |
player_regions = get_player_regions(rotated_image, box, MARGIN) | |
for player, region in player_regions.items(): | |
candidates = find_rank_candidates( | |
region, suit_templates, player, scale | |
) | |
for cand in candidates: | |
cand["filename"] = filename | |
cand["player"] = player | |
all_candidates_global.append(cand) | |
processed_files_info.append({"filename": filename, "error": None}) | |
progress( | |
0.4, desc="ステージ2/3: 文字認識を実行中... (時間がかかります)" | |
) | |
if not all_candidates_global or not trocr_pipeline: | |
progress(1, desc="認識する文字候補がありませんでした。") | |
print("認識する文字候補がありませんでした。") | |
return all_results # エラーがあった画像の結果だけを返す | |
try: | |
candidates_pil_images = [ | |
Image.fromarray(cv2.cvtColor(c["img"], cv2.COLOR_BGR2RGB)) | |
for c in all_candidates_global | |
] | |
ocr_results = trocr_pipeline(candidates_pil_images) | |
except Exception as e: | |
gr.Warning(f"OCR処理中にエラーが発生しました: {e}") | |
# --- ステージ3: 結果の仕分けと最終的なカードの特定 --- | |
progress(0.9, desc="ステージ3/3: 認識結果を仕分け中...") | |
print([result[0]["generated_text"] for result in ocr_results]) | |
raw_data = [] | |
# blacks = [] | |
# reds = [] | |
for i, result in enumerate(ocr_results): | |
text = result[0]["generated_text"].upper().strip() | |
print(text, is_text_valid(text)) | |
text = is_text_valid(text) | |
if text is not None: | |
candidate_info = all_candidates_global[i] | |
print( | |
f"--- 診断中: ランク '{text}' of {candidate_info['player']} at {candidate_info['pos']} with thick:{candidate_info['thickness']} ---" | |
) | |
color_name, avg_lab = get_suit_from_image_rules( | |
candidate_info["no_pad"], DEFAULT_THRESHOLDS | |
) | |
print(color_name) | |
if color_name == "mark": | |
continue | |
candidate_info["avg_lab"] = avg_lab | |
candidate_info["color"] = color_name | |
candidate_info["name"] = text | |
raw_data.append(candidate_info) | |
# print("\r\n".join(blacks)) | |
# print("\r\n".join(reds)) | |
all_results = arrange_data(raw_data) | |
pbn_content = convert2pbn(all_results) | |
pbn_filename = f"analysis_{datetime.now().strftime('%Y%m%d')}.pbn" | |
# if processed_files_info: | |
# last_result = {"filename": processed_files_info[0]["filename"], 1ands": all_results[0][1ands"]} | |
if all_results: | |
# ダウンロード用にPBNコンテンツを値として設定し、表示状態にする | |
export_update = gr.update(interactive=True) | |
else: | |
export_update = gr.update(interactive=False) | |
final_result = all_results[0]["hands"] | |
filenames = [os.path.basename(p) for p in image_paths] | |
dropdown_update = gr.update( | |
choices=filenames, value=filenames[0], interactive=True, open=True | |
) | |
dataframes = run_dds_analysis(all_results, progress) | |
for result in all_results: | |
if result["filename"] in dataframes.keys(): | |
result["dds"] = dataframes[result["filename"]] | |
return ( | |
*display_selected_result(filenames[0], all_results), | |
all_results, | |
dropdown_update, | |
export_update, | |
) | |
except Exception as e: | |
raise gr.Error(f"致命的なエラー: {e}") | |
def display_selected_result(selected_filename, all_results): | |
"""ドロップダウンで選択されたファイルの結果を表示する""" | |
result = next( | |
(r for r in all_results if r["filename"] == selected_filename), None | |
) | |
output_hands = {p: "" for p in PLAYER_ORDER} | |
dds_df_update = gr.update(value=None) | |
if result and "hands" in result and result["hands"]: | |
for player, hand in result["hands"].items(): | |
output_hands[player] = ", ".join(hand) if hand else "(なし)" | |
dds_visible = True | |
dds_df_update = gr.update(value=result.get("dds"), visible=True) | |
elif result and "error" in result: | |
# エラーがあった場合、最初のTextboxにエラーメッセージを表示 | |
output_hands["north"] = f"エラー: {result['error']}" | |
return ( | |
output_hands["north"], | |
output_hands["south"], | |
output_hands["west"], | |
output_hands["east"], | |
dds_df_update, | |
) | |
def validate_deal(hands): | |
if not hands: | |
return False, "分析対象のカードデータがありません" | |
total_cards = [] | |
# 手札が合計52枚あるかチェック | |
for player, hand in hands.items(): | |
if len(hand) != 13: | |
return ( | |
False, | |
f"エラー: {player.capitalize()}の手札が13枚ではありません", | |
) | |
for card in hand: | |
if card in total_cards: | |
return ( | |
False, | |
f"エラー: 重複したカードが検出されました ({card})", | |
) | |
total_cards.append(card) | |
return True, "デックは正常です" | |
def format_dds_data(table): | |
headers = [ | |
"Declarer", | |
"NT", | |
"Spades ♠", | |
"Hearts ♥", | |
"Diamonds ♦", | |
"Clubs ♣", | |
] | |
rows = [ | |
[ | |
"North", | |
table[4][0], | |
table[0][0], | |
table[1][0], | |
table[2][0], | |
table[3][0], | |
], | |
[ | |
"South", | |
table[4][2], | |
table[0][2], | |
table[1][2], | |
table[2][2], | |
table[3][2], | |
], | |
[ | |
"East", | |
table[4][1], | |
table[0][1], | |
table[1][1], | |
table[2][1], | |
table[3][1], | |
], | |
[ | |
"West", | |
table[4][3], | |
table[0][3], | |
table[1][3], | |
table[2][3], | |
table[3][3], | |
], | |
] | |
return headers, rows | |
def run_dds_analysis(all_results_state, progress=gr.Progress()): | |
"""ダブルダミー分析を実行する""" | |
valid_deals = [] | |
for result in all_results_state: | |
if "hands" in result: | |
is_valid, _ = validate_deal(result["hands"]) | |
if is_valid: | |
valid_deals.append(result) | |
if len(valid_deals) == 0: | |
raise gr.Error( | |
"分析不可", "分析対象となる正常なディールがありません。" | |
) | |
# self.status_var.set(f"{len(valid_deals)}件のディールを分析中...") | |
try: | |
deals = dds.ddTableDealsPBN() | |
deals.noOfTables = len(valid_deals) | |
for i, result in enumerate(valid_deals): | |
pbn_deal_string = convert2pbn_txt(result["hands"], "N") | |
print(pbn_deal_string) | |
# table_deal_pbn = dds.ddTableDealPBN() | |
# table_deal_pbn.cards = pbn_deal_string.encode("utf-8") | |
deals.deals[i].cards = pbn_deal_string.encode("utf-8") | |
dds.SetMaxThreads(0) | |
table_res = dds.ddTablesRes() | |
per_res = dds.allParResults() | |
# table_res_pointer = pointer(table_res) | |
res = dds.CalcAllTablesPBN( | |
pointer(deals), | |
0, | |
(c_int * 5)(0, 0, 0, 0, 0), | |
pointer(table_res), | |
pointer(per_res), | |
) | |
print("dds") | |
if res != dds.RETURN_NO_FAULT: | |
err_char_p = dds.ErrorMessage(res) | |
err_string = ( | |
string_at(err_char_p).decode("utf-8") | |
if err_char_p | |
else "Unknown error" | |
) | |
raise RuntimeError( | |
f"DDS Solver failed with code: {res} ({err_string})" | |
) | |
print("dds") | |
filenames = [d["filename"] for d in valid_deals] | |
dataframes = {} | |
for i, filename in enumerate(filenames): | |
headers, rows = format_dds_data(table_res.results[i].resTable) | |
print(rows) | |
dataframes[filename] = pd.DataFrame(rows, columns=headers) | |
return dataframes | |
# 3. 結果を新しいウィンドウで表示 | |
except Exception as e: | |
raise gr.Error(f"DDS分析エラー: 分析中にエラーが発生しました:\n{e}") | |
# self.status_var.set("DDS分析中にエラーが発生しました。") | |
def prepare_export_files(all_results): | |
"""エクスポートボタンが押されたときに各形式のファイルを生成し、ダウンロードボタンを返す""" | |
if not all_results: | |
gr.Warning("エクスポート対象のデータがありません。") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
) | |
filename_base, _ = os.path.splitext(all_results[0]["filename"]) | |
# --- PBN --- | |
pbn_content = convert2pbn(all_results) | |
with tempfile.NamedTemporaryFile( | |
delete=False, mode="w", suffix=".pbn", encoding="utf-8" | |
) as f: | |
f.write(pbn_content) | |
pbn_path = f.name | |
# --- XHD --- | |
xhd_content = convert2xhd(all_results, filename_base) | |
with tempfile.NamedTemporaryFile( | |
delete=False, | |
mode="w", | |
suffix=".xhd", | |
encoding="shift_jis", | |
errors="ignore", | |
) as f: | |
f.write(xhd_content) | |
xhd_path = f.name | |
# --- DUP --- | |
dup_content = convert2dup(all_results, None) | |
with tempfile.NamedTemporaryFile( | |
delete=False, mode="w", suffix=".dup", encoding="utf-8" | |
) as f: | |
f.write(dup_content) | |
dup_path = f.name | |
return ( | |
gr.update(value=pbn_path, visible=True), | |
gr.update(value=xhd_path, visible=True), | |
gr.update(value=dup_path, visible=True), | |
gr.update(visible=True), # モーダルを表示 | |
) | |
# --- Gradio UIの定義 --- | |
with gr.Blocks( | |
theme=gr.themes.Soft(), css="footer {visibility: hidden}" | |
) as demo: | |
# 状態を保持するための非表示コンポーネント | |
all_results_state = gr.State([]) | |
raw_data_state = gr.State([]) | |
current_result_state = gr.State(None) | |
gr.Markdown("# Bridge Card Recognizer") | |
gr.Markdown( | |
"カメラで撮影したブリッジのプレイ中の写真から、各プレイヤーの手札を自動で認識します。" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input = gr.File( | |
label="画像ファイルを選択", | |
file_count="multiple", | |
file_types=["image"], | |
type="filepath", | |
) | |
analyze_button = gr.Button( | |
"分析開始", variant="primary", interactive=False | |
) | |
export_button = gr.Button( | |
"結果をエクスポート", interactive=False | |
) # 新しいエクスポートボタン | |
status_label = gr.Label( | |
label="ステータス", | |
value="準備完了。画像を選択して分析を開始してください。", | |
) | |
if not trocr_pipeline: | |
gr.Warning( | |
"OCRモデルの読み込みに失敗しました。分析機能は利用できません。" | |
) | |
with gr.Column(scale=3): | |
results_dropdown = gr.Dropdown( | |
label="表示するファイルを選択", interactive=False | |
) | |
gr.Markdown("### 認識結果") | |
with gr.Row(): | |
north_box = gr.Textbox(label="North", interactive=False) | |
south_box = gr.Textbox(label="South", interactive=False) | |
with gr.Row(): | |
west_box = gr.Textbox(label="West", interactive=False) | |
east_box = gr.Textbox(label="East", interactive=False) | |
with gr.Row(): | |
# dds_button = gr.Button("ダブルダミー分析 (DDS)", visible=False) | |
# debugger_button = gr.Button("カラーデバッガー", visible=False) | |
export_file = gr.File( | |
label="ダウンロード", visible=False, interactive=False | |
) | |
with gr.Accordion("ダブルダミー分析 結果", open=False): | |
dds_output_df = gr.DataFrame( | |
label="最適プレイ手数", visible=False | |
) | |
# エクスポート用モーダル | |
with Modal(visible=False) as export_modal: | |
gr.Markdown("### エクスポート形式を選択してください") | |
gr.Markdown( | |
"ボタンをクリックすると、対応する形式のファイルがダウンロードされます。" | |
) | |
with gr.Row(): | |
pbn_dl_btn = gr.DownloadButton("PBN形式 (.pbn)", variant="primary") | |
xhd_dl_btn = gr.DownloadButton( | |
"XHD形式 (.xhd)", variant="secondary" | |
) | |
dup_dl_btn = gr.DownloadButton( | |
"DUP形式 (.dup)", variant="secondary" | |
) | |
# --- イベントリスナー --- | |
demo.load( | |
fn=load_model, inputs=None, outputs=[status_label, analyze_button] | |
) | |
analyze_button.click( | |
fn=analyze_image_gradio, | |
inputs=[image_input], | |
outputs=[ | |
north_box, | |
south_box, | |
west_box, | |
east_box, | |
dds_output_df, | |
all_results_state, | |
results_dropdown, | |
export_button, | |
], | |
) | |
results_dropdown.change( | |
fn=display_selected_result, | |
inputs=[results_dropdown, all_results_state], | |
outputs=[north_box, south_box, west_box, east_box, dds_output_df], | |
) | |
export_button.click( | |
fn=prepare_export_files, | |
inputs=[all_results_state], | |
outputs=[pbn_dl_btn, xhd_dl_btn, dup_dl_btn, export_modal], | |
) | |
# dds_button.click( | |
# fn=run_dds_analysis, | |
# inputs=[all_results_state], | |
# outputs=[dds_output_df], | |
# ) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |