Spaces:
Running
Running
File size: 1,483 Bytes
cb7223a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# schemas/visualize.py
from pydantic import BaseModel, field_validator
from typing import List, Optional, Union, Tuple
class VisualizePCARequest(BaseModel):
"""
Schema for the /visualize-pca endpoint.
"""
model_name: str
prompt_pair: List[str]
layer_key: str
highlight_diff: bool = True
figure_format: str = "png"
pair_index: int = 0
output_dir: Optional[str] = None
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v
class VisualizeMeanDiffRequest(BaseModel):
model_name: str
prompt_pair: List[str]
layer_type: str # Changed from layer_key to layer_type
figure_format: str = "png"
output_dir: Optional[str] = None
pair_index: int = 0
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v
class VisualizeHeatmapRequest(BaseModel):
"""
Schema for the /visualize/heatmap endpoint.
"""
model_name: str
prompt_pair: List[str]
layer_key: str
figure_format: str = "png"
output_dir: Optional[str] = None
@field_validator("prompt_pair")
def must_be_two_prompts(cls, v):
if len(v) != 2:
raise ValueError("prompt_pair must be a list of exactly two strings")
return v
|