ThorbenFroehlking commited on
Commit
78b2c3a
·
1 Parent(s): af13564
Files changed (2) hide show
  1. app.py +197 -112
  2. model_loader.py +26 -19
app.py CHANGED
@@ -19,6 +19,11 @@ from torch.utils.data import DataLoader
19
  import re
20
  import pandas as pd
21
  import copy
 
 
 
 
 
22
 
23
  import transformers
24
  from transformers import AutoTokenizer, DataCollatorForTokenClassification
@@ -27,13 +32,26 @@ from datasets import Dataset
27
 
28
  from scipy.special import expit
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Load model and move to device
31
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
32
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
33
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
34
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full'
35
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_0925'
36
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_0925_v2'
37
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full_v2'
38
  max_length = 1500
39
  model, tokenizer = load_model(checkpoint, max_length)
@@ -41,21 +59,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
  model.to(device)
42
  model.eval()
43
 
 
 
 
 
 
 
 
 
 
44
  def normalize_scores(scores):
45
  min_score = np.min(scores)
46
  max_score = np.max(scores)
47
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
 
48
 
49
  def read_mol(pdb_path):
50
  """Read PDB file and return its content as a string"""
51
  with open(pdb_path, 'r') as f:
52
  return f.read()
53
 
54
- def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
55
  """
56
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
57
  If a structure file already exists locally, it uses that.
58
  """
 
 
59
  file_path = download_structure(pdb_id, output_dir)
60
  return file_path
61
 
@@ -76,23 +106,29 @@ def download_structure(pdb_id: str, output_dir: str) -> str:
76
  return file_path
77
  return None
78
 
79
- def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
80
  """
81
  Convert a CIF file to PDB format using BioPython and return the PDB file path.
82
  """
 
 
83
  pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))
84
  parser = MMCIFParser(QUIET=True)
85
  structure = parser.get_structure('protein', cif_path)
86
  io = PDBIO()
87
  io.set_structure(structure)
88
  io.save(pdb_path)
 
 
 
 
89
  return pdb_path
90
 
91
  def fetch_pdb(pdb_id):
92
- pdb_path = fetch_structure(pdb_id)
93
  _, ext = os.path.splitext(pdb_path)
94
  if ext == '.cif':
95
- pdb_path = convert_cif_to_pdb(pdb_path)
96
  return pdb_path
97
 
98
  def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:
@@ -102,7 +138,7 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
102
  parser = PDBParser(QUIET=True)
103
  structure = parser.get_structure('protein', input_pdb)
104
 
105
- output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
106
 
107
  # Create scores dictionary for easy lookup
108
  scores_dict = {resi: score for resi, score in residue_scores}
@@ -132,6 +168,9 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
132
  io.set_structure(structure[0])
133
  io.save(output_pdb, selector)
134
 
 
 
 
135
  return output_pdb
136
 
137
  def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
@@ -157,7 +196,7 @@ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time,
157
 
158
  # Add PyMOL commands for each score bracket
159
  for bracket, residues in residues_by_bracket.items():
160
- if residues: # Only add commands if there are residues in this bracket
161
  color = bracket_colors[bracket]
162
  resi_list = '+'.join(map(str, residues))
163
  pymol_commands += f"""
@@ -184,9 +223,6 @@ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues
184
 
185
  return result_str
186
 
187
-
188
-
189
-
190
  def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
191
  # Determine if input is a PDB ID or file path
192
  if pdb_id_or_file.endswith('.pdb'):
@@ -211,13 +247,23 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
211
  sequence_id = [res.id[1] for res in protein_residues]
212
 
213
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
 
214
  with torch.no_grad():
215
- outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
 
 
 
 
 
 
216
 
217
  # Calculate scores and normalize them
218
- raw_scores = expit(outputs[:, 1] - outputs[:, 0])
219
  normalized_scores = normalize_scores(raw_scores)
220
 
 
 
 
221
  # Choose which scores to use based on score_type
222
  display_scores = normalized_scores if score_type == 'normalized' else raw_scores
223
 
@@ -263,13 +309,17 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
263
  mol_vis = molecule(pdb_path, residue_scores, segment)
264
 
265
  # Create prediction file
266
- prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
267
  with open(prediction_file, "w") as f:
268
  f.write(result_str)
269
 
270
- scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
271
  os.rename(scored_pdb, scored_pdb_name)
272
 
 
 
 
 
273
  return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
274
 
275
  def molecule(input_pdb, residue_scores=None, segment='A'):
@@ -411,6 +461,9 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
411
  </html>
412
  """
