Skip to content

Commit

Permalink
Minor Docstring changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya committed Feb 12, 2025
1 parent aab2519 commit 17ce3eb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
21 changes: 21 additions & 0 deletions qadence/ml_tools/train_utils/acclerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ class Accelerator(DistributionStrategy):
Attributes:
spawn (bool): Whether to use multiprocessing spawn mode for process initialization.
nprocs (int): Number of processes to launch for distributed training.
strategy (str): Detected strategy for process launch ("torchrun", "slurm", or "default").
Inherited Attributes:
backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
compute_setup (str): Desired computation device setup.
log_setup (str): Desired logging device setup.
rank (int | None): Global rank of the process (to be set during environment setup).
world_size (int | None): Total number of processes (to be set during environment setup).
local_rank (int | None): Local rank on the node (to be set during environment setup).
master_addr (str | None): Master node address (to be set during environment setup).
master_port (str | None): Master node port (to be set during environment setup).
device (str | None): Computation device, e.g., "cpu" or "cuda:<local_rank>".
log_device (str | None): Logging device, e.g., "cpu" or "cuda:<local_rank>".
dtype (torch.dtype): Data type for controlling numerical precision (e.g., torch.float32).
data_dtype (torch.dtype): Data type for controlling datasets precision (e.g., torch.float16).
"""

def __init__(
Expand Down Expand Up @@ -62,6 +78,11 @@ def __init__(
self.strategy = self.detect_strategy()
self._log_warnings()

# Default values
self.rank = 0
self.local_rank = 0
self.world_size = 1

def setup(self, process_rank: int) -> None:
"""
Sets up the distributed training environment for a given process.
Expand Down
1 change: 0 additions & 1 deletion qadence/ml_tools/train_utils/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class DistributionStrategy:
backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
compute_setup (str): Desired computation device setup.
log_setup (str): Desired logging device setup.
strategy (str): Detected strategy for process launch ("torchrun", "slurm", or "default").
rank (int | None): Global rank of the process (to be set during environment setup).
world_size (int | None): Total number of processes (to be set during environment setup).
local_rank (int | None): Local rank on the node (to be set during environment setup).
Expand Down

0 comments on commit 17ce3eb

Please sign in to comment.