import gradio as gr import torch import joblib import numpy as np from itertools import product import torch.nn as nn import matplotlib.pyplot as plt import io from PIL import Image ############################################################################### # Model Definition ############################################################################### class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) def get_feature_importance(self, x): """ Calculate gradient-based feature importance, specifically for the 'human' class (index=1) by computing gradient of that probability wrt x. """ x.requires_grad_(True) output = self.network(x) probs = torch.softmax(output, dim=1) # Probability of 'human' class (index=1) human_prob = probs[..., 1] if x.grad is not None: x.grad.zero_() human_prob.backward() importance = x.grad # shape: (batch_size, n_features) return importance, float(human_prob) ############################################################################### # Utility Functions ############################################################################### def parse_fasta(text): """Parses text input in FASTA format into a list of (header, sequence).""" sequences = [] current_header = None current_sequence = [] for line in text.split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """Convert a single nucleotide sequence to a k-mer frequency vector.""" kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec = vec / total_kmers # normalize frequencies return vec ############################################################################### # Visualization ############################################################################### def create_shap_waterfall_plot(important_kmers, all_kmer_importance, human_prob, title): """ Create a SHAP-like waterfall plot: - Start at baseline = 0.5 - Add a bar for "Other" which is the combined effect of all less-important k-mers - Then apply each of the top k-mers in descending order of absolute importance - Show final predicted human probability as the endpoint """ # 1) Sort 'important_kmers' by absolute impact descending sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) # 2) Compute the total effect of "other" k-mers # We have 256 total features. We selected top 10. Sum the rest. top_ids = set([km['idx'] for km in sorted_kmers]) other_contributions = [] for i, val in enumerate(all_kmer_importance): if i not in top_ids: other_contributions.append(val) # sum up those "other" contributions other_sum = np.sum(other_contributions) # The "impact" for "other" will be the absolute value, direction depends on sign other_impact = float(abs(other_sum)) other_direction = "human" if other_sum > 0 else "non-human" # 3) Build a list of all bars: first "other", then each top k-mer # Each bar needs: name, raw_contribution_value # We'll store (label, contribution). The sign indicates direction. bars = [] bars.append(("Other", other_sum)) # lumps the leftover k-mers for km in sorted_kmers: # We re-inject the sign on the raw gradient # (We stored only the absolute in "impact," so let's create a signed value) signed_val = km['impact'] if km['direction'] == 'human' else -km['impact'] bars.append((km['kmer'], signed_val)) # 4) Waterfall plot data: # We'll accumulate partial sums from baseline=0.5 baseline = 0.5 running_val = baseline x_labels = [] y_vals = [] bar_colors = [] # We'll use green for positive contributions (pushing toward 'human'), # red for negative contributions (pushing away from 'human') for (label, contrib) in bars: x_labels.append(label) # new value after adding this contribution new_val = running_val + (0.05 * contrib) # ^ scaled by 0.05 for better display. Adjust as desired. y_vals.append((running_val, new_val)) running_val = new_val if contrib >= 0: bar_colors.append("green") else: bar_colors.append("red") final_prob = running_val # Final point is the model's predicted probability (not always exact, but this is a shap-like idea). # If we want to forcibly ensure final_prob = human_prob, we could do: # correction = human_prob - running_val # running_val += correction # but for now let's keep the "waterfall" purely additive from the gradient. # Let's plot: fig, ax = plt.subplots(figsize=(10, 6)) # We'll create the bars manually x_positions = np.arange(len(x_labels)) last_end = baseline for i, ((start_val, end_val), color) in enumerate(zip(y_vals, bar_colors)): # The bar's height is the difference height = end_val - start_val ax.bar(i, height, bottom=start_val, color=color, edgecolor='black', alpha=0.7) ax.text(i, (start_val + end_val) / 2, f"{height:+.3f}", ha='center', va='center', color='white', fontsize=8) ax.axhline(y=baseline, color='black', linestyle='--', linewidth=1) ax.set_xticks(x_positions) ax.set_xticklabels(x_labels, rotation=45, ha='right') ax.set_ylim(0, 1) ax.set_ylabel("Running Probability (Human)") ax.set_title(f"SHAP-like Waterfall — Final Probability: {final_prob:.3f} (Model Probability: {human_prob:.3f})") plt.tight_layout() return fig def create_frequency_sigma_plot(important_kmers, title): """Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean.""" # Sort by absolute impact sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) kmers = [k["kmer"] for k in sorted_kmers] frequencies = [k["occurrence"] for k in sorted_kmers] # in % sigmas = [k["sigma"] for k in sorted_kmers] directions = [k["direction"] for k in sorted_kmers] x = np.arange(len(kmers)) width = 0.4 fig, ax_bar = plt.subplots(figsize=(10, 6)) # Bar for frequency bars_freq = ax_bar.bar( x - width/2, frequencies, width, alpha=0.7, color=["green" if d=="human" else "red" for d in directions], label="Frequency (%)" ) ax_bar.set_ylabel("Frequency (%)") ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1) # Twin axis for σ ax_bar_twin = ax_bar.twinx() bars_sigma = ax_bar_twin.bar( x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean" ) ax_bar_twin.set_ylabel("Standard Deviations (σ)") ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}") ax_bar.set_xticks(x) ax_bar.set_xticklabels(kmers, rotation=45, ha='right') # Combined legend lines1, labels1 = ax_bar.get_legend_handles_labels() lines2, labels2 = ax_bar_twin.get_legend_handles_labels() ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right") plt.tight_layout() return fig def create_importance_bar_plot(important_kmers, title): """ Create a simple bar chart showing the absolute gradient magnitude for the top k-mers, sorted descending. """ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) kmers = [k['kmer'] for k in sorted_kmers] impacts = [k['impact'] for k in sorted_kmers] directions = [k["direction"] for k in sorted_kmers] x = np.arange(len(kmers)) fig, ax = plt.subplots(figsize=(10, 6)) bar_colors = ["green" if d=="human" else "red" for d in directions] ax.bar(x, impacts, color=bar_colors, alpha=0.7) ax.set_xticks(x) ax.set_xticklabels(kmers, rotation=45, ha='right') ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}") ax.set_ylabel("Gradient Magnitude") ax.grid(axis="y", alpha=0.3) plt.tight_layout() return fig ############################################################################### # Prediction Function ############################################################################### def predict(file_obj): """ Main function for Gradio: 1. Reads the uploaded FASTA file or text. 2. Loads the model and scaler. 3. Generates predictions, probabilities, and top k-mers. 4. Returns multiple outputs: - A textual summary (Markdown). - Waterfall plot. - Frequency & sigma plot. - Absolute importance bar plot. """ # 0. Basic file read if file_obj is None: return ( "Please upload a FASTA file.", None, None, None ) try: # If user provided raw text, use that if isinstance(file_obj, str): text = file_obj else: # If user uploaded a file, decode it text = file_obj.decode('utf-8') except Exception as e: return ( f"Error reading file: {str(e)}", None, None, None ) # 1. Parse FASTA sequences = parse_fasta(text) if len(sequences) == 0: return ( "No valid FASTA sequences found. Please check your input.", None, None, None ) # We’ll just classify the first sequence for demonstration header, seq = sequences[0] # 2. Create k-mer vector & load model k = 4 try: device = "cuda" if torch.cuda.is_available() else "cpu" # Prepare raw freq vector & scale raw_freq_vector = sequence_to_kmer_vector(seq, k=k) # Load model & scaler model = VirusClassifier(input_shape=4**k).to(device) state_dict = torch.load('model.pt', map_location=device) model.load_state_dict(state_dict) scaler = joblib.load('scaler.pkl') model.eval() scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1)) X_tensor = torch.FloatTensor(scaled_vector).to(device) # 3. Inference with torch.no_grad(): logits = model(X_tensor) probs = torch.softmax(logits, dim=1) human_prob = float(probs[0][1]) non_human_prob = float(probs[0][0]) pred_class = 1 if human_prob >= non_human_prob else 0 pred_label = "human" if pred_class == 1 else "non-human" confidence = float(max(probs[0])) # 4. Feature importance importance, hum_prob_grad = model.get_feature_importance(X_tensor) # shape: [1, 256] kmer_importances = importance[0].cpu().numpy() # We’ll store them as a dictionary: index -> (k-mer, importance) # Build up a dict for k-mer strings kmers_list = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers_list)} # 5. Get the top 10 k-mers by absolute importance abs_importance = np.abs(kmer_importances) top_k = 10 top_idxs = np.argsort(abs_importance)[-top_k:][::-1] # descending important_kmers = [] for idx in top_idxs: # Find the k-mer by index kmer_str = kmers_list[idx] # direction direction = "human" if kmer_importances[idx] > 0 else "non-human" # frequency in % from raw_freq_vector freq_percent = float(raw_freq_vector[idx] * 100) # sigma from scaled vector sigma_val = float(scaled_vector[0][idx]) important_kmers.append({ 'kmer': kmer_str, 'idx': idx, 'impact': float(abs_importance[idx]), 'direction': direction, 'occurrence': freq_percent, 'sigma': sigma_val }) # 6. Text Summary summary_text = ( f"**Sequence Header**: {header}\n\n" f"**Predicted Label**: {pred_label}\n" f"**Confidence**: {confidence:.4f}\n\n" f"**Human Probability**: {human_prob:.4f}\n" f"**Non-human Probability**: {non_human_prob:.4f}\n\n" "### Most Influential k-mers:\n" ) for km in important_kmers: direction_text = f"(pushes toward {km['direction']})" freq_text = f"{km['occurrence']:.2f}%" sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean" summary_text += ( f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, " f"occurrence={freq_text}, ({sigma_text})\n" ) # 7. Plots # a) SHAP-like Waterfall Plot fig_waterfall = create_shap_waterfall_plot( important_kmers, kmer_importances, human_prob, f"{header}" ) buf1 = io.BytesIO() fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120) buf1.seek(0) waterfall_img = Image.open(buf1) plt.close(fig_waterfall) # b) Frequency & σ Plot (top 10 k-mers) fig_freq_sigma = create_frequency_sigma_plot( important_kmers, f"{header}" ) buf2 = io.BytesIO() fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120) buf2.seek(0) freq_sigma_img = Image.open(buf2) plt.close(fig_freq_sigma) # c) Absolute Importance Bar Plot fig_imp = create_importance_bar_plot( important_kmers, f"{header}" ) buf3 = io.BytesIO() fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120) buf3.seek(0) importance_img = Image.open(buf3) plt.close(fig_imp) return summary_text, waterfall_img, freq_sigma_img, importance_img except Exception as e: return ( f"Error during prediction or visualization: {str(e)}", None, None, None ) ############################################################################### # Gradio Interface ############################################################################### with gr.Blocks(title="Advanced Virus Host Classifier") as demo: gr.Markdown( """ # Advanced Virus Host Classifier **Upload a FASTA file** containing a single nucleotide sequence. The model will predict whether this sequence is **human** or **non-human**, provide a confidence score, and highlight the most influential k-mers (using a SHAP-like waterfall plot) along with two additional plots. """ ) with gr.Row(): file_in = gr.File(label="Upload FASTA", type="binary") btn = gr.Button("Run Prediction") # We will create multiple tabs for our outputs with gr.Tabs(): with gr.Tab("Prediction Results"): md_out = gr.Markdown() with gr.Tab("SHAP-like Waterfall Plot"): water_out = gr.Image() with gr.Tab("Frequency & σ Plot"): freq_out = gr.Image() with gr.Tab("Importance Bar Plot"): imp_out = gr.Image() # Link the button btn.click( fn=predict, inputs=[file_in], outputs=[md_out, water_out, freq_out, imp_out] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)