413
 
 
 
 
414
  # Return the HTML content within an iframe safely encoded for special characters
415
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
416
 
@@ -487,98 +540,114 @@ with gr.Blocks(css="""
487
  last_pdb_id = gr.State(None)
488
 
489
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
490
- selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
491
-
492
- # First get the actual PDB file path
493
- if mode == "PDB ID":
494
- pdb_path = fetch_pdb(pdb_id) # Get the actual file path
495
-
496
- pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
497
- # Store the actual file path, not just the PDB ID
498
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
499
- elif mode == "Upload File":
500
- _, ext = os.path.splitext(pdb_file.name)
501
- file_path = os.path.join('./', f"{_}{ext}")
502
- if ext == '.cif':
503
- pdb_path = convert_cif_to_pdb(file_path)
504
- else:
505
- pdb_path = file_path
506
 
507
- pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
508
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
  def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
511
  if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
512
  return None, None, None
513
 
514
- # Choose scores based on radio button selection
515
- selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
516
- selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
517
-
518
- # Generate visualization with selected scores
519
- mol_vis = molecule(pdb_path, selected_scores, segment)
520
-
521
- # Generate PyMOL commands and downloadable files
522
- # Get structure for residue info
523
- _, ext = os.path.splitext(pdb_path)
524
- parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
525
- structure = parser.get_structure('protein', pdb_path)
526
- chain = structure[0][segment]
527
- protein_residues = [res for res in chain if is_aa(res)]
528
- sequence = "".join(seq1(res.resname) for res in protein_residues)
529
-
530
- # Define score brackets
531
- score_brackets = {
532
- "0.0-0.2": (0.0, 0.2),
533
- "0.2-0.4": (0.2, 0.4),
534
- "0.4-0.6": (0.4, 0.6),
535
- "0.6-0.8": (0.6, 0.8),
536
- "0.8-1.0": (0.8, 1.0)
537
- }
538
-
539
- # Initialize a dictionary to store residues by bracket
540
- residues_by_bracket = {bracket: [] for bracket in score_brackets}
541
-
542
- # Categorize residues into brackets
543
- for resi, score in selected_scores:
544
- for bracket, (lower, upper) in score_brackets.items():
545
- if lower <= score < upper:
546
- residues_by_bracket[bracket].append(resi)
547
- break
548
-
549
- # Generate timestamp
550
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
551
-
552
- # Generate result text and PyMOL commands based on score type
553
- display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
554
- scores_array = [score for _, score in selected_scores]
555
- result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
556
- scores_array, current_time, display_score_type)
557
- pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
558
-
559
- # Create chain-specific PDB with scores in B-factor
560
- scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
561
-
562
- # Create prediction file
563
- prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
564
- with open(prediction_file, "w") as f:
565
- f.write(result_str)
566
-
567
- scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
568
- os.rename(scored_pdb, scored_pdb_name)
569
-
570
- return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
 
 
 
 
 
 
571
 
572
  def fetch_interface(mode, pdb_id, pdb_file):
573
  if mode == "PDB ID":
574
  return fetch_pdb(pdb_id)
575
  elif mode == "Upload File":
576
  _, ext = os.path.splitext(pdb_file.name)
577
- file_path = os.path.join('./', f"{_}{ext}")
 
578
  if ext == '.cif':
579
- pdb_path = convert_cif_to_pdb(file_path)
580
  else:
581
- pdb_path= file_path
582
  return pdb_path
583
 
584
  def toggle_mode(selected_mode):
@@ -586,8 +655,6 @@ with gr.Blocks(css="""
586
  return gr.update(visible=True), gr.update(visible=False)
587
  else:
588
  return gr.update(visible=False), gr.update(visible=True)
589
-
590
-
591
 
592
  mode.change(
593
  toggle_mode,
@@ -628,17 +695,35 @@ with gr.Blocks(css="""
628
  )
629
 
630
  def predict_utils(sequence):
631
- input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
632
- with torch.no_grad():
633
- outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
634
-
635
- raw_scores = expit(outputs[:, 1] - outputs[:, 0])
636
- normalized_scores = normalize_scores(raw_scores)
637
-
638
- return {
639
- "raw_scores": raw_scores.tolist(),
640
- "normalized_scores": normalized_scores.tolist()
641
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  dummy_input = gr.Textbox(visible=False)
644
  dummy_output = gr.Textbox(visible=False)
@@ -650,4 +735,4 @@ with gr.Blocks(css="""
650
  outputs=[dummy_output]
651
  )
652
 
653
- demo.launch(share=True)
 
19
  import re
20
  import pandas as pd
21
  import copy
22
+ import gc
23
+ import tempfile
24
+ import shutil
25
+ import atexit
26
+ import weakref
27
 
28
  import transformers
29
  from transformers import AutoTokenizer, DataCollatorForTokenClassification
 
32
 
33
  from scipy.special import expit
34
 
35
+ # Create a temporary directory for this session
36
+ TEMP_DIR = tempfile.mkdtemp(prefix="protein_binding_")
37
+ print(f"Using temporary directory: {TEMP_DIR}")
38
+
39
+ # Registry to track created files for cleanup
40
+ _file_registry = weakref.WeakSet()
41
+
42
+ def cleanup_temp_files():
43
+ """Clean up temporary directory on exit"""
44
+ try:
45
+ if os.path.exists(TEMP_DIR):
46
+ shutil.rmtree(TEMP_DIR)
47
+ print(f"Cleaned up temporary directory: {TEMP_DIR}")
48
+ except Exception as e:
49
+ print(f"Error cleaning up temp directory: {e}")
50
+
51
+ # Register cleanup function
52
+ atexit.register(cleanup_temp_files)
53
+
54
  # Load model and move to device
 
 
 
 
 
 
55
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full_v2'
56
  max_length = 1500
57
  model, tokenizer = load_model(checkpoint, max_length)
 
59
  model.to(device)
60
  model.eval()
61
 
62
+ def cleanup_files(*file_paths):
63
+ """Helper function to clean up files"""
64
+ for path in file_paths:
65
+ if path and os.path.exists(path):
66
+ try:
67
+ os.remove(path)
68
+ except Exception as e:
69
+ print(f"Could not remove {path}: {e}")
70
+
71
  def normalize_scores(scores):
72
  min_score = np.min(scores)
73
  max_score = np.max(scores)
74
+ normalized = (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
75
+ return normalized
76
 
77
  def read_mol(pdb_path):
78
  """Read PDB file and return its content as a string"""
79
  with open(pdb_path, 'r') as f:
80
  return f.read()
81
 
82
+ def fetch_structure(pdb_id: str, output_dir: str = None) -> str:
83
  """
84
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
85
  If a structure file already exists locally, it uses that.
86
  """
87
+ if output_dir is None:
88
+ output_dir = TEMP_DIR
89
  file_path = download_structure(pdb_id, output_dir)
90
  return file_path
91
 
 
106
  return file_path
107
  return None
108
 
109
+ def convert_cif_to_pdb(cif_path: str, output_dir: str = None) -> str:
110
  """
111
  Convert a CIF file to PDB format using BioPython and return the PDB file path.
112
  """
113
+ if output_dir is None:
114
+ output_dir = TEMP_DIR
115
  pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))
116
  parser = MMCIFParser(QUIET=True)
117
  structure = parser.get_structure('protein', cif_path)
118
  io = PDBIO()
119
  io.set_structure(structure)
120
  io.save(pdb_path)
121
+
122
+ # Clean up CIF file after conversion
123
+ cleanup_files(cif_path)
124
+
125
  return pdb_path
126
 
127
  def fetch_pdb(pdb_id):
128
+ pdb_path = fetch_structure(pdb_id, TEMP_DIR)
129
  _, ext = os.path.splitext(pdb_path)
130
  if ext == '.cif':
131
+ pdb_path = convert_cif_to_pdb(pdb_path, TEMP_DIR)
132
  return pdb_path
133
 
134
  def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:
 
138
  parser = PDBParser(QUIET=True)
139
  structure = parser.get_structure('protein', input_pdb)
140
 
141
+ output_pdb = os.path.join(TEMP_DIR, f"{os.path.splitext(os.path.basename(input_pdb))[0]}_{chain_id}_predictions_scores.pdb")
142
 
143
  # Create scores dictionary for easy lookup
144
  scores_dict = {resi: score for resi, score in residue_scores}
 
168
  io.set_structure(structure[0])
169
  io.save(output_pdb, selector)
170
 
171
+ # Clear references
172
+ del structure, io, selector, scores_dict
173
+
174
  return output_pdb
175
 
176
  def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
 
196
 
197
  # Add PyMOL commands for each score bracket
198
  for bracket, residues in residues_by_bracket.items():
199
+ if residues:
200
  color = bracket_colors[bracket]
201
  resi_list = '+'.join(map(str, residues))
202
  pymol_commands += f"""
 
223
 
224
  return result_str
225
 
 
 
 
226
  def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
227
  # Determine if input is a PDB ID or file path
228
  if pdb_id_or_file.endswith('.pdb'):
 
247
  sequence_id = [res.id[1] for res in protein_residues]
248
 
249
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
250
+
251
  with torch.no_grad():
252
+ outputs = model(input_ids).logits
253
+ outputs_cpu = outputs.detach().cpu().numpy().squeeze()
254
+
255
+ # Explicitly delete GPU tensors
256
+ del outputs, input_ids
257
+ if torch.cuda.is_available():
258
+ torch.cuda.empty_cache()
259
 
260
  # Calculate scores and normalize them
261
+ raw_scores = expit(outputs_cpu[:, 1] - outputs_cpu[:, 0])
262
  normalized_scores = normalize_scores(raw_scores)
263
 
264
+ # Clear outputs_cpu
265
+ del outputs_cpu
266
+
267
  # Choose which scores to use based on score_type
268
  display_scores = normalized_scores if score_type == 'normalized' else raw_scores
269
 
 
309
  mol_vis = molecule(pdb_path, residue_scores, segment)
310
 
311
  # Create prediction file
312
+ prediction_file = os.path.join(TEMP_DIR, f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt")
313
  with open(prediction_file, "w") as f:
314
  f.write(result_str)
315
 
316
+ scored_pdb_name = os.path.join(TEMP_DIR, f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb")
317
  os.rename(scored_pdb, scored_pdb_name)
318
 
319
+ # Clear large objects from memory
320
+ del structure, chain, protein_residues, raw_scores, normalized_scores, display_scores
321
+ gc.collect()
322
+
323
  return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
324
 
325
  def molecule(input_pdb, residue_scores=None, segment='A'):
 
461
  </html>
462
  """
463
 
464
+ # Clear mol from memory after use
465
+ del mol
466
+
467
  # Return the HTML content within an iframe safely encoded for special characters
468
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
469
 
 
540
  last_pdb_id = gr.State(None)
541
 
542
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
543
+ try:
544
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
+ # First get the actual PDB file path
547
+ if mode == "PDB ID":
548
+ pdb_path = fetch_pdb(pdb_id)
549
+
550
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
551
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
552
+ elif mode == "Upload File":
553
+ _, ext = os.path.splitext(pdb_file.name)
554
+ file_path = os.path.join(TEMP_DIR, f"{os.path.basename(pdb_file.name)}")
555
+
556
+ # Copy uploaded file to temp directory
557
+ shutil.copy(pdb_file.name, file_path)
558
+
559
+ if ext == '.cif':
560
+ pdb_path = convert_cif_to_pdb(file_path, TEMP_DIR)
561
+ else:
562
+ pdb_path = file_path
563
+
564
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
565
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
566
+ finally:
567
+ # Force garbage collection after processing
568
+ gc.collect()
569
+ if torch.cuda.is_available():
570
+ torch.cuda.empty_cache()
571
 
572
  def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
573
  if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
574
  return None, None, None
575
 
576
+ try:
577
+ # Choose scores based on radio button selection
578
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
579
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
580
+
581
+ # Generate visualization with selected scores
582
+ mol_vis = molecule(pdb_path, selected_scores, segment)
583
+
584
+ # Generate PyMOL commands and downloadable files
585
+ # Get structure for residue info
586
+ _, ext = os.path.splitext(pdb_path)
587
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
588
+ structure = parser.get_structure('protein', pdb_path)
589
+ chain = structure[0][segment]
590
+ protein_residues = [res for res in chain if is_aa(res)]
591
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
592
+
593
+ # Define score brackets
594
+ score_brackets = {
595
+ "0.0-0.2": (0.0, 0.2),
596
+ "0.2-0.4": (0.2, 0.4),
597
+ "0.4-0.6": (0.4, 0.6),
598
+ "0.6-0.8": (0.6, 0.8),
599
+ "0.8-1.0": (0.8, 1.0)
600
+ }
601
+
602
+ # Initialize a dictionary to store residues by bracket
603
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
604
+
605
+ # Categorize residues into brackets
606
+ for resi, score in selected_scores:
607
+ for bracket, (lower, upper) in score_brackets.items():
608
+ if lower <= score < upper:
609
+ residues_by_bracket[bracket].append(resi)
610
+ break
611
+
612
+ # Generate timestamp
613
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
614
+
615
+ # Generate result text and PyMOL commands based on score type
616
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
617
+ scores_array = [score for _, score in selected_scores]
618
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
619
+ scores_array, current_time, display_score_type)
620
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
621
+
622
+ # Create chain-specific PDB with scores in B-factor
623
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
624
+
625
+ # Create prediction file
626
+ prediction_file = os.path.join(TEMP_DIR, f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt")
627
+ with open(prediction_file, "w") as f:
628
+ f.write(result_str)
629
+
630
+ scored_pdb_name = os.path.join(TEMP_DIR, f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb")
631
+ os.rename(scored_pdb, scored_pdb_name)
632
+
633
+ # Clear memory
634
+ del structure, chain, protein_residues, scores_array
635
+
636
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
637
+ finally:
638
+ gc.collect()
639
 
640
  def fetch_interface(mode, pdb_id, pdb_file):
641
  if mode == "PDB ID":
642
  return fetch_pdb(pdb_id)
643
  elif mode == "Upload File":
644
  _, ext = os.path.splitext(pdb_file.name)
645
+ file_path = os.path.join(TEMP_DIR, f"{os.path.basename(pdb_file.name)}")
646
+ shutil.copy(pdb_file.name, file_path)
647
  if ext == '.cif':
648
+ pdb_path = convert_cif_to_pdb(file_path, TEMP_DIR)
649
  else:
650
+ pdb_path = file_path
651
  return pdb_path
652
 
653
  def toggle_mode(selected_mode):
 
655
  return gr.update(visible=True), gr.update(visible=False)
656
  else:
657
  return gr.update(visible=False), gr.update(visible=True)
 
 
658
 
659
  mode.change(
660
  toggle_mode,
 
695
  )
696
 
697
  def predict_utils(sequence):
698
+ try:
699
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
700
+ with torch.no_grad():
701
+ outputs = model(input_ids).logits
702
+ outputs_cpu = outputs.detach().cpu().numpy().squeeze()
703
+
704
+ # Explicitly delete GPU tensors
705
+ del outputs, input_ids
706
+ if torch.cuda.is_available():
707
+ torch.cuda.empty_cache()
708
+
709
+ raw_scores = expit(outputs_cpu[:, 1] - outputs_cpu[:, 0])
710
+ normalized_scores = normalize_scores(raw_scores)
711
+
712
+ result = {
713
+ "raw_scores": raw_scores.tolist(),
714
+ "normalized_scores": normalized_scores.tolist()
715
+ }
716
+
717
+ # Clear memory
718
+ del outputs_cpu, raw_scores, normalized_scores
719
+ gc.collect()
720
+
721
+ return result
722
+ except Exception as e:
723
+ gc.collect()
724
+ if torch.cuda.is_available():
725
+ torch.cuda.empty_cache()
726
+ raise e
727
 
728
  dummy_input = gr.Textbox(visible=False)
729
  dummy_output = gr.Textbox(visible=False)
 
735
  outputs=[dummy_output]
736
  )
737
 
738
+ demo.launch(share=True)
model_loader.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  import os
12
  import pandas as pd
13
  import copy
 
14
 
15
  import transformers, datasets
16
  from transformers.modeling_outputs import TokenClassifierOutput
@@ -279,27 +280,25 @@ def load_T5_model_classification(checkpoint, num_labels, half_precision, full =
279
  # Load model and tokenizer
280
 
281
  if "ankh" in checkpoint :
282
- model = T5EncoderModel.from_pretrained(checkpoint,resume_download=True)
283
- tokenizer = AutoTokenizer.from_pretrained(checkpoint,resume_download=True)
284
 
285
  elif "prot_t5" in checkpoint:
286
  # possible to load the half precision model (thanks to @pawel-rezo for pointing that out)
287
  if half_precision and deepspeed:
288
- #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
289
- #model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16)#.to(torch.device('cuda')
290
- tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False,resume_download=True)
291
- model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'),resume_download=True)
292
  else:
293
- model = T5EncoderModel.from_pretrained(checkpoint)
294
- tokenizer = T5Tokenizer.from_pretrained(checkpoint)
295
 
296
  elif "ProstT5" in checkpoint:
297
  if half_precision and deepspeed:
298
- tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False,resume_download=True)
299
- model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'),resume_download=True)
300
  else:
