Spaces:
Running
Running
| import torch | |
| def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device): | |
| """ | |
| Calculate gradient penalty for WGAN-GP. | |
| Args: | |
| discriminator: The discriminator model | |
| real_node: Real node features | |
| real_edge: Real edge features | |
| fake_node: Generated node features | |
| fake_edge: Generated edge features | |
| batch_size: Batch size | |
| device: Device to compute on | |
| Returns: | |
| Gradient penalty term | |
| """ | |
| # Generate random interpolation factors | |
| eps_edge = torch.rand(batch_size, 1, 1, 1, device=device) | |
| eps_node = torch.rand(batch_size, 1, 1, device=device) | |
| # Create interpolated samples | |
| int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True) | |
| int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True) | |
| logits_interpolated = discriminator(int_edge, int_node) | |
| # Calculate gradients for both node and edge inputs | |
| weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device) | |
| gradients = torch.autograd.grad( | |
| outputs=logits_interpolated, | |
| inputs=[int_node, int_edge], | |
| grad_outputs=weight, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True | |
| ) | |
| # Combine gradients from both inputs | |
| gradients_node = gradients[0].view(batch_size, -1) | |
| gradients_edge = gradients[1].view(batch_size, -1) | |
| gradients = torch.cat([gradients_node, gradients_edge], dim=1) | |
| # Calculate gradient penalty | |
| gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() | |
| return gradient_penalty | |
| def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp): | |
| # Compute loss for drugs | |
| logits_real_disc = discriminator(drug_adj, drug_annot) | |
| # Use mean reduction for more stable training | |
| prediction_real = -torch.mean(logits_real_disc) | |
| # Compute loss for generated molecules | |
| node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) | |
| logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach()) | |
| prediction_fake = torch.mean(logits_fake_disc) | |
| # Compute gradient penalty using the new function | |
| gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device) | |
| # Calculate total discriminator loss | |
| d_loss = prediction_fake + prediction_real + lambda_gp * gp | |
| return node, edge, d_loss | |
| def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size): | |
| # Generate fake molecules | |
| node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) | |
| # Compute logits for fake molecules | |
| logits_fake_disc = discriminator(edge_sample, node_sample) | |
| prediction_fake = -torch.mean(logits_fake_disc) | |
| g_loss = prediction_fake | |
| return g_loss, node, edge, node_sample, edge_sample |