lj1995 commited on
Commit
135c82e
·
verified ·
1 Parent(s): 77b4162

Update text/g2pw/onnx_api.py

Browse files
Files changed (1) hide show
  1. text/g2pw/onnx_api.py +51 -1
text/g2pw/onnx_api.py CHANGED
@@ -1,6 +1,52 @@
1
  # This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
2
  # This code is modified from https://github.com/GitYCC/g2pW
3
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import warnings
5
 
6
  warnings.filterwarnings("ignore")
@@ -16,6 +62,10 @@ import numpy as np
16
  import onnxruntime
17
 
18
  onnxruntime.set_default_logger_severity(3)
 
 
 
 
19
  from opencc import OpenCC
20
  from transformers import AutoTokenizer
21
  from pypinyin import pinyin
 
1
  # This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
2
  # This code is modified from https://github.com/GitYCC/g2pW
3
+ def load_nvrtc():
4
+ import torch,sys,os,ctypes
5
+ from pathlib import Path
6
+
7
+ if not torch.cuda.is_available():
8
+ print("[INFO] CUDA is not available, skipping nvrtc setup.")
9
+ return
10
+
11
+ if sys.platform == "win32":
12
+ torch_lib_dir = Path(torch.__file__).parent / "lib"
13
+ if torch_lib_dir.exists():
14
+ os.add_dll_directory(str(torch_lib_dir))
15
+ print(f"[INFO] Added DLL directory: {torch_lib_dir}")
16
+ matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
17
+ if not matching_files:
18
+ print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
19
+ return
20
+ for dll_path in matching_files:
21
+ dll_name = os.path.basename(dll_path)
22
+ try:
23
+ ctypes.CDLL(dll_name)
24
+ print(f"[INFO] Loaded: {dll_name}")
25
+ except OSError as e:
26
+ print(f"[WARNING] Failed to load {dll_name}: {e}")
27
+ else:
28
+ print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
29
+
30
+ elif sys.platform == "linux":
31
+ site_packages = Path(torch.__file__).resolve().parents[1]
32
+ nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
33
+
34
+ if not nvrtc_dir.exists():
35
+ print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
36
+ return
37
+
38
+ matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
39
+ if not matching_files:
40
+ print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
41
+ return
42
+
43
+ for so_path in matching_files:
44
+ try:
45
+ ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
46
+ print(f"[INFO] Loaded: {so_path}")
47
+ except OSError as e:
48
+ print(f"[WARNING] Failed to load {so_path}: {e}")
49
+ load_nvrtc()
50
  import warnings
51
 
52
  warnings.filterwarnings("ignore")
 
62
  import onnxruntime
63
 
64
  onnxruntime.set_default_logger_severity(3)
65
+ try:
66
+ onnxruntime.preload_dlls()
67
+ except:
68
+ traceback.print_exc()
69
  from opencc import OpenCC
70
  from transformers import AutoTokenizer
71
  from pypinyin import pinyin