301
- model = T5EncoderModel.from_pretrained(checkpoint,resume_download=True)
302
- tokenizer = T5Tokenizer.from_pretrained(checkpoint,resume_download=True)
303
 
304
  # Create new Classifier model with PT5 dimensions
305
  class_config=ClassConfig(num_labels=num_labels)
@@ -309,8 +308,13 @@ def load_T5_model_classification(checkpoint, num_labels, half_precision, full =
309
  class_model.shared=model.shared
310
  class_model.encoder=model.encoder
311
 
312
- # Delete the checkpoint model
313
- model=class_model
 
 
 
 
 
314
  del class_model
315
 
316
  if full == True:
@@ -613,9 +617,7 @@ def load_esm_model_classification(checkpoint, num_labels, half_precision, full=F
613
 
614
  return model, tokenizer
615
 
616
- def load_model(checkpoint,max_length):
617
- #checkpoint='ThorbenF/prot_t5_xl_uniref50'
618
- #best_model_path='ThorbenF/prot_t5_xl_uniref50/cpt.pth'
619
  full=False
620
  deepspeed=False
621
  mixed=False
@@ -629,12 +631,17 @@ def load_model(checkpoint,max_length):
629
  else:
630
  model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed)
631
 
632
-
633
  # Download the file
634
  local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth")
635
 
636
- # Load the best model state
637
  state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True)
