zxymimi23451's picture
Upload 258 files
78360e7 verified
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
from flask import Flask, request, jsonify, render_template
import os
import io
import numpy as np
import torch
import yaml
import matplotlib
import argparse
matplotlib.use('Agg')
app = Flask(__name__, static_folder='static', template_folder='templates')
# β€”β€”β€” Arguments β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='videos_example')
args = parser.parse_args()
# β€”β€”β€” Configuration β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
BASE_DIR = args.save_dir
STATIC_BASE = os.path.join('static', BASE_DIR)
IMAGES_DIR = os.path.join(STATIC_BASE, 'images')
OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks')
TRACKS_DIR = os.path.join(BASE_DIR, 'tracks')
YAML_PATH = os.path.join(BASE_DIR, 'test.yaml')
IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images')
FIXED_LENGTH = 121
COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
QUANT_MULTI = 8
for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT):
os.makedirs(d, exist_ok=True)
# β€”β€”β€” Helpers β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI):
# pack into uint16 as before
arr_q = (quant_multi * arr).astype(np.float32)
bio = io.BytesIO()
if compressed:
np.savez_compressed(bio, array=arr_q)
else:
np.savez(bio, array=arr_q)
torch.save(bio.getvalue(), path)
def load_existing_tracks(path):
raw = torch.load(path)
bio = io.BytesIO(raw)
with np.load(bio) as npz:
return npz['array']
# β€”β€”β€” Routes β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload_image', methods=['POST'])
def upload_image():
f = request.files['image']
from PIL import Image
img = Image.open(f.stream)
orig_w, orig_h = img.size
idx = len(os.listdir(IMAGES_DIR)) + 1
ext = f.filename.rsplit('.', 1)[-1]
fname = f"{idx:02d}.{ext}"
img.save(os.path.join(IMAGES_DIR, fname))
img.save(os.path.join(IMAGES_DIR_OUT, fname))
return jsonify({
'image_url': f"{STATIC_BASE}/images/{fname}",
'image_id': idx,
'ext': ext,
'orig_width': orig_w,
'orig_height': orig_h
})
@app.route('/store_tracks', methods=['POST'])
def store_tracks():
data = request.get_json()
image_id = data['image_id']
ext = data['ext']
free_tracks = data.get('tracks', [])
circ_trajs = data.get('circle_trajectories', [])
# Debug lengths
for i, tr in enumerate(free_tracks, 1):
print(f"Freehand Track {i}: {len(tr)} points")
for i, tr in enumerate(circ_trajs, 1):
print(f"Circle/Static Traj {i}: {len(tr)} points")
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)
arrs = []
# 1) Freehand tracks
for i, tr in enumerate(free_tracks):
pts = pad_pts(tr)
arrs.append(pts,)
# 2) Circle + Static combined
for i, tr in enumerate(circ_trajs):
pts = pad_pts(tr)
arrs.append(pts)
print(arrs)
# Nothing to save?
if not arrs:
overlay_file = f"{image_id:02d}.png"
return jsonify({
'status': 'ok',
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
})
new_tracks = np.stack(arrs, axis=0) # (T_new, FIXED_LENGTH,1,4)
# Load existing .pth and pad old channels to 4 if needed
track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth")
if os.path.exists(track_path):
# shape (T_old, FIXED_LENGTH,1,3) or (...,4)
old = load_existing_tracks(track_path)
if old.ndim == 4 and old.shape[-1] == 3:
pad = np.zeros(
(old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32)
old = np.concatenate((old, pad), axis=-1)
all_tracks = np.concatenate([old, new_tracks], axis=0)
else:
all_tracks = new_tracks
# Save updated track file
array_to_npz_bytes(all_tracks, track_path, compressed=True)
# Build overlay PNG
img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}")
img = plt.imread(img_path)
fig, ax = plt.subplots(figsize=(12, 8))
ax.imshow(img)
for t in all_tracks:
coords = t[:, 0, :] # (FIXED_LENGTH,4)
ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1]
[coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0])
ax.axis('off')
overlay_file = f"{image_id:02d}.png"
fig.savefig(os.path.join(OVERLAY_DIR, overlay_file),
bbox_inches='tight', pad_inches=0)
plt.close(fig)
# Update YAML (unchanged)
entry = {
"image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"),
"text": None,
"track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth")
}
if os.path.exists(YAML_PATH):
with open(YAML_PATH) as yf:
docs = yaml.safe_load(yf) or []
else:
docs = []
for e in docs:
if e.get("image", "").endswith(f"{image_id:02d}.{ext}"):
e.update(entry)
break
else:
docs.append(entry)
with open(YAML_PATH, 'w') as yf:
yaml.dump(docs, yf, default_flow_style=False)
return jsonify({
'status': 'ok',
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
})
if __name__ == '__main__':
app.run(debug=True)