File size: 2,603 Bytes
f96995c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from pathlib import Path
import argparse
import random
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm, trange
import hydra
from omegaconf import DictConfig, OmegaConf
import yaml
from datetime import datetime
import numpy as np
from PIL import Image
import warp as wp
import matplotlib.pyplot as plt
import multiprocess as mp
import torch
import torch.backends.cudnn
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import logging
from pgnd.utils import get_root, mkdir
from modules_planning.planning_env import RobotPlanningEnv
root: Path = get_root(__file__)
logging.basicConfig(level=logging.WARNING)
def main(args):
mp.set_start_method('spawn')
with open(root / args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.CLoader)
cfg = OmegaConf.create(config)
cfg.sim.num_steps = 1000
cfg.sim.gripper_forcing = False
cfg.sim.uniform = True
iteration = args.iteration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ckpt_path = (root / args.config).parent / 'ckpt' / f'{iteration:06d}.pt'
seed = cfg.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# path
datetime_now = datetime.now().strftime('%y%m%d-%H%M%S')
exp_root: Path = root / 'log' / cfg.train.name / 'plan' / datetime_now
mkdir(exp_root, overwrite=cfg.overwrite, resume=cfg.resume)
env = RobotPlanningEnv(
cfg,
exp_root=exp_root,
ckpt_path=ckpt_path,
resolution=(848, 480),
capture_fps=30,
record_fps=0,
text_prompts=args.text_prompts,
show_annotation=(not args.no_annotation),
use_robot=True,
bimanual=args.bimanual,
gripper_enable=True,
debug=True,
construct_target=args.construct_target,
)
env.start()
env.join()
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--config', type=str, default='log/cloth/train/hydra.yaml')
arg_parser.add_argument('--iteration', type=str, default=100000)
arg_parser.add_argument('--text_prompts', type=str, default='green towel.')
arg_parser.add_argument('--seed', type=int, default=42)
arg_parser.add_argument('--no_annotation', action='store_true')
arg_parser.add_argument('--bimanual', action='store_true')
arg_parser.add_argument('--construct_target', action='store_true')
args = arg_parser.parse_args()
with torch.no_grad():
main(args)
|