File size: 2,617 Bytes
e0be88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import tempfile
import unittest

from transformers import TrainingArguments


class TestTrainingArguments(unittest.TestCase):
    def test_default_output_dir(self):
        """Test that output_dir defaults to 'trainer_output' when not specified."""
        args = TrainingArguments(output_dir=None)
        self.assertEqual(args.output_dir, "trainer_output")

    def test_custom_output_dir(self):
        """Test that output_dir is respected when specified."""
        with tempfile.TemporaryDirectory() as tmp_dir:
            args = TrainingArguments(output_dir=tmp_dir)
            self.assertEqual(args.output_dir, tmp_dir)

    def test_output_dir_creation(self):
        """Test that output_dir is created only when needed."""
        with tempfile.TemporaryDirectory() as tmp_dir:
            output_dir = os.path.join(tmp_dir, "test_output")

            # Directory should not exist before creating args
            self.assertFalse(os.path.exists(output_dir))

            # Create args with save_strategy="no" - should not create directory
            args = TrainingArguments(
                output_dir=output_dir,
                do_train=True,
                save_strategy="no",
                report_to=None,
            )
            self.assertFalse(os.path.exists(output_dir))

            # Now set save_strategy="steps" - should create directory when needed
            args.save_strategy = "steps"
            args.save_steps = 1
            self.assertFalse(os.path.exists(output_dir))  # Still shouldn't exist

            # Directory should be created when actually needed (e.g. in Trainer)

    def test_torch_empty_cache_steps_requirements(self):
        """Test that torch_empty_cache_steps is a positive integer or None."""

        # None is acceptable (feature is disabled):
        args = TrainingArguments(torch_empty_cache_steps=None)
        self.assertIsNone(args.torch_empty_cache_steps)

        # non-int is unacceptable:
        with self.assertRaises(ValueError):
            TrainingArguments(torch_empty_cache_steps=1.0)
        with self.assertRaises(ValueError):
            TrainingArguments(torch_empty_cache_steps="none")

        # negative int is unacceptable:
        with self.assertRaises(ValueError):
            TrainingArguments(torch_empty_cache_steps=-1)

        # zero is unacceptable:
        with self.assertRaises(ValueError):
            TrainingArguments(torch_empty_cache_steps=0)

        # positive int is acceptable:
        args = TrainingArguments(torch_empty_cache_steps=1)
        self.assertEqual(args.torch_empty_cache_steps, 1)