Roberta / config.py
namanpenguin's picture
Upload 10 files
b81f538 verified
# config.py
import torch
import os
# --- Paths ---
# Adjust DATA_PATH to your actual data location
DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
TOKENIZER_PATH = './tokenizer/'
LABEL_ENCODERS_PATH = './label_encoders.pkl'
MODEL_SAVE_DIR = './saved_models/'
PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
# --- Data Columns ---
TEXT_COLUMN = "Sanction_Context"
# Define all your target label columns
LABEL_COLUMNS = [
"Red_Flag_Reason",
"Maker_Action",
"Escalation_Level",
"Risk_Category",
"Risk_Drivers",
"Investigation_Outcome"
]
# Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
# For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
# --- Model Hyperparameters ---
MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
BATCH_SIZE = 16 # Batch size for training and evaluation
LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
DROPOUT_RATE = 0.3 # Dropout rate for regularization
# --- Device Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Specific Model Configurations ---
ROBERTA_MODEL_NAME = 'roberta-base'
BERT_MODEL_NAME = 'bert-base-uncased'
DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
# TF-IDF
TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
# --- Field-Specific Strategy (Conceptual) ---
# This dictionary provides conceptual strategies for enhancing specific fields.
# Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
FIELD_STRATEGIES = {
"Maker_Action": {
"loss": "focal_loss", # Requires custom Focal Loss implementation
"enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
},
"Risk_Category": {
"enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
},
"Escalation_Level": {
"enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
},
"Investigation_Outcome": {
"type": "classification_or_generation" # If generation, T5/BART would be needed.
}
}
# Ensure model save and predictions directories exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_PATH, exist_ok=True)