BladeSzaSza Claude commited on
Commit
5ed6938
Β·
1 Parent(s): 4cfc1e9

integrate Hunyuan3D API via gradio_client

Browse files

- Replace local model loading with API calls to tencent/Hunyuan3D-2.1
- Use generation_all endpoint for both shape and texture generation
- Add proper image handling and temporary file management
- Maintain fallback 3D generation for reliability
- Remove deprecated transformers-based approach

πŸ€– Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. models/model_3d_generator.py +111 -275
models/model_3d_generator.py CHANGED
@@ -7,6 +7,7 @@ from typing import Union, Optional, Dict, Any
7
  from pathlib import Path
8
  import os
9
  import logging
 
10
 
11
  # Set up detailed logging for 3D generation
12
  logging.basicConfig(level=logging.INFO)
@@ -62,151 +63,38 @@ class Hunyuan3DGenerator:
62
  return False
63
 
64
  def load_model(self):
65
- """Lazy load the 3D generation model"""
66
  if self.model is None:
67
- logger.info("πŸš€ Starting 3D model loading process...")
68
 
69
  try:
70
- # Try to import Hunyuan3D components
71
- logger.info("πŸ“¦ Attempting to import Hunyuan3D components...")
72
  try:
73
- from hy3dshape.pipelines import Hunyuan3DDiTFlowMatchingPipeline
74
- from hy3dshape.rembg import BackgroundRemover
75
- logger.info("βœ… Hunyuan3D components imported successfully")
76
 
77
- # Load the pipeline
78
- model_id = self.lite_model_id if self.use_lite else self.model_id
79
- logger.info(f"πŸ“¦ Loading Hunyuan3D pipeline: {model_id}")
 
 
80
 
81
- self.model = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_id)
82
- self.bg_remover = BackgroundRemover()
83
-
84
- logger.info("βœ… Hunyuan3D pipeline loaded successfully")
85
 
86
  except ImportError as import_error:
87
- logger.error(f"❌ Failed to import Hunyuan3D components: {import_error}")
88
- logger.info("πŸ”„ Hunyuan3D not installed, trying alternative approach...")
89
-
90
- # Fallback: Try using transformers AutoModel
91
- logger.info("πŸ“¦ Importing transformers components...")
92
- from transformers import AutoModel, AutoProcessor
93
-
94
- model_id = self.lite_model_id if self.use_lite else self.model_id
95
- logger.info(f"πŸ“¦ Loading model: {model_id}")
96
-
97
- # Check if model exists on HuggingFace
98
- try:
99
- from huggingface_hub import model_info
100
- info = model_info(model_id)
101
- logger.info(f"βœ… Model found on HuggingFace: {info.modelId}")
102
- except Exception as hub_error:
103
- logger.error(f"❌ Model not found on HuggingFace: {hub_error}")
104
- logger.info("πŸ”„ Using fallback 3D generation")
105
- self.model = "fallback"
106
- return
107
-
108
- # Load preprocessor
109
- logger.info("πŸ“¦ Loading preprocessor...")
110
- try:
111
- self.preprocessor = AutoProcessor.from_pretrained(model_id)
112
- logger.info("βœ… Preprocessor loaded successfully")
113
- except Exception as proc_error:
114
- logger.error(f"❌ Preprocessor loading failed: {proc_error}")
115
- logger.info("πŸ”„ Using fallback mode")
116
- self.model = "fallback"
117
- return
118
-
119
- # Load model with optimizations
120
- torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
121
- logger.info(f"πŸ“¦ Using torch dtype: {torch_dtype}")
122
-
123
- # Disable torch.compile to avoid dynamo issues
124
- logger.info("πŸ“¦ Disabling torch compile to avoid dynamo issues...")
125
- torch._dynamo.config.suppress_errors = True
126
 
