Spaces:
Runtime error
Runtime error
| """ | |
| https://github.com/ProteinDesignLab/protpardelle | |
| License: MIT | |
| Author: Alex Chu | |
| Dataloader from PDB files. | |
| """ | |
| import copy | |
| import pickle | |
| import json | |
| import numpy as np | |
| import torch | |
| from torch.utils import data | |
| from core import utils | |
| from core import protein | |
| from core import residue_constants | |
| FEATURES_1D = ( | |
| "coords_in", | |
| "torsions_in", | |
| "b_factors", | |
| "atom_positions", | |
| "aatype", | |
| "atom_mask", | |
| "residue_index", | |
| "chain_index", | |
| ) | |
| FEATURES_FLOAT = ( | |
| "coords_in", | |
| "torsions_in", | |
| "b_factors", | |
| "atom_positions", | |
| "atom_mask", | |
| "seq_mask", | |
| ) | |
| FEATURES_LONG = ("aatype", "residue_index", "chain_index", "orig_size") | |
| def make_fixed_size_1d(data, fixed_size=128): | |
| data_len = data.shape[0] | |
| if data_len >= fixed_size: | |
| extra_len = data_len - fixed_size | |
| start_idx = np.random.choice(np.arange(extra_len + 1)) | |
| new_data = data[start_idx : (start_idx + fixed_size)] | |
| mask = torch.ones(fixed_size) | |
| if data_len < fixed_size: | |
| pad_size = fixed_size - data_len | |
| extra_shape = data.shape[1:] | |
| new_data = torch.cat([data, torch.zeros(pad_size, *extra_shape)], 0) | |
| mask = torch.cat([torch.ones(data_len), torch.zeros(pad_size)], 0) | |
| return new_data, mask | |
| def apply_random_se3(coords_in, atom_mask=None, translation_scale=1.0): | |
| # unbatched. center on the mean of CA coords | |
| coords_mean = coords_in[:, 1:2].mean(-3, keepdim=True) | |
| coords_in -= coords_mean | |
| random_rot, _ = torch.linalg.qr(torch.randn(3, 3)) | |
| coords_in = coords_in @ random_rot | |
| random_trans = torch.randn_like(coords_mean) * translation_scale | |
| coords_in += random_trans | |
| if atom_mask is not None: | |
| coords_in = coords_in * atom_mask[..., None] | |
| return coords_in | |
| def get_masked_coords_array(coords, atom_mask): | |
| ma_mask = repeat(1 - atom_mask[..., None].cpu().numpy(), "... 1 -> ... 3") | |
| return np.ma.array(coords.cpu().numpy(), mask=ma_mask) | |
| def make_crop_cond_mask_and_recenter_coords( | |
| atom_mask, | |
| atom_coords, | |
| contiguous_prob=0.05, | |
| discontiguous_prob=0.9, | |
| sidechain_only_prob=0.8, | |
| max_span_len=10, | |
| max_discontiguous_res=8, | |
| dist_threshold=8.0, | |
| recenter_coords=True, | |
| ): | |
| b, n, a = atom_mask.shape | |
| device = atom_mask.device | |
| seq_mask = atom_mask[..., 1] | |
| n_res = seq_mask.sum(-1) | |
| masks = [] | |
| for i, nr in enumerate(n_res): | |
| nr = nr.int().item() | |
| mask = torch.zeros((n, a), device=device) | |
| conditioning_type = torch.distributions.Categorical( | |
| torch.tensor( | |
| [ | |
| contiguous_prob, | |
| discontiguous_prob, | |
| 1.0 - contiguous_prob - discontiguous_prob, | |
| ] | |
| ) | |
| ).sample() | |
| conditioning_type = ["contiguous", "discontiguous", "none"][conditioning_type] | |
| if conditioning_type == "contiguous": | |
| span_len = torch.randint( | |
| 1, min(max_span_len, nr), (1,), device=device | |
| ).item() | |
| span_start = torch.randint(0, nr - span_len, (1,), device=device) | |
| mask[span_start : span_start + span_len, :] = 1 | |
| elif conditioning_type == "discontiguous": | |
| # Extract CB atoms coordinates for the i-th example | |
| cb_atoms = atom_coords[i, :, 3] | |
| # Pairwise distances between CB atoms | |
| cb_distances = torch.cdist(cb_atoms, cb_atoms) | |
| close_mask = ( | |
| cb_distances <= dist_threshold | |
| ) # Mask for selecting close CB atoms | |
| random_residue = torch.randint(0, nr, (1,), device=device).squeeze() | |
| cb_dist_i = cb_distances[random_residue] + 1e3 * (1 - seq_mask[i]) | |
| close_mask = cb_dist_i <= dist_threshold | |
| n_neighbors = close_mask.sum().int() | |
| # pick how many neighbors (up to 10) | |
| n_sele = torch.randint( | |
| 2, | |
| n_neighbors.clamp(min=3, max=max_discontiguous_res + 1), | |
| (1,), | |
| device=device, | |
| ) | |
| # Select the indices of CB atoms that are close together | |
| idxs = torch.arange(n, device=device)[close_mask.bool()] | |
| idxs = idxs[torch.randperm(len(idxs))[:n_sele]] | |
| if len(idxs) > 0: | |
| mask[idxs] = 1 | |
| if np.random.uniform() < sidechain_only_prob: | |
| mask[:, :5] = 0 | |
| masks.append(mask) | |
| crop_cond_mask = torch.stack(masks) | |
| crop_cond_mask = crop_cond_mask * atom_mask | |
| if recenter_coords: | |
| motif_masked_array = get_masked_coords_array(atom_coords, crop_cond_mask) | |
| cond_coords_center = motif_masked_array.mean((1, 2)) | |
| motif_mask = torch.Tensor(1 - cond_coords_center.mask).to(crop_cond_mask) | |
| means = torch.Tensor(cond_coords_center.data).to(atom_coords) * motif_mask | |
| coords_out = atom_coords - rearrange(means, "b c -> b 1 1 c") | |
| else: | |
| coords_out = atom_coords | |
| return coords_out, crop_cond_mask | |
| class Dataset(data.Dataset): | |
| """Loads and processes PDBs into tensors.""" | |
| def __init__( | |
| self, | |
| pdb_path, | |
| fixed_size, | |
| mode="train", | |
| overfit=-1, | |
| short_epoch=False, | |
| se3_data_augment=True, | |
| ): | |
| self.pdb_path = pdb_path | |
| self.fixed_size = fixed_size | |
| self.mode = mode | |
| self.overfit = overfit | |
| self.short_epoch = short_epoch | |
| self.se3_data_augment = se3_data_augment | |
| with open(f"{self.pdb_path}/{mode}_pdb_keys.list") as f: | |
| self.pdb_keys = np.array(f.read().split("\n")[:-1]) | |
| if overfit > 0: | |
| n_data = len(self.pdb_keys) | |
| self.pdb_keys = np.random.choice( | |
| self.pdb_keys, min(n_data, overfit), replace=False | |
| ).repeat(n_data // overfit) | |
| def __len__(self): | |
| if self.short_epoch: | |
| return min(len(self.pdb_keys), 256) | |
| else: | |
| return len(self.pdb_keys) | |
| def __getitem__(self, idx): | |
| pdb_key = self.pdb_keys[idx] | |
| data = self.get_item(pdb_key) | |
| # For now, replace dataloading errors with a random pdb. 10 tries | |
| for _ in range(10): | |
| if data is not None: | |
| return data | |
| pdb_key = self.pdb_keys[np.random.randint(len(self.pdb_keys))] | |
| data = self.get_item(pdb_key) | |
| raise Exception("Failed to load data example after 10 tries.") | |
| def get_item(self, pdb_key): | |
| example = {} | |
| if self.pdb_path.endswith("cath_s40_dataset"): # CATH pdbs | |
| data_file = f"{self.pdb_path}/dompdb/{pdb_key}" | |
| elif self.pdb_path.endswith("ingraham_cath_dataset"): # ingraham splits | |
| data_file = f"{self.pdb_path}/pdb_store/{pdb_key}" | |
| else: | |
| raise Exception("Invalid pdb path.") | |
| try: | |
| example = utils.load_feats_from_pdb(data_file) | |
| coords_in = example["atom_positions"] | |
| except FileNotFoundError: | |
| raise Exception(f"File {pdb_key} not found. Check if dataset is corrupted?") | |
| except RuntimeError: | |
| return None | |
| # Apply data augmentation | |
| if self.se3_data_augment: | |
| coords_in = apply_random_se3(coords_in, atom_mask=example["atom_mask"]) | |
| orig_size = coords_in.shape[0] | |
| example["coords_in"] = coords_in | |
| example["orig_size"] = torch.ones(1) * orig_size | |
| fixed_size_example = {} | |
| seq_mask = None | |
| for k, v in example.items(): | |
| if k in FEATURES_1D: | |
| fixed_size_example[k], seq_mask = make_fixed_size_1d( | |
| v, fixed_size=self.fixed_size | |
| ) | |
| else: | |
| fixed_size_example[k] = v | |
| if seq_mask is not None: | |
| fixed_size_example["seq_mask"] = seq_mask | |
| example_out = {} | |
| for k, v in fixed_size_example.items(): | |
| if k in FEATURES_FLOAT: | |
| example_out[k] = v.float() | |
| elif k in FEATURES_LONG: | |
| example_out[k] = v.long() | |
| return example_out | |
| def collate(self, example_list): | |
| out = {} | |
| for ex in example_list: | |
| for k, v in ex.items(): | |
| out.setdefault(k, []).append(v) | |
| return {k: torch.stack(v) for k, v in out.items()} | |
| def sample(self, n=1, return_data=True, return_keys=False): | |
| keys = self.pdb_keys[torch.randperm(self.__len__())[:n].long()] | |
| if return_keys and not return_data: | |
| return keys | |
| if n == 1: | |
| data = self.collate([self.get_item(keys)]) | |
| else: | |
| data = self.collate([self.get_item(key) for key in keys]) | |
| if return_data and return_keys: | |
| return data, keys | |
| if return_data and not return_keys: | |
| return data | |