File size: 1,679 Bytes
fbf88c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from typing import Optional
import gradio as gr
from hfutils.repository import hf_hub_repo_url
from imgutils.generic import MultiLabelTIMMModel
KNOWN_MODELS = ['animetimm/swinv2_base_window8_256.e621v1-full']
SPECIAL_MODELS = {}
def render_model_demo(repo_id, label: Optional[str] = None):
label = label or repo_id.split('/')[-1]
with gr.Tab(label):
model = MultiLabelTIMMModel(repo_id=repo_id)
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=repo_id, repo_type='model')
gr.Markdown(f'This is the quick demo for tagger model [{repo_id}]({repo_url}).')
with gr.Row():
model.make_ui()
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.HTML(f'<h2 style="text-align: center;">Tagger Playground For E621V1 Full</h2>')
gr.Markdown(f'This is the playground for taggers trained on [animetimm/e621-wdtagger-v1-w640-ws-full](https://huggingface.co/datasets/animetimm/e621-wdtagger-v1-w640-ws-full).'
f'Powered by `dghs-imgutils`\'s quick demo module.')
with gr.Row():
with gr.Tabs():
_exist_models = set()
for t, repo_id in SPECIAL_MODELS.items():
render_model_demo(repo_id, f'{repo_id.split("/")[-1]} ({t})')
_exist_models.add(repo_id)
for repo_id in KNOWN_MODELS:
if repo_id not in _exist_models:
render_model_demo(repo_id)
_exist_models.add(repo_id)
demo.launch()
|