File size: 2,790 Bytes
f96995c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import numpy as np
from io import BytesIO
def quat_mult(q1, q2):
w1, x1, y1, z1 = q1
w2, x2, y2, z2 = q2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return np.array([w, x, y, z], dtype=np.float32)
def rot_mat_to_quat(rot_mat):
w = np.sqrt(1 + rot_mat[0, 0] + rot_mat[1, 1] + rot_mat[2, 2]) / 2
x = (rot_mat[2, 1] - rot_mat[1, 2]) / (4 * w)
y = (rot_mat[0, 2] - rot_mat[2, 0]) / (4 * w)
z = (rot_mat[1, 0] - rot_mat[0, 1]) / (4 * w)
return np.array([w, x, y, z], dtype=np.float32)
def save_to_splat(pts, colors, scales, quats, opacities, output_file, center=True, rotate=True, rot_rev=False):
if center:
pts_mean = np.mean(pts, axis=0)
pts = pts - pts_mean
buffer = BytesIO()
for (v, c, s, q, o) in zip(pts, colors, scales, quats, opacities):
position = np.array([v[0], v[1], v[2]], dtype=np.float32)
scales = np.array([s[0], s[1], s[2]], dtype=np.float32)
rot = np.array([q[0], q[1], q[2], q[3]], dtype=np.float32)
# SH_C0 = 0.28209479177387814
# color = np.array([0.5 + SH_C0 * c[0], 0.5 + SH_C0 * c[1], 0.5 + SH_C0 * c[2], o[0]])
color = np.array([c[0], c[1], c[2], o[0]])
# rotate around x axis
if rotate:
rot_x_90 = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32)
if not rot_rev:
rot_x_90 = np.linalg.inv(rot_x_90)
position = np.dot(rot_x_90, position)
rot = quat_mult(rot_mat_to_quat(rot_x_90), rot)
buffer.write(position.tobytes())
buffer.write(scales.tobytes())
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
buffer.write(
((rot / np.linalg.norm(rot)) * 128 + 128)
.clip(0, 255)
.astype(np.uint8)
.tobytes()
)
with open(output_file, "wb") as f:
f.write(buffer.getvalue())
def read_splat(splat_file):
with open(splat_file, "rb") as f:
data = f.read()
pts = []
colors = []
scales = []
quats = []
opacities = []
for i in range(0, len(data), 32):
v = np.frombuffer(data[i : i + 12], dtype=np.float32)
s = np.frombuffer(data[i + 12 : i + 24], dtype=np.float32)
c = np.frombuffer(data[i + 24 : i + 28], dtype=np.uint8) / 255
q = np.frombuffer(data[i + 28 : i + 32], dtype=np.uint8)
q = (q * 1.0 - 128) / 128
pts.append(v)
scales.append(s)
colors.append(c[:3])
quats.append(q)
opacities.append(c[3:])
return np.array(pts), np.array(colors), np.array(scales), np.array(quats), np.array(opacities)
|