Skip to content

Commit 17ce3eb

Browse files
committed
Minor Docstring changes
1 parent aab2519 commit 17ce3eb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

qadence/ml_tools/train_utils/acclerator.py

+21
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ class Accelerator(DistributionStrategy):
2929
Attributes:
3030
spawn (bool): Whether to use multiprocessing spawn mode for process initialization.
3131
nprocs (int): Number of processes to launch for distributed training.
32+
strategy (str): Detected strategy for process launch ("torchrun", "slurm", or "default").
33+
34+
35+
Inherited Attributes:
36+
backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
37+
compute_setup (str): Desired computation device setup.
38+
log_setup (str): Desired logging device setup.
39+
rank (int | None): Global rank of the process (to be set during environment setup).
40+
world_size (int | None): Total number of processes (to be set during environment setup).
41+
local_rank (int | None): Local rank on the node (to be set during environment setup).
42+
master_addr (str | None): Master node address (to be set during environment setup).
43+
master_port (str | None): Master node port (to be set during environment setup).
44+
device (str | None): Computation device, e.g., "cpu" or "cuda:<local_rank>".
45+
log_device (str | None): Logging device, e.g., "cpu" or "cuda:<local_rank>".
46+
dtype (torch.dtype): Data type for controlling numerical precision (e.g., torch.float32).
47+
data_dtype (torch.dtype): Data type for controlling datasets precision (e.g., torch.float16).
3248
"""
3349

3450
def __init__(
@@ -62,6 +78,11 @@ def __init__(
6278
self.strategy = self.detect_strategy()
6379
self._log_warnings()
6480

81+
# Default values
82+
self.rank = 0
83+
self.local_rank = 0
84+
self.world_size = 1
85+
6586
def setup(self, process_rank: int) -> None:
6687
"""
6788
Sets up the distributed training environment for a given process.

qadence/ml_tools/train_utils/strategy.py

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class DistributionStrategy:
2525
backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
2626
compute_setup (str): Desired computation device setup.
2727
log_setup (str): Desired logging device setup.
28-
strategy (str): Detected strategy for process launch ("torchrun", "slurm", or "default").
2928
rank (int | None): Global rank of the process (to be set during environment setup).
3029
world_size (int | None): Total number of processes (to be set during environment setup).
3130
local_rank (int | None): Local rank on the node (to be set during environment setup).

0 commit comments

Comments
 (0)