|
import os |
|
import tempfile |
|
import unittest |
|
|
|
from transformers import TrainingArguments |
|
|
|
|
|
class TestTrainingArguments(unittest.TestCase): |
|
def test_default_output_dir(self): |
|
"""Test that output_dir defaults to 'trainer_output' when not specified.""" |
|
args = TrainingArguments(output_dir=None) |
|
self.assertEqual(args.output_dir, "trainer_output") |
|
|
|
def test_custom_output_dir(self): |
|
"""Test that output_dir is respected when specified.""" |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
args = TrainingArguments(output_dir=tmp_dir) |
|
self.assertEqual(args.output_dir, tmp_dir) |
|
|
|
def test_output_dir_creation(self): |
|
"""Test that output_dir is created only when needed.""" |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
output_dir = os.path.join(tmp_dir, "test_output") |
|
|
|
|
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
args = TrainingArguments( |
|
output_dir=output_dir, |
|
do_train=True, |
|
save_strategy="no", |
|
report_to=None, |
|
) |
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
args.save_strategy = "steps" |
|
args.save_steps = 1 |
|
self.assertFalse(os.path.exists(output_dir)) |
|
|
|
|
|
|
|
def test_torch_empty_cache_steps_requirements(self): |
|
"""Test that torch_empty_cache_steps is a positive integer or None.""" |
|
|
|
|
|
args = TrainingArguments(torch_empty_cache_steps=None) |
|
self.assertIsNone(args.torch_empty_cache_steps) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
TrainingArguments(torch_empty_cache_steps=1.0) |
|
with self.assertRaises(ValueError): |
|
TrainingArguments(torch_empty_cache_steps="none") |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
TrainingArguments(torch_empty_cache_steps=-1) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
TrainingArguments(torch_empty_cache_steps=0) |
|
|
|
|
|
args = TrainingArguments(torch_empty_cache_steps=1) |
|
self.assertEqual(args.torch_empty_cache_steps, 1) |
|
|