Spaces:
Runtime error
Runtime error
| import argparse | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import torch | |
| import subprocess | |
| import output | |
| from rdkit import Chem | |
| from src import const | |
| from src.visualizer import save_xyz_file | |
| from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule | |
| from src.lightning import DDPM | |
| from src.linker_size_lightning import SizeClassifier | |
| N_SAMPLES = 5 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--ip', type=str, default=None) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| os.makedirs("results", exist_ok=True) | |
| os.makedirs("models", exist_ok=True) | |
| size_gnn_path = 'models/geom_size_gnn.ckpt' | |
| if not os.path.exists(size_gnn_path): | |
| print('Downloading SizeGNN model...') | |
| link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1' | |
| subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True) | |
| size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) | |
| print('Loaded SizeGNN model') | |
| diffusion_path = 'models/geom_difflinker.ckpt' | |
| if not os.path.exists(diffusion_path): | |
| print('Downloading Diffusion model...') | |
| link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1' | |
| subprocess.run(f'wget {link} -O {diffusion_path}', shell=True) | |
| ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device) | |
| print('Loaded diffusion model') | |
| def sample_fn(_data): | |
| output, _ = size_nn.forward(_data, return_loss=False) | |
| probabilities = torch.softmax(output, dim=1) | |
| distribution = torch.distributions.Categorical(probs=probabilities) | |
| samples = distribution.sample() | |
| sizes = [] | |
| for label in samples.detach().cpu().numpy(): | |
| sizes.append(size_nn.linker_id2size[label]) | |
| sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) | |
| return sizes | |
| def read_molecule_content(path): | |
| with open(path, "r") as f: | |
| return "".join(f.readlines()) | |
| def read_molecule(path): | |
| if path.endswith('.pdb'): | |
| return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.mol'): | |
| return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.mol2'): | |
| return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.sdf'): | |
| return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] | |
| raise Exception('Unknown file extension') | |
| def show_input(input_file): | |
| if input_file is None: | |
| return '' | |
| if isinstance(input_file, str): | |
| path = input_file | |
| else: | |
| path = input_file.name | |
| extension = path.split('.')[-1] | |
| if extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
| return output.IFRAME_TEMPLATE.format(html=msg) | |
| try: | |
| molecule = read_molecule_content(path) | |
| except Exception as e: | |
| return f'Could not read the molecule: {e}' | |
| html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def draw_sample(idx, out_files): | |
| in_file = out_files[0] | |
| in_sdf = in_file if isinstance(in_file, str) else in_file.name | |
| out_file = out_files[idx + 1] | |
| out_sdf = out_file if isinstance(out_file, str) else out_file.name | |
| input_fragments_content = read_molecule_content(in_sdf) | |
| generated_molecule_content = read_molecule_content(out_sdf) | |
| html = output.SAMPLES_RENDERING_TEMPLATE.format( | |
| fragments=input_fragments_content, | |
| fragments_fmt='sdf', | |
| molecule=generated_molecule_content, | |
| molecule_fmt='sdf', | |
| ) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def generate(input_file, n_steps): | |
| if input_file is None: | |
| return '' | |
| path = input_file.name | |
| extension = path.split('.')[-1] | |
| if extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
| return output.IFRAME_TEMPLATE.format(html=msg) | |
| try: | |
| molecule = read_molecule(path) | |
| molecule = Chem.RemoveAllHs(molecule) | |
| name = '.'.join(path.split('/')[-1].split('.')[:-1]) | |
| inp_sdf = f'results/input_{name}.sdf' | |
| except Exception as e: | |
| return f'Could not read the molecule: {e}' | |
| if molecule.GetNumAtoms() > 50: | |
| return f'Too large molecule: upper limit is 50 heavy atoms' | |
| with Chem.SDWriter(inp_sdf) as w: | |
| w.write(molecule) | |
| positions, one_hot, charges = parse_molecule(molecule, is_geom=True) | |
| anchors = np.zeros_like(charges) | |
| fragment_mask = np.ones_like(charges) | |
| linker_mask = np.zeros_like(charges) | |
| print('Read and parsed molecule') | |
| dataset = [{ | |
| 'uuid': '0', | |
| 'name': '0', | |
| 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
| 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
| 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
| 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'num_atoms': len(positions), | |
| }] * N_SAMPLES | |
| dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges) | |
| print('Created dataloader') | |
| ddpm.edm.T = n_steps | |
| assert ddpm.center_of_mass == 'fragments' | |
| for data in dataloader: | |
| chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
| print('Generated linker') | |
| x = chain[0][:, :, :ddpm.n_dims] | |
| h = chain[0][:, :, ddpm.n_dims:] | |
| # Put the molecule back to the initial orientation | |
| pos_masked = data['positions'] * data['fragment_mask'] | |
| N = data['fragment_mask'].sum(1, keepdims=True) | |
| mean = torch.sum(pos_masked, dim=1, keepdim=True) / N | |
| x = x + mean * node_mask | |
| names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)] | |
| save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') | |
| print('Saved XYZ files') | |
| break | |
| out_files = [] | |
| for i in range(N_SAMPLES): | |
| out_xyz = f'results/output_{i+1}_{name}_.xyz' | |
| out_sdf = f'results/output_{i+1}_{name}_.sdf' | |
| subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) | |
| out_files.append(out_sdf) | |
| print('Converted to SDF') | |
| return [ | |
| draw_sample(0, out_files), | |
| [inp_sdf] + out_files, | |
| gr.Radio.update(visible=True, value='Sample 1') | |
| ] | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') | |
| gr.Markdown( | |
| 'Given a set of disconnected fragments in 3D, ' | |
| 'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. ' | |
| 'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms ' | |
| 'and linker size, and can be conditioned on the protein pockets.' | |
| ) | |
| gr.Markdown( | |
| '[**[Paper]**](https://arxiv.org/abs/2210.05274) ' | |
| '[**[Code]**](https://github.com/igashov/DiffLinker)' | |
| ) | |
| with gr.Box(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown('## Input Fragments') | |
| gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:') | |
| input_file = gr.File(file_count='single', label='Input Fragments') | |
| n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10) | |
| examples = gr.Dataset( | |
| components=[gr.File(visible=False)], | |
| samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']], | |
| type='index', | |
| ) | |
| button = gr.Button('Generate Linker!') | |
| gr.Markdown('') | |
| gr.Markdown('## Output Files') | |
| gr.Markdown('Download files with the generated molecules here:') | |
| output_files = gr.File(file_count='multiple', label='Output Files', interactive=False) | |
| with gr.Column(): | |
| gr.Markdown('## Visualization') | |
| # gr.Markdown('Below you will see input and output molecules') | |
| samples = gr.Radio( | |
| choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'], | |
| value='Sample 1', | |
| type='index', | |
| show_label=False, | |
| visible=False, | |
| interactive=True, | |
| ) | |
| visualization = gr.HTML() | |
| input_file.change( | |
| fn=show_input, | |
| inputs=[input_file], | |
| outputs=[visualization], | |
| ) | |
| examples.click( | |
| fn=lambda idx: [ | |
| f'examples/example_{idx+1}.sdf', | |
| 10, | |
| show_input(f'examples/example_{idx+1}.sdf'), | |
| gr.Radio(value='Sample 1', visible=False) | |
| ], | |
| inputs=[examples], | |
| outputs=[input_file, n_steps, visualization, samples] | |
| ) | |
| button.click( | |
| fn=generate, | |
| inputs=[input_file, n_steps], | |
| outputs=[visualization, output_files, samples], | |
| ) | |
| samples.change( | |
| fn=draw_sample, | |
| inputs=[samples, output_files], | |
| outputs=[visualization], | |
| ) | |
| input_file.clear( | |
| fn=lambda: ['', gr.Radio(value='Sample 1', visible=False)], | |
| inputs=[], | |
| outputs=[visualization, samples], | |
| ) | |
| demo.launch(server_name=args.ip) | |