Skip to content

Commit

Permalink
enable online dataprocessing without forcing precomputation
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Mar 7, 2025
1 parent cad195e commit fdcae95
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 57 deletions.
4 changes: 4 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class BaseArgs:
# Dataset arguments
dataset_config: str = None
dataset_shuffle_buffer_size: int = 1
enable_precomputation: bool = False
precomputation_items: int = 512
precomputation_dir: Optional[str] = None
precomputation_once: bool = False
Expand Down Expand Up @@ -420,6 +421,7 @@ def to_dict(self) -> Dict[str, Any]:
dataset_arguments = {
"dataset_config": self.dataset_config,
"dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
"enable_precomputation": self.enable_precomputation,
"precomputation_items": self.precomputation_items,
"precomputation_dir": self.precomputation_dir,
"precomputation_once": self.precomputation_once,
Expand Down Expand Up @@ -625,6 +627,7 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--dataset_config", type=str, required=True)
parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
parser.add_argument("--enable_precomputation", action="store_true")
parser.add_argument("--precomputation_items", type=int, default=512)
parser.add_argument("--precomputation_dir", type=str, default=None)
parser.add_argument("--precomputation_once", action="store_true")
Expand Down Expand Up @@ -761,6 +764,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
# Dataset arguments
result_args.dataset_config = args.dataset_config
result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
result_args.enable_precomputation = args.enable_precomputation
result_args.precomputation_items = args.precomputation_items
result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
result_args.precomputation_once = args.precomputation_once
Expand Down
10 changes: 9 additions & 1 deletion finetrainers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
initialize_dataset,
wrap_iterable_dataset_for_preprocessing,
)
from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable
from .precomputation import (
InMemoryDataIterable,
InMemoryDistributedDataPreprocessor,
InMemoryOnceDataIterable,
PrecomputedDataIterable,
PrecomputedDistributedDataPreprocessor,
PrecomputedOnceDataIterable,
initialize_preprocessor,
)
from .sampler import ResolutionSampler
from .utils import find_files
3 changes: 1 addition & 2 deletions finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,10 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)

has_tar_files = any(file.endswith(".tar") for file in repo_file_list)
has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite)

# TODO(aryan): handle parquet
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
Expand Down
233 changes: 224 additions & 9 deletions finetrainers/data/precomputation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,148 @@
import pathlib
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
from tqdm.auto import tqdm

from .. import utils
from ..logging import get_logger


class DistributedDataPreprocessor:
logger = get_logger()


def initialize_preprocessor(
rank: int,
num_items: int,
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
save_dir: Optional[str] = None,
enable_precomputation: bool = False,
) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]:
if enable_precomputation:
return PrecomputedDistributedDataPreprocessor(rank, num_items, processor_fn, save_dir)
return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn)


class DistributedDataProcessorMixin:
def consume(self, *args, **kwargs):
raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.")

def consume_once(self, *args, **kwargs):
raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.")

@property
def requires_data(self):
raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.")


class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin):
def __init__(
self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]]
) -> None:
super().__init__()

self._rank = rank
self._num_items = num_items
self._processor_fn = processor_fn

self._cached_samples = []
self._buffer = InMemoryDataBuffer(num_items)
self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None

def consume(
self,
data_type: str,
components: Dict[str, Any],
data_iterator,
generator: Optional[torch.Generator] = None,
cache_samples: bool = False,
use_cached_samples: bool = False,
drop_samples: bool = False,
) -> Iterable[Dict[str, Any]]:
if data_type not in self._processor_fn.keys():
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
if cache_samples:
if use_cached_samples:
raise ValueError("Cannot cache and use cached samples at the same time.")
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")

for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
item = next(data_iterator)
if cache_samples:
self._cached_samples.append(item)
item = self._processor_fn[data_type](**item, **components, generator=generator)
self._buffer.add(data_type, item)

if drop_samples:
del self._cached_samples
self._cached_samples = []

self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer)
return iter(self._preprocessed_iterator)

def consume_once(
self,
data_type: str,
components: Dict[str, Any],
data_iterator,
generator: Optional[torch.Generator] = None,
cache_samples: bool = False,
use_cached_samples: bool = False,
drop_samples: bool = False,
) -> Iterable[Dict[str, Any]]:
if data_type not in self._processor_fn.keys():
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
if cache_samples:
if use_cached_samples:
raise ValueError("Cannot cache and use cached samples at the same time.")
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")

for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
item = next(data_iterator)
if cache_samples:
self._cached_samples.append(item)
item = self._processor_fn[data_type](**item, **components, generator=generator)
self._buffer.add(data_type, item)

