@@ -29,6 +29,22 @@ class Accelerator(DistributionStrategy):
29
29
Attributes:
30
30
spawn (bool): Whether to use multiprocessing spawn mode for process initialization.
31
31
nprocs (int): Number of processes to launch for distributed training.
32
+ strategy (str): Detected strategy for process launch ("torchrun", "slurm", or "default").
33
+
34
+
35
+ Inherited Attributes:
36
+ backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
37
+ compute_setup (str): Desired computation device setup.
38
+ log_setup (str): Desired logging device setup.
39
+ rank (int | None): Global rank of the process (to be set during environment setup).
40
+ world_size (int | None): Total number of processes (to be set during environment setup).
41
+ local_rank (int | None): Local rank on the node (to be set during environment setup).
42
+ master_addr (str | None): Master node address (to be set during environment setup).
43
+ master_port (str | None): Master node port (to be set during environment setup).
44
+ device (str | None): Computation device, e.g., "cpu" or "cuda:<local_rank>".
45
+ log_device (str | None): Logging device, e.g., "cpu" or "cuda:<local_rank>".
46
+ dtype (torch.dtype): Data type for controlling numerical precision (e.g., torch.float32).
47
+ data_dtype (torch.dtype): Data type for controlling datasets precision (e.g., torch.float16).
32
48
"""
33
49
34
50
def __init__ (
@@ -62,6 +78,11 @@ def __init__(
62
78
self .strategy = self .detect_strategy ()
63
79
self ._log_warnings ()
64
80
81
+ # Default values
82
+ self .rank = 0
83
+ self .local_rank = 0
84
+ self .world_size = 1
85
+
65
86
def setup (self , process_rank : int ) -> None :
66
87
"""
67
88
Sets up the distributed training environment for a given process.
0 commit comments