oopere's picture
Upload 11 files
cb7223a verified
# schemas/visualize.py
from pydantic import BaseModel, field_validator
from typing import List, Optional, Union, Tuple
class VisualizePCARequest(BaseModel):
"""
Schema for the /visualize-pca endpoint.
"""
model_name: str
prompt_pair: List[str]
layer_key: str
highlight_diff: bool = True
figure_format: str = "png"
pair_index: int = 0
output_dir: Optional[str] = None
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v
class VisualizeMeanDiffRequest(BaseModel):
model_name: str
prompt_pair: List[str]
layer_type: str # Changed from layer_key to layer_type
figure_format: str = "png"
output_dir: Optional[str] = None
pair_index: int = 0
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v
class VisualizeHeatmapRequest(BaseModel):
"""
Schema for the /visualize/heatmap endpoint.
"""
model_name: str
prompt_pair: List[str]
layer_key: str
figure_format: str = "png"
output_dir: Optional[str] = None
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v