127
- logger.info("πŸ“¦ Loading 3D model with safe device handling...")
128
-
129
- # Try loading with different strategies
130
- loading_successful = False
131
-
132
- # Strategy 1: Load directly to device
133
- try:
134
- logger.info("πŸ“¦ Strategy 1: Direct device loading...")
135
- self.model = AutoModel.from_pretrained(
136
- model_id,
137
- torch_dtype=torch_dtype,
138
- device_map={"": self.device},
139
- low_cpu_mem_usage=True,
140
- trust_remote_code=True
141
- )
142
- loading_successful = True
143
- logger.info("βœ… Direct device loading successful")
144
- except Exception as e1:
145
- logger.error(f"❌ Strategy 1 failed: {e1}")
146
-
147
- # Strategy 2: Load to CPU first
148
- if not loading_successful:
149
- try:
150
- logger.info("πŸ“¦ Strategy 2: CPU-first loading...")
151
- # Load model to CPU first to avoid meta tensor issues
152
- self.model = AutoModel.from_pretrained(
153
- model_id,
154
- torch_dtype=torch.float32, # Use float32 for CPU loading
155
- low_cpu_mem_usage=True,
156
- device_map=None, # No device mapping initially
157
- trust_remote_code=True
158
- )
159
- logger.info("βœ… 3D model loaded to CPU")
160
-
161
- # Now safely move to target device
162
- logger.info(f"πŸ“¦ Moving model to target device: {self.device}")
163
- try:
164
- if self.device == "cuda":
165
- # Convert to appropriate dtype for GPU
166
- self.model = self.model.to(device=self.device, dtype=torch.float16)
167
- logger.info("βœ… Model moved to CUDA with fp16")
168
- else:
169
- # Keep on CPU
170
- self.model = self.model.to(device="cpu", dtype=torch.float32)
171
- logger.info("βœ… Model kept on CPU with fp32")
172
- loading_successful = True
173
-
174
- except Exception as device_error:
175
- logger.error(f"❌ Device movement failed: {device_error}")
176
- logger.info("πŸ”„ Falling back to CPU...")
177
- self.device = "cpu"
178
- if self.model is not None:
179
- self.model = self.model.to("cpu", dtype=torch.float32)
180
- loading_successful = True
181
- else:
182
- logger.error("❌ Model is None, using fallback mode")
183
- self.model = "fallback"
184
- except Exception as e2:
185
- logger.error(f"❌ Strategy 2 failed: {e2}")
186
-
187
- # If all strategies failed, use fallback
188
- if not loading_successful:
189
- logger.error("❌ All loading strategies failed")
190
- logger.info("πŸ”„ Using fallback 3D generation")
191
- self.model = "fallback"
192
- return
193
-
194
- # Enable optimizations safely
195
- logger.info("πŸ“¦ Applying model optimizations...")
196
- if self.model != "fallback" and hasattr(self.model, 'enable_attention_slicing'):
197
- self.model.enable_attention_slicing()
198
- logger.info("βœ… Attention slicing enabled")
199
- else:
200
- logger.info("⚠️ Attention slicing not available")
201
 
202
- logger.info("πŸŽ‰ 3D model loading completed successfully!")
203
-
204
  except Exception as e:
205
- logger.error(f"❌ Failed to load Hunyuan3D model: {e}")
206
- logger.error(f"❌ Error type: {type(e).__name__}")
207
  logger.info("πŸ”„ Falling back to simple 3D generation...")
208
- # Model loading failed, will use fallback
209
- self.model = "fallback"
210
 