if drop_samples:
del self._cached_samples
self._cached_samples = []

self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer)
return iter(self._preprocessed_iterator)

@property
def requires_data(self):
if self._preprocessed_iterator is None:
return True
return self._preprocessed_iterator.requires_data


class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin):
def __init__(
self,
rank: int,
num_items: int,
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
save_dir: str,
) -> None:
super().__init__()

self._rank = rank
self._num_items = num_items
self._processor_fn = processor_fn
self._save_dir = pathlib.Path(save_dir)

self._cached_samples = []
self._preprocessed_iterator: "PreprocessedDataIterable" = None
self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None

self._save_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -59,9 +180,8 @@ def consume(
if drop_samples:
del self._cached_samples
self._cached_samples = []
utils.free_memory()

self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type)
self._preprocessed_iterator = PrecomputedDataIterable(self._rank, self._save_dir, data_type)
return iter(self._preprocessed_iterator)

def consume_once(
Expand Down Expand Up @@ -95,9 +215,8 @@ def consume_once(
if drop_samples:
del self._cached_samples
self._cached_samples = []
utils.free_memory()

self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type)
self._preprocessed_iterator = PrecomputedOnceDataIterable(self._rank, self._save_dir, data_type)
return iter(self._preprocessed_iterator)

@property
Expand All @@ -107,7 +226,72 @@ def requires_data(self):
return self._preprocessed_iterator.requires_data


class PreprocessedDataIterable:
class InMemoryDataIterable:
"""
An iterator that loads data items from an in-memory buffer. Once all the data is consumed,
`requires_data` is set to True, indicating that the more data is required and the preprocessor's
consume method should be called again.
"""

def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
self._rank = rank
self._data_type = data_type
self._buffer = buffer

self._requires_data = False

def __iter__(self) -> Iterable[Dict[str, Any]]:
while length := self._buffer.get_length(self._data_type) > 0:
yield self._buffer.get(self._data_type)
if length == 1:
self._requires_data = True

def __len__(self) -> int:
return self._buffer.get_length(self._data_type)

@property
def requires_data(self):
return self._requires_data


class InMemoryOnceDataIterable:
"""
An iterator that loads data items from an in-memory buffer. This iterator will never set
`requires_data` to True, as it is assumed that all the data was configured to be preprocessed
by the user. The data will indefinitely be cycled from the buffer.
"""

def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
self._rank = rank
self._data_type = data_type
self._buffer = buffer

self._requires_data = False

def __iter__(self) -> Iterable[Dict[str, Any]]:
while True:
if self._buffer.get_length(self._data_type) == 0:
self._requires_data = True
break
item = self._buffer.get(self._data_type)
yield item
self._buffer.add(self._data_type, item)

def __len__(self) -> int:
return self._buffer.get_length(self._data_type)

@property
def requires_data(self):
return self._requires_data


class PrecomputedDataIterable:
"""
An iterator that loads preconfigured number of data items from disk. Once all the data is
loaded, `requires_data` is set to True, indicating that the more data is required and
the preprocessor's consume method should be called again.
"""

def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
self._rank = rank
self._save_dir = pathlib.Path(save_dir)
Expand All @@ -130,7 +314,13 @@ def requires_data(self):
return self._requires_data


class PreprocessedOnceDataIterable:
class PrecomputedOnceDataIterable:
"""
An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator
will never set `requires_data` to True, as it is assumed that all the data was configured to
be preprocessed by the user.
"""

def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
self._rank = rank
self._save_dir = pathlib.Path(save_dir)
Expand All @@ -153,6 +343,31 @@ def requires_data(self):
return self._requires_data


class InMemoryDataBuffer:
def __init__(self, max_limit: int = -1) -> None:
self.max_limit = max_limit
self.buffer: Dict[str, List[str]] = {}

def add(self, data_type: str, item: Dict[str, Any]) -> None:
if data_type not in self.buffer:
self.buffer[data_type] = []
if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit:
logger.log_freq(
"WARN",
"IN_MEMORY_DATA_BUFFER_FULL",
"Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.",
64,
)
self.buffer[data_type].pop(0)
self.buffer[data_type].append(item)

def get(self, data_type: str) -> Dict[str, Any]:
return self.buffer[data_type].pop(0)

def get_length(self, data_type: str) -> int:
return len(self.buffer[data_type])


def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
filename = directory / f"{data_type}-{rank}-{index}.pt"
torch.save(item, filename.as_posix())
Expand Down
Loading

0 comments on commit fdcae95

Please sign in to comment.