638
  model.load_state_dict(state_dict)
 
 
 
 
 
 
639
 
640
  return model, tokenizer
 
11
  import os
12
  import pandas as pd
13
  import copy
14
+ import gc
15
 
16
  import transformers, datasets
17
  from transformers.modeling_outputs import TokenClassifierOutput
 
280
  # Load model and tokenizer
281
 
282
  if "ankh" in checkpoint :
283
+ model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True)
284
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, resume_download=True)
285
 
286
  elif "prot_t5" in checkpoint:
287
  # possible to load the half precision model (thanks to @pawel-rezo for pointing that out)
288
  if half_precision and deepspeed:
289
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False, resume_download=True)
290
+ model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16, resume_download=True).to(torch.device('cuda'))
 
 
291
  else:
292
+ model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True)
293
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, resume_download=True)
294
 
295
  elif "ProstT5" in checkpoint:
296
  if half_precision and deepspeed:
297
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False, resume_download=True)
298
+ model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16, resume_download=True).to(torch.device('cuda'))
299
  else:
300
+ model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True)
301
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, resume_download=True)
302
 
303
  # Create new Classifier model with PT5 dimensions
304
  class_config=ClassConfig(num_labels=num_labels)
 
308
  class_model.shared=model.shared
309
  class_model.encoder=model.encoder
310
 
311
+ # Delete the checkpoint model and clear memory
312
+ del model
313
+ gc.collect()
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
+
317
+ model = class_model
318
  del class_model
319
 
320
  if full == True:
 
617
 
618
  return model, tokenizer
619
 
620
+ def load_model(checkpoint, max_length):
 
 
621
  full=False
622
  deepspeed=False
623
  mixed=False
 
631
  else:
632
  model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed)
633
 
 
634
  # Download the file
635
  local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth")
636
 
637
+ # Load the best model state with memory mapping for efficiency
638
  state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True)
639
  model.load_state_dict(state_dict)
640
+
641
+ # Clear state_dict from memory immediately after loading
642
+ del state_dict
643
+ gc.collect()
644
+ if torch.cuda.is_available():
645
+ torch.cuda.empty_cache()
646
 
647
  return model, tokenizer