algohunt commited on
Commit
415f10f
·
1 Parent(s): 029882a
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -42,6 +42,7 @@ from src.data import DemoData
42
  from src.models import LiNo_UniPS
43
  from torch.utils.data import DataLoader
44
  import pytorch_lightning as pl
 
45
 
46
  MAX_SEED = np.iinfo(np.int32).max
47
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -49,12 +50,8 @@ WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights'
49
  os.makedirs(TMP_DIR, exist_ok=True)
50
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
51
 
52
- is_gpu_available = torch.cuda.is_available()
53
 
54
 
55
- if is_gpu_available:
56
- print("✅ NVIDIA GPU。")
57
-
58
 
59
  def cache_weights(weights_dir: str) -> dict:
60
  import os
@@ -88,7 +85,7 @@ def preprocess_mesh(mesh_prompt):
88
  trimesh_mesh = trimesh.load_mesh(mesh_prompt)
89
  trimesh_mesh.export(mesh_prompt+'.glb')
90
  return mesh_prompt+'.glb'
91
-
92
  def generate_3d(image, seed=-1,
93
  ss_guidance_strength=3, ss_sampling_steps=50,
94
  slat_guidance_strength=3, slat_sampling_steps=6,normal_bridge=None):
@@ -136,7 +133,7 @@ def generate_3d(image, seed=-1,
136
  trimesh_mesh.export(mesh_path)
137
 
138
  return mesh_path, mesh_path
139
-
140
  def predict_normal(input_images,input_mask):
141
  test_dataset = DemoData(input_imgs_list=input_images,input_mask=input_mask)
142
  test_loader = DataLoader(test_dataset, batch_size=1)
 
42
  from src.models import LiNo_UniPS
43
  from torch.utils.data import DataLoader
44
  import pytorch_lightning as pl
45
+ import spaces
46
 
47
  MAX_SEED = np.iinfo(np.int32).max
48
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
50
  os.makedirs(TMP_DIR, exist_ok=True)
51
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
52
 
 
53
 
54
 
 
 
 
55
 
56
  def cache_weights(weights_dir: str) -> dict:
57
  import os
 
85
  trimesh_mesh = trimesh.load_mesh(mesh_prompt)
86
  trimesh_mesh.export(mesh_prompt+'.glb')
87
  return mesh_prompt+'.glb'
88
+ @spaces.GPU
89
  def generate_3d(image, seed=-1,
90
  ss_guidance_strength=3, ss_sampling_steps=50,
91
  slat_guidance_strength=3, slat_sampling_steps=6,normal_bridge=None):
 
133
  trimesh_mesh.export(mesh_path)
134
 
135
  return mesh_path, mesh_path
136
+ @spaces.GPU
137
  def predict_normal(input_images,input_mask):
138
  test_dataset = DemoData(input_imgs_list=input_images,input_mask=input_mask)
139
  test_loader = DataLoader(test_dataset, batch_size=1)