naderasadi's picture
Initial commit
5b2ab1c
from typing import List, Tuple, Dict, Union, Any, Optional
import argparse
from PIL import Image
import wandb
class WandBLogger:
def __init__(self, config: Dict[str, Any]):
assert "wandb_project" in config, "Missing `wandb_project` in config"
self.wandb = wandb.init(
project=config.wandb_project, name=config.exp_name, config=config
)
def log_scalars(self, logs: Dict[str, Union[int, float]]):
self.wandb.log(logs)
def log_images(self, logs: Dict[str, List[Image.Image]]):
wandb.log(
{
key: [wandb.Image(image, caption=key) for image in images]
for key, images in logs.items()
}
)
def parser():
parser = argparse.ArgumentParser()
parser.add_argument("--segmentation_model", type=str, default="mask2former")
parser.add_argument("--controlnet_name", type=str, default="hed")
parser.add_argument(
"--sd_model", type=str, default="runwayml/stable-diffusion-v1-5"
)
parser.add_argument(
"--images_path",
type=str,
default="/home/nader/DesignGenie/assets/images/",
)
parser.add_argument(
"--prompts_path",
type=str,
default="/home/nader/DesignGenie/assets/prompts.txt",
)
parser.add_argument(
"--negative_prompt",
type=str,
default="monochrome, lowres, bad anatomy, worst quality, low quality",
)
parser.add_argument("--num_inference_steps", type=int, default=20)
parser.add_argument("--n_outputs", type=int, default=4)
parser.add_argument("--wandb_project", type=str, default="DesignGenie")
parser.add_argument("--wandb", type=int, default=1)
parser.add_argument("--exp_name", type=str, default="demo")
args = parser.parse_args()
return args