211
  def image_to_3d(self,
212
  image: Union[str, Image.Image, np.ndarray],
@@ -229,7 +117,7 @@ class Hunyuan3DGenerator:
229
  logger.info("βœ… Model already loaded")
230
 
231
  # If model loading failed, use fallback
232
- if self.model == "fallback":
233
  logger.info("πŸ”„ Using fallback 3D generation...")
234
  return self._generate_fallback_3d(image)
235
 
@@ -237,124 +125,78 @@ class Hunyuan3DGenerator:
237
  logger.info("πŸ–ΌοΈ Preparing input image...")
238
  if isinstance(image, str):
239
  logger.info(f"πŸ–ΌοΈ Loading image from path: {image}")
 
240
  image = Image.open(image)
241
  elif isinstance(image, np.ndarray):
242
  logger.info("πŸ–ΌοΈ Converting numpy array to PIL Image")
243
  image = Image.fromarray(image)
 
 
244
  else:
245
  logger.info("πŸ–ΌοΈ Input is already PIL Image")
 
 
246
 
247
- # Ensure RGBA for Hunyuan3D
248
- logger.info(f"πŸ–ΌοΈ Image mode: {image.mode}")
249
- if image.mode != 'RGBA':
250
- logger.info("πŸ–ΌοΈ Converting image to RGBA mode")
251
- image = image.convert('RGBA')
252
-
253
- logger.info(f"πŸ–ΌοΈ Final image size: {image.size}")
254
-
255
- # Remove background if requested
256
- if remove_background and image.mode == 'RGB':
257
- logger.info("🎭 Removing background from image...")
258
- try:
259
- if hasattr(self, 'bg_remover'):
260
- # Use Hunyuan3D's background remover
261
- image = self.bg_remover(image)
262
- logger.info("βœ… Background removed using Hunyuan3D remover")
263
- else:
264
- # Use fallback background removal
265
- image = self._remove_background(image)
266
- logger.info("βœ… Background removed using fallback method")
267
- except Exception as bg_error:
268
- logger.error(f"❌ Background removal failed: {bg_error}")
269
- logger.info("πŸ”„ Continuing with original image...")
270
 
271
- # Check if we have the Hunyuan3D pipeline
272
- if hasattr(self.model, '__call__') and hasattr(self, 'bg_remover'):
273
- # Using Hunyuan3D pipeline
274
- logger.info("🧠 Using Hunyuan3D pipeline for 3D generation...")
275
 
276
  try:
277
- # Generate 3D model using Hunyuan3D
278
- logger.info("πŸš€ Starting Hunyuan3D generation...")
279
- mesh_outputs = self.model(image=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- if isinstance(mesh_outputs, list) and len(mesh_outputs) > 0:
282
- mesh = mesh_outputs[0]
283
- logger.info("βœ… 3D mesh generated successfully")
 
 
 
 
 
 
 
 
 
 
284
 
285
- # Save mesh
286
- logger.info("πŸ’Ύ Saving generated mesh...")
287
- mesh_path = self._save_mesh(mesh)
288
- logger.info(f"βœ… Mesh saved to: {mesh_path}")
289
 
290
- return mesh_path
291
  else:
292
- logger.error("❌ No mesh output from Hunyuan3D")
293
- raise Exception("Empty mesh output")
294
 
295
- except Exception as hunyuan_error:
296
- logger.error(f"❌ Hunyuan3D generation failed: {hunyuan_error}")
297
  logger.info("πŸ”„ Falling back to alternative generation...")
298
  return self._generate_fallback_3d(image)
299
 
300
  else:
301
- # Using transformers-based approach (original code)
302
- logger.info("🧠 Using transformers-based 3D generation...")
303
-
304
- # Resize for processing
305
- logger.info("πŸ–ΌοΈ Resizing image for processing (512x512)...")
306
- image = image.resize((512, 512), Image.Resampling.LANCZOS)
307
- logger.info("βœ… Image resized successfully")
308
-
309
- # Process with model
310
- logger.info("🧠 Starting model inference...")
311
- with torch.no_grad():
312
- try:
313
- # Preprocess image
314
- logger.info("πŸ”„ Preprocessing image for model...")
315
- inputs = self.preprocessor(images=image, return_tensors="pt")
316
- logger.info(f"πŸ”„ Input tensor shape: {inputs['pixel_values'].shape if 'pixel_values' in inputs else 'unknown'}")
317
-
318
- # Move inputs to device safely
319
- logger.info(f"πŸ”„ Moving inputs to device: {self.device}")
320
- try:
321
- # Avoid device-related dynamo issues
322
- device_str = str(self.device) # Convert to string to avoid torch.device in dynamo
323
- inputs = {k: v.to(device_str) for k, v in inputs.items() if hasattr(v, 'to')}
324
- logger.info("βœ… Inputs moved to device successfully")
325
- except Exception as device_error:
326
- logger.error(f"❌ Failed to move inputs to device: {device_error}")
327
- raise device_error
328
-
329
- # Generate 3D
330
- logger.info("πŸš€ Starting 3D generation inference...")
331
- logger.info(f"πŸš€ Parameters: steps={self.num_inference_steps}, guidance={self.guidance_scale}")
332
-
333
- outputs = self.model.generate(
334
- **inputs,
335
- num_inference_steps=self.num_inference_steps,
336
- guidance_scale=self.guidance_scale,
337
- texture_resolution=texture_resolution
338
- )
339
- logger.info("βœ… 3D generation completed successfully")
340
-
341
- # Extract mesh
342
- logger.info("πŸ”§ Extracting mesh from model outputs...")
343
- mesh = self._extract_mesh(outputs)
344
- logger.info("βœ… Mesh extraction completed")
345
-
346
- except Exception as inference_error:
347
- logger.error(f"❌ Model inference failed: {inference_error}")
348
- logger.error(f"❌ Inference error type: {type(inference_error).__name__}")
349
- raise inference_error
350
-
351
- # Save mesh
352
- logger.info("πŸ’Ύ Saving generated mesh...")
353
- mesh_path = self._save_mesh(mesh)
354
- logger.info(f"βœ… Mesh saved to: {mesh_path}")
355
-
356
- logger.info("πŸŽ‰ 3D generation process completed successfully!")
357
- return mesh_path
358
 
359
  except Exception as e:
360
  logger.error(f"❌ 3D generation error: {e}")
@@ -387,40 +229,6 @@ class Hunyuan3DGenerator:
387
  image.putdata(new_data)
388
  return image
389
 
390
- def _extract_mesh(self, model_outputs: Dict[str, Any]) -> trimesh.Trimesh:
391
- """Extract mesh from model outputs"""
392
- # This would depend on actual Hunyuan3D output format
393
- # Placeholder implementation
394
-
395
- if 'vertices' in model_outputs and 'faces' in model_outputs:
396
- vertices = model_outputs['vertices'].cpu().numpy()
397
- faces = model_outputs['faces'].cpu().numpy()
398
-
399
- # Create trimesh object
400
- mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
401
-
402
- # Add texture if available
403
- if 'texture' in model_outputs:
404
- # Apply texture to mesh
405
- pass
406
-
407
- return mesh
408
- else:
409
- # Create a simple mesh if outputs are different
410
- return self._create_simple_mesh()
411
-
412
- def _create_simple_mesh(self) -> trimesh.Trimesh:
413
- """Create a simple placeholder mesh"""
414
- # Create a simple sphere as placeholder
415
- mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0)
416
-
417
- # Add some variation
418
- mesh.vertices += np.random.normal(0, 0.05, mesh.vertices.shape)
419
-
420
- # Smooth the mesh
421
- mesh = mesh.smoothed()
422
-
423
- return mesh
424
 
425
  def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str:
426
  """Generate fallback 3D model when main model fails"""
@@ -494,6 +302,36 @@ class Hunyuan3DGenerator:
494
 
495
  return mesh_path
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  def text_to_3d(self, text_prompt: str) -> str:
498
  """Generate 3D model from text description"""
499
  # First generate image, then convert to 3D
@@ -501,15 +339,13 @@ class Hunyuan3DGenerator:
501
  raise NotImplementedError("Text to 3D requires image generation first")
502
 
503
  def to(self, device: str):
504
- """Move model to specified device"""
505
  self.device = device
506
- if self.model and self.model != "fallback":
507
- self.model.to(device)
508
 
509
  def __del__(self):
510
  """Cleanup when object is destroyed"""
511
- if self.model and self.model != "fallback":
512
- del self.model
513
- if self.preprocessor:
514
- del self.preprocessor
515
- torch.cuda.empty_cache()
 
7
  from pathlib import Path
8
  import os
9
  import logging
10
+ import random
11
 
12
  # Set up detailed logging for 3D generation
13
  logging.basicConfig(level=logging.INFO)
 
63
  return False
64
 
65
  def load_model(self):
66
+ """Initialize Gradio client for Hunyuan3D API"""
67
  if self.model is None:
68
+ logger.info("πŸš€ Starting Hunyuan3D API client initialization...")
69
 
70
  try:
71
+ # Try to import gradio_client
72
+ logger.info("πŸ“¦ Attempting to import gradio_client...")
73
  try:
74
+ from gradio_client import Client, handle_file
75
+ logger.info("βœ… gradio_client imported successfully")
 
76
 
77
+ # Initialize Hunyuan3D client
78
+ logger.info("🌐 Connecting to Hunyuan3D API...")
79
+ self.client = Client("tencent/Hunyuan3D-2.1")
80
+ self.handle_file = handle_file
81
+ self.model = "gradio_api"
82
 
83
+ logger.info("βœ… Hunyuan3D API client initialized successfully")
 
 
 
84
 
85
  except ImportError as import_error:
86
+ logger.error(f"❌ Failed to import gradio_client: {import_error}")
87
+ logger.info("πŸ’‘ Please install gradio_client:")
88
+ logger.info(" pip install gradio_client")
89
+ logger.info("πŸ”„ Using fallback mode instead...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ self.model = "fallback_mode"
92
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
94
  except Exception as e:
95
+ logger.error(f"❌ Failed to initialize Hunyuan3D API client: {e}")
 
96
  logger.info("πŸ”„ Falling back to simple 3D generation...")
97
+ self.model = "fallback_mode"
 
98
 
99
  def image_to_3d(self,
100
  image: Union[str, Image.Image, np.ndarray],
 
117
  logger.info("βœ… Model already loaded")
118
 
119
  # If model loading failed, use fallback
120
+ if self.model == "fallback_mode":
121
  logger.info("πŸ”„ Using fallback 3D generation...")
122
  return self._generate_fallback_3d(image)
123
 
 
125
  logger.info("πŸ–ΌοΈ Preparing input image...")
126
  if isinstance(image, str):
127
  logger.info(f"πŸ–ΌοΈ Loading image from path: {image}")
128
+ image_path = image
129
  image = Image.open(image)
130
  elif isinstance(image, np.ndarray):
131
  logger.info("πŸ–ΌοΈ Converting numpy array to PIL Image")
132
  image = Image.fromarray(image)
133
+ # Save to temp file for gradio client
134
+ image_path = self._save_temp_image(image)
135
  else:
136
  logger.info("πŸ–ΌοΈ Input is already PIL Image")
137
+ # Save to temp file for gradio client
138
+ image_path = self._save_temp_image(image)
139
 
140
+ logger.info(f"πŸ–ΌοΈ Image mode: {image.mode}, size: {image.size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ # Check if we have the Gradio API client
143
+ if self.model == "gradio_api" and hasattr(self, 'client'):
144
+ logger.info("🌐 Using Hunyuan3D Gradio API for 3D generation...")
 
145
 
146
  try:
147
+ # Generate 3D model using Hunyuan3D API
148
+ logger.info("πŸš€ Starting Hunyuan3D API generation...")
149
+
150
+ # Use generation_all for both shape and texture
151
+ logger.info("πŸ“€ Calling generation_all API...")
152
+ result = self.client.predict(
153
+ image=self.handle_file(image_path),
154
+ mv_image_front=None,
155
+ mv_image_back=None,
156
+ mv_image_left=None,
157
+ mv_image_right=None,
158
+ steps=self.num_inference_steps,
159
+ guidance_scale=self.guidance_scale,
160
+ seed=random.randint(1, 10000),
161
+ octree_resolution=self.resolution,
162
+ check_box_rembg=remove_background,
163
+ num_chunks=8000,
164
+ randomize_seed=True,
165
+ api_name="/generation_all"
166
+ )
167
 
168
+ logger.info("βœ… API call completed successfully")
169
+ logger.info(f"πŸ“Š Result type: {type(result)}, length: {len(result) if isinstance(result, (list, tuple)) else 'N/A'}")
170
+
171
+ # Extract mesh file from result
172
+ # Result format: [shape_file, texture_file, html_output, mesh_stats, seed]
173
+ if isinstance(result, (list, tuple)) and len(result) >= 2:
174
+ shape_file = result[0] # Shape file path
175
+ texture_file = result[1] # Textured file path (if available)
176
+
177
+ # Use textured file if available, otherwise use shape file
178
+ mesh_file = texture_file if texture_file else shape_file
179
+
180
+ logger.info(f"βœ… Generated mesh file: {mesh_file}")
181
 
182
+ # Copy to our output location
183
+ output_path = self._save_output_mesh(mesh_file)
184
+ logger.info(f"βœ… Mesh saved to: {output_path}")
 
185
 
186
+ return output_path
187
  else:
188
+ logger.error("❌ Unexpected result format from Hunyuan3D API")
189
+ raise Exception("Invalid API response format")
190
 
191
+ except Exception as api_error:
192
+ logger.error(f"❌ Hunyuan3D API generation failed: {api_error}")
193
  logger.info("πŸ”„ Falling back to alternative generation...")
194
  return self._generate_fallback_3d(image)
195
 
196
  else:
197
+ # Fallback to simple 3D generation
198
+ logger.info("πŸ”„ No API client available, using fallback...")
199
+ return self._generate_fallback_3d(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  except Exception as e:
202
  logger.error(f"❌ 3D generation error: {e}")
 
229
  image.putdata(new_data)
230
  return image
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str:
234
  """Generate fallback 3D model when main model fails"""
 
302
 
303
  return mesh_path
304
 
305
+ def _save_temp_image(self, image: Image.Image) -> str:
306
+ """Save PIL image to temporary file for gradio client"""
307
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
308
+ image_path = tmp.name
309
+
310
+ # Save image
311
+ image.save(image_path, 'PNG')
312
+ logger.info(f"πŸ’Ύ Saved temp image to: {image_path}")
313
+
314
+ return image_path
315
+
316
+ def _save_output_mesh(self, source_mesh_path: str) -> str:
317
+ """Copy generated mesh to our output location"""
318
+ import shutil
319
+
320
+ # Create output directory if it doesn't exist
321
+ output_dir = "/tmp/hunyuan3d_output"
322
+ os.makedirs(output_dir, exist_ok=True)
323
+
324
+ # Generate unique filename
325
+ timestamp = tempfile.mktemp().split('/')[-1]
326
+ output_filename = f"hunyuan3d_mesh_{timestamp}.glb"
327
+ output_path = os.path.join(output_dir, output_filename)
328
+
329
+ # Copy the file
330
+ shutil.copy2(source_mesh_path, output_path)
331
+ logger.info(f"πŸ“ Copied mesh from {source_mesh_path} to {output_path}")
332
+
333
+ return output_path
334
+
335
  def text_to_3d(self, text_prompt: str) -> str:
336
  """Generate 3D model from text description"""
337
  # First generate image, then convert to 3D
 
339
  raise NotImplementedError("Text to 3D requires image generation first")
340
 
341
  def to(self, device: str):
342
+ """Update device preference"""
343
  self.device = device
344
+ logger.info(f"πŸ”§ Device preference updated to: {device}")
 
345
 
346
  def __del__(self):
347
  """Cleanup when object is destroyed"""
348
+ if hasattr(self, 'client'):
349
+ del self.client
350
+ if torch.cuda.is_available():
351
+ torch.cuda.empty_cache()