|
from pathlib import Path |
|
import argparse |
|
import random |
|
import time |
|
import os |
|
import matplotlib.pyplot as plt |
|
from collections import defaultdict |
|
from tqdm import tqdm, trange |
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
import yaml |
|
from datetime import datetime |
|
import numpy as np |
|
from PIL import Image |
|
import warp as wp |
|
import matplotlib.pyplot as plt |
|
import multiprocess as mp |
|
import torch |
|
import torch.backends.cudnn |
|
import torch.nn as nn |
|
from torch.nn.utils import clip_grad_norm_ |
|
from torch.utils.data import DataLoader |
|
import logging |
|
|
|
from pgnd.utils import get_root, mkdir |
|
from modules_planning.planning_env import RobotPlanningEnv |
|
|
|
root: Path = get_root(__file__) |
|
logging.basicConfig(level=logging.WARNING) |
|
|
|
|
|
def main(args): |
|
mp.set_start_method('spawn') |
|
|
|
with open(root / args.config, 'r') as f: |
|
config = yaml.load(f, Loader=yaml.CLoader) |
|
cfg = OmegaConf.create(config) |
|
|
|
cfg.sim.num_steps = 1000 |
|
cfg.sim.gripper_forcing = False |
|
cfg.sim.uniform = True |
|
|
|
iteration = args.iteration |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
ckpt_path = (root / args.config).parent / 'ckpt' / f'{iteration:06d}.pt' |
|
seed = cfg.seed |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
datetime_now = datetime.now().strftime('%y%m%d-%H%M%S') |
|
exp_root: Path = root / 'log' / cfg.train.name / 'plan' / datetime_now |
|
mkdir(exp_root, overwrite=cfg.overwrite, resume=cfg.resume) |
|
|
|
env = RobotPlanningEnv( |
|
cfg, |
|
exp_root=exp_root, |
|
ckpt_path=ckpt_path, |
|
resolution=(848, 480), |
|
capture_fps=30, |
|
record_fps=0, |
|
text_prompts=args.text_prompts, |
|
show_annotation=(not args.no_annotation), |
|
use_robot=True, |
|
bimanual=args.bimanual, |
|
gripper_enable=True, |
|
debug=True, |
|
construct_target=args.construct_target, |
|
) |
|
|
|
env.start() |
|
env.join() |
|
|
|
|
|
if __name__ == '__main__': |
|
arg_parser = argparse.ArgumentParser() |
|
arg_parser.add_argument('--config', type=str, default='log/cloth/train/hydra.yaml') |
|
arg_parser.add_argument('--iteration', type=str, default=100000) |
|
arg_parser.add_argument('--text_prompts', type=str, default='green towel.') |
|
arg_parser.add_argument('--seed', type=int, default=42) |
|
arg_parser.add_argument('--no_annotation', action='store_true') |
|
arg_parser.add_argument('--bimanual', action='store_true') |
|
arg_parser.add_argument('--construct_target', action='store_true') |
|
args = arg_parser.parse_args() |
|
|
|
with torch.no_grad(): |
|
main(args) |
|
|