Spaces:
Running
Running
# schemas/visualize.py | |
from pydantic import BaseModel, field_validator | |
from typing import List, Optional, Union, Tuple | |
class VisualizePCARequest(BaseModel): | |
""" | |
Schema for the /visualize-pca endpoint. | |
""" | |
model_name: str | |
prompt_pair: List[str] | |
layer_key: str | |
highlight_diff: bool = True | |
figure_format: str = "png" | |
pair_index: int = 0 | |
output_dir: Optional[str] = None | |
def must_be_two_prompts(cls, v): | |
if len(v) != 2: | |
raise ValueError("prompt_pair must be a list of exactly two strings") | |
return v | |
class VisualizeMeanDiffRequest(BaseModel): | |
model_name: str | |
prompt_pair: List[str] | |
layer_type: str # Changed from layer_key to layer_type | |
figure_format: str = "png" | |
output_dir: Optional[str] = None | |
pair_index: int = 0 | |
def must_be_two_prompts(cls, v): | |
if len(v) != 2: | |
raise ValueError("prompt_pair must be a list of exactly two strings") | |
return v | |
class VisualizeHeatmapRequest(BaseModel): | |
""" | |
Schema for the /visualize/heatmap endpoint. | |
""" | |
model_name: str | |
prompt_pair: List[str] | |
layer_key: str | |
figure_format: str = "png" | |
output_dir: Optional[str] = None | |
def must_be_two_prompts(cls, v): | |
if len(v) != 2: | |
raise ValueError("prompt_pair must be a list of exactly two strings") | |
return v | |