diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d9f89997c..64ecbde1f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 5f5e59288..0893de47d 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/Dockerfile b/Dockerfile index 5804d0e47..7ff5d7a74 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/data/data/data_loader.py b/fast_llm/data/data/data_loader.py new file mode 100644 index 000000000..ba7e5e612 --- /dev/null +++ b/fast_llm/data/data/data_loader.py @@ -0,0 +1,72 @@ +import itertools +import typing + +import torch.utils.data + +from fast_llm.core.distributed import broadcast_object + + +class SampledDatasetIterator(torch.utils.data.Sampler): + """ + A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). + To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. + """ + + def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): + super().__init__() + self._total_samples = total_samples + self._begin_index = begin_index + self._batch_size = micro_batch_size * data_parallel + self._start_idx = data_rank * micro_batch_size + self._end_idx = (data_rank + 1) * micro_batch_size + + def __len__(self) -> int: + return self._total_samples + + def __iter__(self) -> typing.Iterator[list[int]]: + for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): + yield list(range(idx + self._start_idx, idx + self._end_idx)) + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + data_loader: torch.utils.data.dataloader.DataLoader, + process_group: torch.distributed.ProcessGroup | None, + ): + self._data_loader = data_loader + self._rank = 0 if process_group is None else process_group.rank() + self._process_group = process_group + + def __iter__(self): + if self._rank == 0: + self._iterator = iter(self._data_loader) + else: + self._iterator = itertools.repeat(None) + if self._process_group is None: + return self._iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + + try: + data = next(self._iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self._process_group, 0) + + if isinstance(data, Exception): + raise data + + return data diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index dbd770895..3af86652a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,12 +8,12 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor -from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank @@ -116,20 +116,23 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa - batch_sampler=SampledDatasetIterator( - total_samples=len(self._datasets[dataset_name]), - begin_index=consumed_samples, - micro_batch_size=batch_config.micro_batch_size, - data_rank=self._distributed.config.batch_data_rank, - data_parallel=self._distributed.config.batch_data_parallel, - ), - num_workers=num_workers, - prefetch_factor=prefetch_factor, - pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, - ) + data_loader = torch.utils.data.DataLoader( + self._datasets[dataset_name], # noqa + batch_sampler=SampledDatasetIterator( + total_samples=len(self._datasets[dataset_name]), + begin_index=consumed_samples, + micro_batch_size=batch_config.micro_batch_size, + data_rank=self._distributed.config.batch_data_rank, + data_parallel=self._distributed.config.batch_data_parallel, + ), + num_workers=num_workers, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=LanguageModelBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) + + if self._datasets[dataset_name].requires_broadcast: + data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group) + + return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708b..1df24e92b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -5,6 +5,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData + from fast_llm.data.dataset.sampled import SampledIterableDataset class Dataset[SampleType: Sample](abc.ABC): @@ -27,6 +28,14 @@ def __getstate__(self): del state["__orig_class__"] return state + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return False + class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ @@ -48,3 +57,14 @@ class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass + + +class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]): + @abc.abstractmethod + def __iter__(self) -> typing.Iterator[SampleType]: + pass + + def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]": + from fast_llm.data.dataset.sampled import SampledIterableDataset + + return SampledIterableDataset(self, config) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 2858d8d18..003b1dfb0 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -15,6 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -298,3 +299,46 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) else: raise FileNotFoundError(self.path) + + +@config_class() +class RedisConfig(Config): + # TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config? + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + def get_client(self): + import redis + + return redis.Redis(self.host, self.port) + + +@config_class(dynamic_type={SampledDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]): + """ + Configuration for a streaming dataset that reads training data from a Redis stream. + """ + + _abstract = False + + acknowledge_interval: int = Field( + default=10, + desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.", + hint=FieldHint.core, + ) + + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + from fast_llm.data.dataset.streaming import RedisStreamingDataset + + return RedisStreamingDataset[StreamingDatasetConfig, SampleType](self, sampling.distributed.config).sample( + sampling + ) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 01f3195e4..27070f674 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -51,3 +51,11 @@ def __getitem__(self, index: int) -> SampleType: @property def name(self) -> str: return self._dataset.name + + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return self._dataset.requires_broadcast diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index d51a68746..8cf7d938a 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -8,7 +8,7 @@ import torch import yaml -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -111,6 +111,10 @@ def __init__( # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `Data`. + @property + def requires_broadcast(self) -> bool: + return self._indexed_dataset.requires_broadcast + def _sample(self) -> None: """ Create a `SampledDataset` with the requested parameters. @@ -429,3 +433,61 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]): + def __init__( + self, + dataset: SamplableIterableDataset[SampleType], + sampling: SamplingData, + ): + self._dataset = dataset + self._config = sampling.config + self._parameters = sampling.parameters + self._documents: list[SampleType] = [] + self._current_length = 0 + self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + # Delay iterator creation to avoid pickling issues. + self._iterator: typing.Iterator[SampleType] | None = None + + @property + def requires_broadcast(self) -> bool: + # TODO: ====== fix ====== + # return self._iterator.requires_broadcast + return True + + def __getitem__(self, index: int) -> SampleType: + if self._iterator is None: + self._iterator = iter(self._dataset) + while self._current_length < self._sample_length: + document = next(self._iterator) + if len(document) > self._sample_length: + logging.warning(f"Dropping document with length {len(document)} > {self._sample_length}.") + continue + self._documents.append(document) + self._current_length += len(document) + + if self._current_length == self._sample_length: + documents = self._documents + self._documents = [] + self._current_length = 0 + else: + last_length = len(self._documents[-1]) + remaining_length = last_length - (self._current_length - self._sample_length) + if self._parameters.truncate_documents: + documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)] + self._documents = [self._documents[-1].crop(remaining_length, last_length)] + else: + documents = self._documents[:-1] + [self._documents[0].get_padding(remaining_length)] + self._documents = [self._documents[-1]] + self._current_length = len(self._documents[0]) + sample = documents[0].from_documents(documents) + Assert.eq(len(sample), self._sample_length) + return sample + + def __len__(self) -> int: + return self._parameters.num_samples + + @property + def name(self) -> str: + return self._dataset.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py new file mode 100644 index 000000000..cff028f62 --- /dev/null +++ b/fast_llm/data/dataset/streaming.py @@ -0,0 +1,127 @@ +import json +import typing + +import redis +import torch.utils.data + +from fast_llm.config import Configurable +from fast_llm.data.dataset.abstract import SamplableIterableDataset +from fast_llm.data.dataset.config import StreamingDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames + + +def dtype_from_string(name: str) -> torch.dtype: + try: + return getattr(torch, name) + except AttributeError: + raise ValueError(f"Unknown torch dtype: {name}") + + +REDIS_DATA_KEY = "fast_llm_streaming" +REDIS_GROUP_NAME = "fast_llm_group" + + +class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( + Configurable[ConfigType], SamplableIterableDataset[SampleType] +): + def __init__(self, config: ConfigType, distributed_config: DistributedConfig): + super().__init__(config) + # if distributed_config.pipeline_parallel > 1: + # NOTE: It is not yet clear whether the issue comes from the streaming dataset + # itself or from the distributed data-loader wrappers, but currently it + # interferes with pipeline-parallel training and causes a timeout during + # the training step. + # raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") + + self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" + self._config = config + self._rank = distributed_config.batch_data_rank + self.is_batch_data_group_leader = ( + distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank == 0 + ) + + # TODO: Not needed? + # if distributed_config.rank == 0: + # redis_client = redis.Redis(host=self._config.host, port=self._config.port) + # redis_client.hset(f"{REDIS_DATA_KEY}:consumer_count", "0", str(distributed_config.batch_data_parallel)) + + @property + def requires_broadcast(self) -> bool: + return True + + @property + def name(self) -> str: + return self._name + + def __iter__(self) -> typing.Iterator[LanguageModelSample]: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError("StreamingDataset can work only with one instance per rank") + + if not self.is_batch_data_group_leader: + raise RuntimeError("Must be only called on the batch data group leader") + + client = redis.Redis(host=self._config.host, port=self._config.port) + + # Create the consumer group at the start of the stream ("0") + # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True + try: + client.xgroup_create(name=REDIS_DATA_KEY, groupname=REDIS_GROUP_NAME, id="0", mkstream=True) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + # Consumer group already exists + pass + else: + raise + + processed = 0 + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = client.xreadgroup( + groupname=REDIS_GROUP_NAME, + consumername=f"fast_llm_consumer_{self._rank}", + # ">" reads only new messages that have not been delivered to any consumer + streams={REDIS_DATA_KEY: ">"}, + count=1, + block=1000, + # No explicit ACK: messages are processed immediately; on rank failure the job restarts, + # so message loss is acceptable and simplifies coordination + noack=True, + ) + if messages: + for stream_key, msgs in messages: + assert stream_key == REDIS_DATA_KEY.encode() + for msg_id, msg_data in msgs: + processed += 1 + # TODO: or do it after processing all received messaged then count > 1? + if processed % self._config.acknowledge_interval == 0: + client.hset(f"{REDIS_DATA_KEY}:ack", str(self._rank), msg_id) + + yield self._read_document(json.loads(msg_data[b"data"])) + + def _read_document(self, data: dict) -> LanguageModelSample: + tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) + sample_size = len(tokens) + if "loss_masking_spans" in data: + loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size) + else: + loss_masking_spans = None + if "chosen_spans" in data: + chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size) + else: + chosen_spans = None + if "rejected_spans" in data: + rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size) + else: + rejected_spans = None + return LanguageModelSample( + TokenSample(tokens, [sample_size]), + loss_masking_spans, + chosen_spans, + rejected_spans, + ) diff --git a/fast_llm/data/iterator.py b/fast_llm/data/iterator.py deleted file mode 100644 index a407c0258..000000000 --- a/fast_llm/data/iterator.py +++ /dev/null @@ -1,25 +0,0 @@ -import typing - -import torch.utils.data - - -class SampledDatasetIterator(torch.utils.data.Sampler): - """ - A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). - To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. - """ - - def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): - super().__init__() - self._total_samples = total_samples - self._begin_index = begin_index - self._batch_size = micro_batch_size * data_parallel - self._start_idx = data_rank * micro_batch_size - self._end_idx = (data_rank + 1) * micro_batch_size - - def __len__(self) -> int: - return self._total_samples - - def __iter__(self) -> typing.Iterator[list[int]]: - for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): - yield list(range(idx + self._start_idx, idx + self._end_idx)) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 7a257a5fa..bbb0fa34b 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -71,6 +71,31 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + def iter_tensors( + self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata" + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `state_dict` until that tensor is available. + state_dict = {} + for parameter_name, shard_name, tensor in self._model.get_state_tensor_iterator( + self.get_shard_names(config), config.data_type + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for exported_name, exported_tensor in self._convert_state_dict(shard_state_dict, True).items(): + yield shard_name, self._get_key(exported_name, shard_name), exported_tensor + + for shard_name, shard_state_dict in state_dict.items(): + assert ( + not shard_state_dict + ), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}" + @classmethod @abc.abstractmethod def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7f4b7bc38..c2d6d1405 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -97,6 +97,31 @@ def setup(self, group: "ProcessGroup|None"): def check_ranks_in_range(self, start, stop): check_ranks_in_range(self.global_ranks, start, stop) + @classmethod + def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: tuple[int, int]) -> typing.Self: + start = global_rank + rank = 0 + world_size = 1 + for i, (size, stride) in enumerate(sizes_and_strides): + if i > 0: + Assert.multiple(stride, sizes_and_strides[i - 1][1]) + rank_ = global_rank // stride % size + start -= rank_ * stride + rank += world_size * rank_ + world_size *= size + global_ranks = [start] + for size, stride in sizes_and_strides: + if size == 1: + continue + if len(global_ranks) == 1: + global_ranks = range(start, start + size * stride, stride) + elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start: + global_ranks = range(start, start + size * stride, global_ranks.step) + else: + global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks] + Assert.eq(len(global_ranks), world_size) + return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) + def check_ranks_in_range(global_ranks, start, stop): Assert.geq(min(global_ranks), start) @@ -112,6 +137,7 @@ class DistributedDimNames: sequence_data = "sequence_data" batch_data = "batch_data" tensor_and_sequence_data = "tensor_and_sequence_data" + model_and_sequence_data = "model_and_sequence_data" tensor_and_data = "tensor_and_data" @@ -300,88 +326,68 @@ def _validate(self) -> None: else: self.distributed_dims = {} - data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + tensor_stride = 1 + sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + batch_data_stride = sequence_data_stride * self.sequence_data_parallel pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.world, - size=self.world_size, - rank=self.rank, - global_ranks=range(self.world_size), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.world, + (self.world_size, 1), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.data, + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.data, - size=self.data_parallel, - rank=self.data_rank, - global_ranks=self._get_global_ranks(self.data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.pipeline, - size=self.pipeline_parallel, - rank=self.pipeline_rank, - global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.sequence_data, + (self.sequence_data_parallel, sequence_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor, - size=self.tensor_parallel, - rank=self.tensor_rank, - global_ranks=self._get_global_ranks(self.tensor_parallel, 1), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.sequence_data, - size=self.sequence_data_parallel, - rank=self.sequence_data_rank, - global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.batch_data, - size=self.batch_data_parallel, - rank=self.batch_data_rank, - global_ranks=self._get_global_ranks( - self.batch_data_parallel, data_stride * self.sequence_data_parallel - ), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), ) - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self.pipeline_first: - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_sequence_data, - size=self.sequence_data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1), - ) - ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_data, - size=self.data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1), - ) - ) - super()._validate() + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.model_and_sequence_data, + (self.tensor_parallel, tensor_stride), + ( + (self.pipeline_parallel, pipeline_stride) + if self.pipeline_first + else (self.sequence_data_parallel, sequence_data_stride) + ), + ( + (self.sequence_data_parallel, sequence_data_stride) + if self.pipeline_first + else (self.pipeline_parallel, pipeline_stride) + ), + ) + super()._validate() if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def _get_global_ranks(self, size: int, stride: int) -> range: - start = self.rank // (size * stride) * size * stride + self.rank % stride - return range(start, start + size * stride, stride) + def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None: + self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index aa2be6ce7..d93e17d1c 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -28,6 +28,7 @@ def __init__( local_world_size: int | None = None, timeout: float = 60, use_cpu: bool = False, + init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, ): @@ -58,7 +59,7 @@ def __init__( # TODO: Allow other init methods? self.store, _, _ = next( torch.distributed.rendezvous( - "env://", + init_method, self._rank, self._world_size, timeout=datetime.timedelta(seconds=timeout), @@ -180,14 +181,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self._config.pipeline_first: - self.tensor_and_sequence_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] - ) - self.tensor_and_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_data] - ) + self.tensor_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] + ) + self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data]) + self.model_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] + ) self._config.log_first_rank(f"Setting random seeds...") diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 6a6223cb7..ed6835140 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,6 +1,8 @@ import logging import typing +import torch + from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig @@ -30,6 +32,20 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) + def iter_checkpoint( + self, + config: CheckpointSaveConfig, + extra_metadata: dict | None = None, + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # TODO: Handle barriers, ok file, mkdir, etc. here + converter = config.format.get_handler_class()(self) + fast_llm_metadata = self._config.to_metadata( + config, + shards=converter.get_shard_names(config), + metadata={} if extra_metadata is None else extra_metadata, + ) + yield from converter.iter_tensors(config, fast_llm_metadata) + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 867cca984..02829c580 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -16,6 +16,7 @@ skip_valid_if_none, ) from fast_llm.data.data.config import DataConfig +from fast_llm.data.dataset.config import RedisConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -321,6 +322,101 @@ def _validate(self) -> None: self.wandb.alert.assert_sub_interval(self.logs) +@config_class() +class TrainerEvent(Config): + enabled: bool = Field( + default=False, + desc="Flag indicating whether this event is enabled. If False, the event will be skipped.", + hint=FieldHint.feature, + ) + + +@config_class() +class WeightsBroadcastEventConfig(TrainerEvent): + """ + Event sent to indicate that updated weights are ready for broadcast. + """ + + initial_weights_step_message_type: str = Field( + default="initial_weights_step", + desc="Message indicating that weights the training starting/ continuing from.", + hint=FieldHint.feature, + ) + + initial_weights_step_message_includes_weights: bool = Field( + default=False, + desc=( + "Whether to include the loaded model weights in the initial event message. " + "Useful when training restarts from an internal checkpoint format that " + "which does not have an exported checkpoint for that step." + ), + hint=FieldHint.feature, + ) + + weights_ready_message_type: str = Field( + default="weights_ready", + desc="Message indicating that weights are ready to be broadcast.", + hint=FieldHint.feature, + ) + + # NCCL rendezvous details + rdvz_master_address: str | None = Field( + default=None, + desc="Master address for the external NCCL process group.", + hint=FieldHint.feature, + ) + + rdvz_master_port: int | None = Field( + default=None, + desc="Master port for the external NCCL process group.", + hint=FieldHint.feature, + ) + + world_size: int | None = Field( + default=None, + desc="World size of the external NCCL process group.", + hint=FieldHint.feature, + ) + + rank: int | None = Field( + default=None, + desc="Rank of this process in the external NCCL process group.", + hint=FieldHint.feature, + ) + + +@config_class() +class TrainingFinishedEventConfig(TrainerEvent): + """ + Event sent to indicate that training has completed. + """ + + training_finished_message_type: str = Field( + default="training_finished", + desc="Message indicating that weights the training starting/ continuing from.", + hint=FieldHint.feature, + ) + + +@config_class() +class TrainerEventsConfig(RedisConfig): + """ + Aggregates all trainer-side Redis-based event configurations. + """ + + weights_broadcast: WeightsBroadcastEventConfig = Field( + default=None, + desc="Configuration for signaling weight-ready events via Redis.", + hint=FieldHint.feature, + ) + + training_finished: TrainingFinishedEventConfig = Field( + default=None, + desc="Configuration for signaling training-finished events via Redis.", + hint=FieldHint.feature, + ) + + @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True @@ -352,6 +448,12 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) + events: TrainerEventsConfig = Field( + default=None, + desc="Optional Trainer event configurations (weight broadcast, training finished, etc.).", + hint=FieldHint.feature, + ) + def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7225ed20a..a2f98c05f 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -36,6 +36,7 @@ TrainingCheckpointConfig, TrainingEvaluatorConfig, ) +from fast_llm.engine.training.trainer_events import TrainerEvents from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, log_memory_usage from fast_llm.utils import Assert, Interrupter, get_and_reset_memory_usage_mib @@ -131,6 +132,8 @@ def __init__(self, config: TrainerConfig): self._is_evaluation_only = config.training.train_iters == 0 + self.trainer_events = TrainerEvents(config.events) + self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( @@ -286,6 +289,7 @@ def run(self) -> None: assert self._is_setup with self._wandb: self._run_training() + self.trainer_events.send_training_finished() def _run_training(self) -> None: self._prepare_training_state() @@ -358,6 +362,11 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") + + self.trainer_events.send_initial_weights_step( + self._completed_steps, self._multi_stage, self._config.training.export + ) + torch.cuda.synchronize() start_time = time.perf_counter() last_time = start_time @@ -384,6 +393,9 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters += 1 for name, value in reduced_losses.items(): total_losses[name] += value + self.trainer_events.send_weights( + self._completed_steps, self._multi_stage, self._config.training.export + ) else: skipped_iters += 1 nan_iters += not all(math.isfinite(loss) for loss in reduced_losses.values()) diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py new file mode 100644 index 000000000..0937999fc --- /dev/null +++ b/fast_llm/engine/training/trainer_events.py @@ -0,0 +1,109 @@ +import json +import logging + +import redis +import torch.distributed + +from fast_llm.data.dataset.config import RedisConfig +from fast_llm.engine.config_utils.run import is_main_rank +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.training.config import TrainerEventsConfig, TrainingExportConfig + +logger = logging.getLogger(__name__) + + +REDIS_TRAINING_KEY = "fast_llm_events" + + +class RedisEventSender: + def __init__(self, config: RedisConfig): + self.config = config + self.client = None + + if is_main_rank(): + self.client = redis.Redis( + host=config.host, + port=config.port, + ) + + def send(self, msg_type: str, payload: dict | None = None): + if not is_main_rank(): + return + + if not payload: + payload = {} + payload.update({"type": msg_type}) + + self.client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps(payload)}) + + +class TrainerEvents: + """ + Main helper class holding all event channels. + Each event may have its own RedisConfig. + + Usage: + events = TrainerEvents(cfg.events) + events.weights_broadcast.send({"step": 100}) + events.training_finished.send() + """ + + def __init__(self, config: TrainerEventsConfig): + self.config = config + + if config.weights_broadcast.enabled or config.training_finished.enabled: + self.sender = RedisEventSender(config.redis) + else: + self.sender = None + + if config.weights_broadcast.enabled and is_main_rank(): + init_method = ( + f"tcp://{config.weights_broadcast.rdvz_master_address}:{config.weights_broadcast.rdvz_master_port}" + ) + logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + self.weights_pg = torch.distributed.init_process_group( + backend="nccl", + init_method=init_method, + world_size=config.weights_broadcast.world_size, + rank=config.weights_broadcast.rank, + ) + logger.info(f"Weights broadcast rendezvous at {init_method} connected") + else: + self.weights_pg = None + + def send_initial_weights_step(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): + if self.config.weights_broadcast.enabled: + self.sender.send( + msg_type=self.config.weights_broadcast.initial_weights_step_message_type, payload={"step": step} + ) + if self.config.weights_broadcast.initial_weights_step_message_includes_weights: + self._broadcast_weights(model, export_config) + + def send_weights(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): + if self.config.weights_broadcast.enabled: + self.sender.send(msg_type=self.config.weights_broadcast.weights_ready_message_type, payload={"step": step}) + self._broadcast_weights(model, export_config) + + def send_training_finished(self): + if self.config.training_finished.enabled: + self.sender.send(msg_type=self.config.training_finished.training_finished_message_type) + + if is_main_rank() and self.config.weights_broadcast.enabled: + torch.distributed.destroy_process_group() + + def _broadcast_weights(self, model: FastLLMModel, export_config: TrainingExportConfig): + for shard_name, layer_name, tensor in model.iter_checkpoint(export_config.get_save_config("", 10), {}): + if is_main_rank(): + meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] + torch.distributed.broadcast_object_list( + meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) + torch.distributed.broadcast( + tensor, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) + # Broadcast end of weights broadcast + if is_main_rank(): + meta = [None] + torch.distributed.broadcast_object_list( + meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc33454..57c9614bd 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import Assert @@ -17,6 +19,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -33,8 +51,22 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, False)) +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/redis/config.py b/fast_llm/redis/config.py new file mode 100644 index 000000000..e69de29bb diff --git a/setup.cfg b/setup.cfg index 005ae5a8a..f4ad02c43 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,8 @@ SSM = GENERATION = lm_eval>=0.4.9 +STREAMING = + redis>=7.1.0 # Required for supporting vision inputs VISION = @@ -78,6 +80,7 @@ DEV = setuptools>=80.9.0 # Dependency manager needs colorama to show colors. colorama>=0.4.6 + fakeredis>=2.32.1 # Required for building the documentation DOCS = diff --git a/tests/conftest.py b/tests/conftest.py index ba2927c64..e3a2df9a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,8 @@ ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path # isort: skip +from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip # Import all dynamic classes. import fast_llm.cli # isort: skip diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 000000000..e16583e8f --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,236 @@ +import contextlib +import logging +import pathlib +import typing + +import fakeredis +import pytest +import redis +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import RedisConfig, SamplingParameters, StreamingDatasetConfig +from fast_llm.data.dataset.streaming import RedisStreamingDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert +from tests.utils.redis import find_free_port, make_sampling, push_msg, redis_batch_producer +from tests.utils.subtest import DistributedTestContext +from tests.utils.utils import requires_cuda + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def fake_redis(monkeypatch): + """Monkeypatch redis.Redis globally.""" + fake_redis = fakeredis.FakeRedis() + monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) + try: + yield fake_redis + finally: + fake_redis.close() + + +@pytest.mark.parametrize( + "messages", + [ + (range(3),), + (range(3), range(3, 7)), + (range(3), range(5), [], [9, 4]), + ], +) +def test_streaming_dataset( + fake_redis: fakeredis.FakeRedis, + messages: tuple[list[int], ...], +): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + stream_config = StreamingDatasetConfig(port=find_free_port()) + dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) + for message in messages: + push_msg(fake_redis, list(message)) + for message in messages: + sample = next(dataset_iterator) + assert isinstance(sample, LanguageModelSample) + Assert.eq(sample.tokens.tokens.tolist(), list(message)) + Assert.eq(sample.tokens.lengths, [len(message)]) + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +@pytest.mark.parametrize( + ("messages", "expected_samples", "expected_lengths"), + [ + ((range(5),), (range(5),), ([5],)), # Single message, exact fit. + ((range(3), [3, 4]), (range(5),), ([3, 2],)), # Two messages, exact fit. + ((range(6), range(5)), (range(5),), ([5],)), # Two messages, one dropped. + ( + (range(3), range(5)), + ( + [0, 1, 2, -100, -100], + range(5), + ), + ( + [3, 2], + [5], + ), + ), # Two messages, one padded. + ], +) +def test_streaming_sampled_dataset( + fake_redis: fakeredis.FakeRedis, + messages: tuple[list[int], ...], + expected_samples: tuple[list[int], ...], + expected_lengths: tuple[int, ...], +): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + stream_config = StreamingDatasetConfig(port=find_free_port()) + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset_iterator = iter( + RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) + ) + for message in messages: + push_msg(fake_redis, list(message)) + for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): + sample = next(dataset_iterator) + assert isinstance(sample, LanguageModelSample) + Assert.eq(sample.tokens.tokens.tolist(), list(expected_sample)) + Assert.eq(sample.tokens.lengths, expected_lengths_) + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +_NUM_BATCHES = 1 + + +def _get_distributed_and_batch_config( + distributed_config_dict: dict[str, typing.Any], world_size: int = 1 +) -> tuple[DistributedConfig, GPTBatchConfig]: + distributed_config = DistributedConfig.from_dict( + distributed_config_dict, {"world_size": world_size, "local_world_size": world_size} + ) + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10) + batch_config.setup(distributed_config=distributed_config) + batch_config.validate() + return distributed_config, batch_config + + +def _run_test_data_streaming( + path: pathlib.Path, distributed_config: DistributedConfig, batch_config: GPTBatchConfig, port: int +): + redis_config = RedisConfig(port=port + 100) + + data = GPTData(GPTDataConfig(datasets={"train": {"type": "streaming", "port": port + 100}}), distributed_config) + distributed = Distributed(distributed_config) + with ( + redis_batch_producer(redis_config, batch_config) if distributed_config.rank == 0 else contextlib.nullcontext() + ): + data.setup( + distributed=distributed, + sampling_parameters={ + "train": SamplingParameters( + sequence_length=batch_config.sequence_length, + extra_tokens=0, + num_samples=batch_config.batch_size * _NUM_BATCHES, + truncate_documents=False, + ) + }, + preprocessing=LanguageModelPreprocessingConfig(), + cache_directory=path / "cache", + timeout=5, + ) + + data_iter = data.get_iterator(batch_config, "train", consumed_samples=0, num_workers=0, prefetch_factor=None) + batches = [next(data_iter) for _ in range(_NUM_BATCHES)] + path.mkdir(parents=True, exist_ok=True) + torch.save( + torch.stack([batch.tokens.tokens[:, 0] for batch in batches]), + path / f"rank_{distributed_config.batch_data_rank}_" + f"{distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank}.pt", + ) + # Wait for other processes to finish before shutting down the server. + safe_barrier(distributed.world_group, "streaming test end") + + +def check_data_streaming_results( + path: pathlib.Path, + distributed_config: DistributedConfig, + batch_config: GPTBatchConfig, +): + sample_indexes = set() + for batch_data_rank in range(distributed_config.batch_data_parallel): + batches_tokens = torch.load(path / f"rank_{batch_data_rank}_0.pt") + Assert.eq(batches_tokens.shape, (_NUM_BATCHES, batch_config.micro_batch_size)) + for model_and_sequence_data_rank in range( + 1, distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).size + ): + Assert.all_equal( + torch.load(path / f"rank_{batch_data_rank}_{model_and_sequence_data_rank}.pt"), batches_tokens + ) + sample_indexes.update(batches_tokens.flatten().tolist()) + Assert.eq(len(sample_indexes), _NUM_BATCHES * batch_config.batch_size) + + +def _run_test_data_streaming_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, port: int +) -> None: + # Import all dynamic classes. TODO: needed? + import fast_llm.cli # noqa + + for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS: + with test_context.subtest(base_path, name, num_gpus) as subtest: + if subtest.do_run: + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + _run_test_data_streaming(base_path / name, distributed_config, batch_config, port) + + +@requires_cuda +def test_data_streaming(result_path, worker_resources): + distributed_config, batch_config = _get_distributed_and_batch_config({}) + path = result_path / "data_streaming/single_gpu" + _run_test_data_streaming(path, distributed_config, batch_config, worker_resources.torchrun_port) + check_data_streaming_results(path, distributed_config, batch_config) + + +_DISTRIBUTED_TESTING_CONFIGS = [ + ("dp2", 2, {}), + ("sdp2", 2, {"sequence_data_parallel": 2}), + ("tp2", 2, {"tensor_parallel": 2}), + ("pp2", 2, {"pipeline_parallel": 2}), + ("dp2_sdp2", 4, {"sequence_data_parallel": 2}), + ("dp2_tp2", 4, {"tensor_parallel": 2}), + ("dp2_pp2", 4, {"pipeline_parallel": 2}), + ("sdp2_tp2", 4, {"sequence_data_parallel": 2, "tensor_parallel": 2}), + ("sdp2_pp2", 4, {"sequence_data_parallel": 2, "pipeline_parallel": 2}), + ("tp2_pp2", 4, {"tensor_parallel": 2, "pipeline_parallel": 2}), +] + + +@requires_cuda +@pytest.mark.depends_on(on=["test_data_streaming"]) +def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_test_data_streaming_distributed, + (result_path / "data_streaming", worker_resources.torchrun_port), + world_size=torch.cuda.device_count(), + ) + + +@requires_cuda +@pytest.mark.depends_on(on=["test_data_streaming"]) +@pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) +def test_run_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): + report_subtest(path := result_path / f"data_streaming/{name}", num_gpus) + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py deleted file mode 100644 index 407946545..000000000 --- a/tests/models/distributed_test_checkpoint.py +++ /dev/null @@ -1,93 +0,0 @@ -import gc -import logging - -import torch - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.config import NoAutoValidate -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.checkpoint.config import ( - CheckpointLoadConfig, - CheckpointSaveConfig, - DistributedCheckpointFormat, - FastLLMCheckpointFormat, -) -from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from fast_llm.engine.multi_stage.config import StageMode -from fast_llm.utils import Assert, header -from tests.utils.model_configs import ModelTestingConfig -from tests.utils.run_test_script import parse_run_distributed_script -from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def _test_load_and_save_parallel( - model_testing_config: ModelTestingConfig, - config: DistributedSaveLoadConfig, -): - logger.info(header(config.name)) - logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") - with NoAutoValidate(): - load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) - load_config.setup(model_testing_config.model_config_class) - load_config.validate() - model = model_testing_config.model_class.from_pretrained( - load_config, - # The world size and rank are already set through environment variable. - {"distributed": {**config.distributed, "backend": model_testing_config.distributed_backend}}, - mode=StageMode.inference, - ) - for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): - logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") - model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)) - del model - gc.collect() - torch.cuda.empty_cache() - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - with ProcessGroupPool( - timeout=20, - backend=DistributedBackend(model_testing_config.distributed_backend), - ) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - - for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): - if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: - continue - config = config.resolve(base_path, model_testing_config) - Assert.eq(world_size, config.num_gpus) - with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest: - _test_load_and_save_parallel( - model_testing_config=model_testing_config, - config=config, - ) - if not subtest.success: - failures.append(config.name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py deleted file mode 100644 index 29b68366d..000000000 --- a/tests/models/distributed_test_model.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from tests.utils.distributed_configs import DISTRIBUTED_TESTING_CONFIGS -from tests.utils.run_test_script import do_run_test_script_for_all_models, parse_run_distributed_script -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - # TODO: Why are barriers needed? - with ProcessGroupPool( - timeout=60, - backend=DistributedBackend(model_testing_config.distributed_backend), - ) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - safe_barrier(group, "start") - - for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): - if model_testing_config.should_skip(config): - continue - if world_size < config.num_gpus: - logger.warning(f"{name} {f"SKIPPED (not enough GPUs: {world_size} < {config.num_gpus})"})") - continue - with DistributedSubtestContext(base_path, name, group, config.num_gpus, enabled=do_capture) as subtest: - if rank < config.num_gpus: - do_run_test_script_for_all_models(config, model_testing_config, base_path) - if not subtest.success: - failures.append(name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29e..6f164a33e 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -1,3 +1,4 @@ +import gc import logging import pathlib import shutil @@ -7,6 +8,7 @@ import torch import yaml +from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import ( CheckpointFormat, CheckpointLoadConfig, @@ -16,12 +18,13 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName -from fast_llm.utils import Assert +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.utils import Assert, header from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig +from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -391,31 +394,55 @@ def test_huggingface_model(model_testing_config, get_convert_path): raise ValueError(f"Comparison failed ({len(errors)} errors)") +def _save_and_load_in_parallel( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): + if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: + continue + config = config.resolve(base_path, model_testing_config) + with test_context.subtest(base_path, config.name, config.num_gpus) as subtest: + if subtest.do_run: + logger.info(header(config.name)) + logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") + with NoAutoValidate(): + load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) + load_config.setup(model_testing_config.model_config_class) + load_config.validate() + model = model_testing_config.model_class.from_pretrained( + load_config, + # The world size and rank are already set through environment variable. + {"distributed": config.distributed}, + mode=StageMode.inference, + ) + for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") + model.save_checkpoint( + CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format) + ) + del model + gc.collect() + torch.cuda.empty_cache() + + @requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request): +def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_path, model_testing_config): # Save and load checkpoints to and from various distributed configurations. # Combined in a single test to mitigate process creation overhead. # TODO: Test beyond 2 gpu configs? - import tests.models.distributed_test_checkpoint - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") - - script = [ - "-m", - tests.models.distributed_test_checkpoint.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=2) + pytest.skip(f"Not enough GPUs2") + run_parallel_script( + _save_and_load_in_parallel, + (run_test_script_base_path, model_testing_config), + world_size=2, + backend=model_testing_config.distributed_backend, + ) @pytest.fixture(scope="module") diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d14721142..58768bc52 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,4 +1,5 @@ import logging +import pathlib import pytest import torch @@ -8,8 +9,10 @@ SIMPLE_TESTING_CONFIG, SINGLE_GPU_TESTING_CONFIGS, ) -from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup +from tests.utils.run_test_script import do_run_test_script_for_all_models +from tests.utils.subtest import DistributedTestContext, check_subtest_success, set_subtest_success +from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -49,27 +52,34 @@ def test_and_compare_model( compare_results_for_all_models(config) +def _run_model_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): + if model_testing_config.should_skip(config): + continue + with test_context.subtest(base_path, name, config.num_gpus) as subtest: + if subtest.do_run: + do_run_test_script_for_all_models(config, model_testing_config, base_path) + + @requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group( ModelTestingGroup.distributed, ) -def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): - import tests.models.distributed_test_model - - script = [ - "-m", - tests.models.distributed_test_model.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=torch.cuda.device_count()) +def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_model_distributed, + (run_test_script_base_path, model_testing_config), + world_size=torch.cuda.device_count(), + backend=model_testing_config.distributed_backend, + ) # We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure. diff --git a/tests/trainer/events_fake_consumer.py b/tests/trainer/events_fake_consumer.py new file mode 100644 index 000000000..9692134db --- /dev/null +++ b/tests/trainer/events_fake_consumer.py @@ -0,0 +1,105 @@ +import json +import sys +from pathlib import Path + +import redis +import safetensors.torch +import torch.distributed +import yaml + + +def main(): + if len(sys.argv) != 2: + print("Usage: python -m tests.trainer.events_fake_consumer ") + sys.exit(1) + + config_path = Path(sys.argv[1]) + if not config_path.exists(): + print(f"Config file {config_path} does not exist") + sys.exit(1) + + with config_path.open("rt") as f: + config = yaml.safe_load(f) + + consumer_cfg = config["consumer"] + world_size = consumer_cfg["world_size"] + rank = consumer_cfg["rank"] + results_path = Path(consumer_cfg["results_path"]) + results_path.mkdir(parents=True, exist_ok=True) + + consumer_id = f"[Consumer {rank}/{world_size}]" + + print(f"{consumer_id} Started with config:") + print(yaml.safe_dump(config)) + + assert config["events"]["weights_broadcast"]["enabled"] + assert config["events"]["training_finished"]["enabled"] + + redis_client = redis.Redis(host=config["events"]["redis"]["host"], port=config["events"]["redis"]["port"]) + + print(f"{consumer_id} waiting for pg rendezvous...") + weights_pg = torch.distributed.init_process_group( + backend="nccl", + init_method=f'tcp://{config["events"]["weights_broadcast"]["rdvz_master_address"]}:' + f'{config["events"]["weights_broadcast"]["rdvz_master_port"]}', + world_size=world_size, + rank=rank, + ) + broadcast_source_rank = config["events"]["weights_broadcast"]["rank"] + + last_id = "0-0" + msg_key = config["events"]["redis"]["payload_key"].encode() + stream_key = config["events"]["redis"]["stream_key"] + + print(f"{consumer_id} waiting for messages...") + while True: + result = redis_client.xread( + streams={stream_key: last_id}, + count=1, + block=200, + ) + + if not result: + continue + + _, events = result[0] + + for event_id, msg in events: + last_id = event_id + assert msg_key in msg + msg = json.loads(msg[msg_key].decode()) + print(f"{consumer_id} msg received: {msg}") + if msg["type"] == config["events"]["weights_broadcast"]["weights_ready_message_type"] or ( + msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"] + and config["events"]["weights_broadcast"]["initial_weights_step_message_includes_weights"] + ): + weights = {} + while True: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=weights_pg, group_src=broadcast_source_rank) + meta = meta[0] + if meta is None: + print(f"{consumer_id} weight broadcast finished") + break + shard_name, layer_name, tensor_size, tensor_type = meta + tensor = torch.zeros( + tuple(tensor_size), dtype=tensor_type, device="cuda" + ) # so far consumer is single gpu only + torch.distributed.broadcast(tensor, group=weights_pg, group_src=broadcast_source_rank) + print(f"{consumer_id} {shard_name} layer {layer_name} {tensor_size} {tensor_type} received") + if shard_name == "weights": + weights[layer_name] = tensor + safetensors.torch.save_file(weights, results_path / f"{msg["step"]}.safetensors") + + elif msg["type"] == config["events"]["training_finished"]["training_finished_message_type"]: + torch.distributed.destroy_process_group() + (results_path / "training_finished").touch() + return + else: + raise RuntimeError(f"{consumer_id} Received unknown message type {msg}") + if msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"]: + (results_path / "initial_weights_step").touch() + + +if __name__ == "__main__": + main() diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py new file mode 100644 index 000000000..baa0526e3 --- /dev/null +++ b/tests/trainer/test_events.py @@ -0,0 +1,407 @@ +import contextlib +import copy +import os +import pathlib +import subprocess +import time +import typing + +import pytest +import safetensors +import torch +import yaml + +from fast_llm.data.dataset.config import StreamingDatasetConfig +from tests.utils.model_configs import MODEL_CONFIGS +from tests.utils.redis import redis_batch_producer +from tests.utils.utils import requires_cuda + + +@contextlib.contextmanager +def run_fake_events_consumers( + model_config: dict, + test_result_path: pathlib.Path, + broadcast_world_size: int, + fake_consumers_broadcast_ranks: list[int], + assigned_gpus: list[str], + timeout: float = 30.0, # seconds +): + """ + Context manager to run fake event consumer subprocesses for testing. + + Each subprocess gets a separate config and CUDA_VISIBLE_DEVICES. + + After exiting the context, all subprocesses are ensured to terminate. + Raises RuntimeError if any subprocess exits with non-zero code. + """ + import tests.trainer.events_fake_consumer + + assert len(assigned_gpus) > 0 + assert len(assigned_gpus) == len(fake_consumers_broadcast_ranks) + + processes = [] + + try: + for i, gpu in enumerate(assigned_gpus): + consumer_path = test_result_path / str(i) + consumer_path.mkdir(parents=True, exist_ok=True) + + # Deep copy config and update per consumer + this_config = copy.deepcopy(model_config) + this_config["consumer"] = { + "idx": i, + "results_path": consumer_path / "results", + "world_size": broadcast_world_size, + "rank": fake_consumers_broadcast_ranks[i], + } + this_config_path = consumer_path / "config.yaml" + + # Save config as YAML + with open(this_config_path, "w") as f: + yaml.safe_dump(convert_paths(this_config), f) + + # Build subprocess command + script = [ + "python", + "-m", + tests.trainer.events_fake_consumer.__name__, + str(this_config_path), + ] + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu) + + # Start subprocess + proc = subprocess.Popen(script, env=env) + processes.append(proc) + + # Yield control to the caller while subprocesses run + yield + + finally: + # Wait for processes to exit or kill after timeout + start_time = time.time() + for proc in processes: + try: + remaining = max(0, timeout - (time.time() - start_time)) + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + proc.kill() + + # Check exit codes + errors = [(i, p.returncode) for i, p in enumerate(processes) if p.returncode != 0] + if errors: + raise RuntimeError(f"Some fake consumer subprocesses failed: {errors}") + + +def run_fast_llm_training(model_config, run_distributed_script, assigned_gpus): + import fast_llm.cli + + config_path = model_config["run"]["experiment_dir"] / "load_config.yaml" + config_path.parent.mkdir(parents=True, exist_ok=True) + with config_path.open("wt") as f: + yaml.safe_dump(convert_paths(model_config), f) + + script = [ + "-m", + fast_llm.cli.__name__, + "train", + "gpt", + "--config", + str(config_path), + ] + + env = os.environ.copy() + env["PYTHONHASHSEED"] = "42" + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu) for gpu in assigned_gpus) + run_distributed_script(script, num_gpus=len(assigned_gpus), env=env) + + +def compare_test_tensors_to_checkpoint(test_safetensor_path: str, checkpoint_dir: str): + """ + Compare a test-saved safetensor file (a dict of tensors) + to all safetensors in a checkpoint directory. + + Checks: + - tensor names must match + - shapes must match + - dtypes must match + - values must match (exact) + """ + + # ------------------------- + # Load test tensor file + # ------------------------- + test_tensors = {} + with safetensors.safe_open(test_safetensor_path, framework="pt", device="cpu") as f: + for key in f.keys(): + test_tensors[key] = f.get_tensor(key) + + assert len(test_tensors) > 0, f"No tensors found in {test_safetensor_path}." + + # ------------------------- + # Load checkpoint tensors + # ------------------------- + checkpoint_tensors = {} + + for file in os.listdir(checkpoint_dir): + if file.endswith(".safetensors"): + path = os.path.join(checkpoint_dir, file) + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in checkpoint_tensors: + raise AssertionError( + f"Duplicate tensor name '{key}' across checkpoint {checkpoint_dir} files." + ) + checkpoint_tensors[key] = f.get_tensor(key) + + assert len(checkpoint_tensors) > 0, f"No safetensors found in checkpoint directory: {checkpoint_dir}" + + # ------------------------- + # Compare tensor sets + # ------------------------- + test_names = set(test_tensors.keys()) + ckpt_names = set(checkpoint_tensors.keys()) + + unexpected_in_test = test_names - ckpt_names + missing_in_test = ckpt_names - test_names + + assert not missing_in_test, "Tensors missing in {test_safetensor_path}:\n" + "\n".join(sorted(missing_in_test)) + assert not unexpected_in_test, "Unexpected tensors in {test_safetensor_path}:\n" + "\n".join( + sorted(unexpected_in_test) + ) + + # ------------------------- + # Compare individual tensors + # ------------------------- + for name in sorted(test_names): + t_test = test_tensors[name] + t_ckpt = checkpoint_tensors[name] + + # dtype + assert t_test.dtype == t_ckpt.dtype, f"Mismatch in dtype for '{name}': " f"{t_test.dtype} != {t_ckpt.dtype}" + + # shape + assert t_test.shape == t_ckpt.shape, ( + f"Mismatch in shape for '{name}': " f"{tuple(t_test.shape)} != {tuple(t_ckpt.shape)}" + ) + + # values + if not torch.equal(t_test, t_ckpt): + diff = (t_test - t_ckpt).abs() + max_diff = diff.max().item() + idx = (diff > 0).nonzero(as_tuple=False) + example = idx[0].tolist() if idx.numel() > 0 else "unknown" + + raise AssertionError( + f"Tensor content mismatch for '{name}'.\n" + f"Max difference: {max_diff}\n" + f"Example differing index: {example}" + ) + + # If we reached here → all is good + return True + + +def check_events_results( + test_results_path_fast_llm, + test_results_path_consumers, + consumer_count, + training_steps, + model_checkpoint_format, +): + for consumer_idx in range(consumer_count): + consumer_test_results_path = test_results_path_consumers / str(consumer_idx) / "results" + assert (consumer_test_results_path / "training_finished").is_file() + assert (consumer_test_results_path / "initial_weights_step").is_file() + # NOTE: We do not test the initial weights broadcast result when enabled, + # because it is identical to subsequent broadcasts. + for training_step in range(1, training_steps + 1): + compare_test_tensors_to_checkpoint( + consumer_test_results_path / f"{training_step}.safetensors", + test_results_path_fast_llm / "export" / model_checkpoint_format / str(training_step), + ) + + +def convert_paths(obj): + if isinstance(obj, dict): + return {k: convert_paths(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_paths(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(convert_paths(v) for v in obj) + elif isinstance(obj, pathlib.Path): + return str(obj) + else: + return obj + + +def parallelism_variants(num_gpus: int) -> list[dict[str, int]]: + if num_gpus == 1: + return [{"tp": 1, "pp": 1, "sp": 1}] + + if num_gpus == 2: + return [ + # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. + {"tp": 2, "pp": 1, "sp": 1}, + # {"tp": 1, "pp": 2, "sp": 1}, + {"tp": 1, "pp": 1, "sp": 2}, + ] + + if num_gpus == 4: + return [ + # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. + {"tp": 4, "pp": 1, "sp": 1}, + # {"tp": 1, "pp": 4, "sp": 1}, + {"tp": 1, "pp": 1, "sp": 4}, + # {"tp": 2, "pp": 2, "sp": 1}, + # {"tp": 1, "pp": 2, "sp": 2}, + {"tp": 2, "pp": 1, "sp": 2}, + ] + + raise ValueError(f"Invalid gpu count for fast_llm parallelism {num_gpus}") + + +def consumer_counts(num_gpus: int) -> int: + if num_gpus == 2: + return 1 + if num_gpus == 3: + return 1 + if num_gpus == 4: + return 2 + if num_gpus == 5: + return 1 + if num_gpus == 6: + return 2 + if num_gpus == 7: + return 3 + if num_gpus >= 8: + return 4 + + +def generate_variants(num_gpus: int) -> list[dict[str, typing.Any]]: + """ + Generate all (consumer_count, tp/pp/sp) variants for given GPU count. + """ + results = [] + + if num_gpus < 2: + return results + if num_gpus == 2: + num_gpus = [2] + elif num_gpus <= 4: + num_gpus = [2, num_gpus] + else: + num_gpus = [2, 4, min(num_gpus, 8)] + + for gpus in num_gpus: + consumers = consumer_counts(gpus) + remaining = gpus - consumers + par_vars = parallelism_variants(remaining) + for pv in par_vars: + results.append( + { + "total_gpus": gpus, + "consumers_gpu_count": consumers, + "fast_llm_gpus_count": remaining, + "consumers_gpus": list(range(consumers)), + "fast_llm_gpus": list(range(consumers, gpus)), + "tensor_parallel": pv["tp"], + "pipeline_parallel": pv["pp"], + "sequence_data_parallel": pv["sp"], + } + ) + + return results + + +variants = generate_variants(torch.cuda.device_count()) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "variant", + variants, + ids=[ + f"gpu{v['total_gpus']}_cgpus{v['consumers_gpu_count']}_fgpus{v['fast_llm_gpus_count']}" + f"_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}" + for v in variants + ], +) +def test_trainer_events_with_streaming(variant, run_distributed_script, result_path, request): + stream_config = StreamingDatasetConfig(port=port) + test_result_path = result_path / request.node.name + test_result_path_fast_llm = test_result_path / "fast_llm" + test_result_path_consumers = test_result_path / "consumers" + + broadcast_world_size = variant["consumers_gpu_count"] + 1 + fake_consumers_broadcast_ranks = list(range(variant["consumers_gpu_count"])) + fake_consumers_assigned_gpus = variant["consumers_gpus"] + fast_llm_broadcast_rank = variant["consumers_gpu_count"] + fast_llm_assigned_gpus = variant["fast_llm_gpus"] + train_iters = 2 + + model_config = copy.deepcopy(MODEL_CONFIGS["mistral"].config_dict) + model_config["data"]["datasets"] = {"training": stream_config.to_dict()} + model_config["data"]["sampling"] = {"shuffle": "disabled"} + model_config["training"]["train_iters"] = train_iters + model_config["training"]["export"] = {"interval": 1, "format": MODEL_CONFIGS["mistral"].checkpoint_format.name} + model_config["batch"]["micro_batch_size"] = 1 + model_config["batch"]["truncate_documents"] = False + model_config["run"]["experiment_dir"] = test_result_path_fast_llm + model_config["model"]["distributed"]["tensor_parallel"] = variant["tensor_parallel"] + model_config["model"]["distributed"]["pipeline_parallel"] = variant["pipeline_parallel"] + model_config["model"]["distributed"]["sequence_data_parallel"] = variant["sequence_data_parallel"] + + # We use same stream for messages in the test. Also make all fields explicit, + # so fake consumers can read them as well from this dict config + model_config["events"] = { + "redis": { + "host": stream_config.host, + "port": stream_config.port, + "stream_key": "fast_llm_events", + "payload_key": "event", + }, + "weights_broadcast": { + "enabled": True, + "initial_weights_step_message_type": "initial_weights_step", + "initial_weights_step_message_includes_weights": True, + "weights_ready_message_type": "weights_ready", + "rdvz_master_address": "127.0.0.1", + "rdvz_master_port": 19999, + "world_size": broadcast_world_size, + "rank": fast_llm_broadcast_rank, + }, + "training_finished": { + "enabled": True, + "training_finished_message_type": "training_finished", + }, + } + + batch_size = model_config["batch"]["batch_size"] + sequence_length = model_config["batch"]["sequence_length"] + with redis_batch_producer( + redis_client=fake_redis_client, + fake_redis_server_killer=fake_redis_server_killer, + batch_size=batch_size, + sequence_length=sequence_length, + ): + with run_fake_events_consumers( + model_config=model_config, + test_result_path=test_result_path_consumers, + broadcast_world_size=broadcast_world_size, + fake_consumers_broadcast_ranks=fake_consumers_broadcast_ranks, + assigned_gpus=fake_consumers_assigned_gpus, + ): + run_fast_llm_training( + model_config=model_config, + run_distributed_script=run_distributed_script, + assigned_gpus=fast_llm_assigned_gpus, + ) + check_events_results( + test_results_path_fast_llm=test_result_path_fast_llm, + test_results_path_consumers=test_result_path_consumers, + consumer_count=len(fake_consumers_assigned_gpus), + training_steps=train_iters, + model_checkpoint_format=MODEL_CONFIGS["mistral"].checkpoint_format.name, + ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1248a1117..f14472194 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -12,6 +12,7 @@ from fast_llm.config import set_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.conversion.config import ( @@ -149,7 +150,7 @@ def base_model_config_class(self): @functools.cached_property def distributed_backend(self): - return self.config_dict["model"]["distributed"]["backend"] + return DistributedBackend(self.config_dict["model"]["distributed"]["backend"]) def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) diff --git a/tests/utils/redis.py b/tests/utils/redis.py new file mode 100644 index 000000000..7e8072aab --- /dev/null +++ b/tests/utils/redis.py @@ -0,0 +1,137 @@ +import contextlib +import itertools +import json +import pathlib +import socket +import threading +import time + +import fakeredis + +from fast_llm.data.dataset.config import ( + RedisConfig, + SamplingConfig, + SamplingData, + SamplingParameters, + StreamingDatasetConfig, +) +from fast_llm.data.dataset.streaming import REDIS_DATA_KEY, REDIS_GROUP_NAME +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.models.gpt.config import GPTBatchConfig + + +def find_free_port(): + """Find a free TCP port and return it.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def push_msg(redis_client, tokens): + """Push a message into FakeRedis stream.""" + redis_client.xadd(REDIS_DATA_KEY, {"data": json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) + + +def wait_until_stream_empty( + redis_client, + stream_key, + consumer_group, + stop_event, +): + """ + Wait until lag == 0, meaning all messages have been delivered AND acknowledged. + Absence of group mean test has not started yet, so we wait + """ + consumer_group = consumer_group.encode() + while not stop_event.is_set(): + groups = redis_client.xinfo_groups(stream_key) + + g = next((g for g in groups if g["name"] == consumer_group), None) + if g is not None: + lag = g.get("lag", 0) + if lag == 0: + return + + time.sleep(0.05) + + +def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig): + while not stop_event.is_set(): + res = redis_client.hget(f"{REDIS_DATA_KEY}:consumer_count", "0") + if res is None: + time.sleep(0.05) + continue + return int(res) + + +@contextlib.contextmanager +def redis_batch_producer(config: RedisConfig, batch_config: GPTBatchConfig): + with fake_redis_server(config): + stop_event = threading.Event() + client = config.get_client() + + def producer_loop(): + for sample_index in itertools.count(): + if stop_event.is_set(): + break + push_msg(client, [sample_index] * batch_config.sequence_length) + if sample_index % 5 == 0: + wait_until_stream_empty(client, REDIS_DATA_KEY, REDIS_GROUP_NAME, stop_event) + + thread = threading.Thread(target=producer_loop, daemon=True) + thread.start() + + try: + yield + finally: + stop_event.set() + thread.join(timeout=1) + client.close() + + +def make_sampling(sequence_length, num_samples, distributed): + return SamplingData( + parameters=SamplingParameters( + sequence_length=sequence_length, + extra_tokens=0, + num_samples=num_samples, + truncate_documents=False, + ), + config=SamplingConfig(), + distributed=distributed, + dataset_name="test", + cache_directory=pathlib.Path("/tmp"), + preprocessing=LanguageModelPreprocessingConfig(), + ) + + +@contextlib.contextmanager +def fake_redis_server(config: RedisConfig): + # We search for free port as port from previous test can still be not free even after server shutdown + + # ----- Monkey-patch handler to suppress broken pipes ----- + orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle + + def safe_handle(self): + try: + orig_handle(self) + except (ConnectionResetError, BrokenPipeError): + # Client disconnected abruptly (e.g., when a PyTorch DataLoader iterator is deleted). + # These errors occur only with fake Redis and can be safely ignored. + pass + except Exception as e: + print(f"Unexpected exception in fake Redis handler: {e}") + + fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + try: + yield + finally: + # ----- Teardown ----- + server.shutdown() + server.server_close() + thread.join() diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 5c07324cf..d43855789 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -47,11 +47,7 @@ def do_run_distributed_script( @pytest.fixture(scope="session") -def run_distributed_script( - worker_resources: "WorkerResources", - run_test_script_base_path: pathlib.Path, - model_testing_config: ModelTestingConfig, -): +def run_distributed_script(worker_resources: "WorkerResources"): return functools.partial( do_run_distributed_script, rendezvous_port=worker_resources.rendezvous_port, diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py new file mode 100644 index 000000000..9d7d319a1 --- /dev/null +++ b/tests/utils/subtest.py @@ -0,0 +1,265 @@ +import functools +import json +import logging +import math +import pathlib +import sys +import time +import traceback +import typing + +import pytest +import torch + +from fast_llm.core.distributed import allreduce_scalar, safe_barrier +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig +from fast_llm.engine.distributed.distributed import ProcessGroupPool +from fast_llm.utils import Assert, get_and_reset_memory_usage_mib, header + +logger = logging.getLogger(__name__) + + +class DistributedTestContext: + def __init__( + self, + do_capture: bool, + timeout: float = 20.0, + init_method: str = "env://", + backend: DistributedBackend = DistributedBackend.nccl, + ) -> None: + self._do_capture = do_capture + self._timeout = timeout + self._init_method = init_method + self._backend = backend + + def __enter__(self): + if self._do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + + self._pool = ProcessGroupPool( + timeout=self._timeout, init_method=self._init_method, backend=self._backend + ).__enter__() + self._rank = self._pool.rank + self._world_size = self._pool.world_size + self._failures = [] + self._configure_logging() + self._group = self._pool.get_process_group(range(self._world_size), self._rank) + # TODO: Barriers needed? + safe_barrier(self._group, "start") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Final barrier to ensure everything is done before torchrun potentially kills workers. + safe_barrier(self._group, "testing end") + # Let pytest know how things went. + # These should already be reported above, we repeat for convenience. + if self._failures: + raise RuntimeError(f"The following subtests failed: {", ".join(self._failures)}") + else: + logger.warning("All tests passed") + + def subtest(self, base_path: pathlib.Path, name: str, num_gpus: int): + return self.DistributedSubtestContext(self, base_path, name, num_gpus) + + def _configure_logging(self): + configure_logging(rank=self._rank, world_size=self._world_size) + + class DistributedSubtestContext: + def __init__( + self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int + ) -> None: + self._test_context = test_context + self._path = base_path / name + self._name = name + self._num_gpus = num_gpus + self._skip = self._test_context._world_size < self._num_gpus + self._do_run = self._test_context._rank < num_gpus and not self._skip + self._do_capture = self._test_context._do_capture and self._do_run + self._success = False + + def __enter__(self) -> typing.Self: + if self._do_capture: + self._sys_stdout = sys.stdout + self._sys_stderr = sys.stderr + self._path.mkdir(parents=True, exist_ok=True) + sys.stdout = self._path.joinpath(f"pytest_stdout_{self._test_context._rank}").open("w") + sys.stderr = self._path.joinpath(f"pytest_stderr_{self._test_context._rank}").open("w") + self._test_context._configure_logging() + # Logging is set to log to the old stdout, so we need to reconfigure. + self._start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._skip: + # Skipped tests should exit right away. + Assert.none(exc_val) + logger.warning( + f"{self._name} {f"SKIPPED (not enough GPUs: {self._test_context._world_size} < {self._num_gpus})"})" + ) + return + + if self._do_capture: + try: + stdout_handle = sys.stdout + stderr_handle = sys.stderr + sys.stdout = self._sys_stdout + sys.stderr = self._sys_stderr + stdout_handle.close() + stderr_handle.close() + finally: + assert DistributedConfig.default_world_size > 1 + self._test_context._configure_logging() + + if exc_type is None: + self._success = True + else: + self._path.mkdir(parents=True, exist_ok=True) + self._path.joinpath(f"pytest_traceback_{self._test_context._rank}").write_text(traceback.format_exc()) + + logger.warning(f"{self._name} done, waiting for other ranks ({"PASSED" if self._success else "FAILED"})") + + if (group := self._test_context._group) is not None: + # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + + if self._do_capture: + # Free resources to limit memory usage. + report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) + report["duration"] = time.perf_counter() - self._start + + json.dump(report, self._path.joinpath(f"pytest_report_{self._test_context._rank}").open("w")) + + if self._test_context._rank == 0: + set_subtest_success(self._path, self._success) + logger.warning(f"{self._name} {"PASSED" if self._success else "FAILED"}") + if not self._success: + self._test_context._failures.append(self._name) + + return True + + @property + def do_run(self) -> bool: + return self._do_run and not self._skip + + +def set_subtest_success(path: pathlib.Path, success: bool = True): + path.joinpath("pytest_success").write_text(str(int(success))) + + +def check_subtest_success(path: pathlib, fail: bool = True) -> bool: + if not path.is_dir(): + if fail: + pytest.fail(f"Test {path.name} did not run", pytrace=False) + else: + return False + try: + return bool(int(path.joinpath("pytest_success").read_text())) + except OSError: + return False + + +def format_resource_report(title: str, report: dict[str, float]) -> str: + return "".join( + [ + f"{title}:\n ", + f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", + f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), + f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), + f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), + f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), + f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", + ] + ) + + +@pytest.fixture(scope="function") +def report_subtest(request: pytest.FixtureRequest): + verbose = request.config.getoption("verbose") + do_capture = request.config.getoption("distributed_capture") + + def do_report_subtest(path: pathlib.Path, world_size: int) -> None: + success = check_subtest_success(path) + if not do_capture: + logger.warning("Distributed capture is disabled. See distributed test for run output.") + elif verbose > 1 or not success: + for rank in range(world_size): + for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): + print(header(f"{fd} rank {rank}", 80), file=file_) + file_path = path / f"pytest_{fd}_{rank}" + try: + print(file_path.read_text(), file=file_) + except OSError: + print(f"<<< not found {file_path}>>>", file=file_) + else: + print("Set verbose > 1 to show run output.") + + reports = {} + for rank in range(world_size): + try: + reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) + except OSError: + reports[rank] = {} + keys = {key for report in reports.values() for key in report} + report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} + report["gpus"] = world_size + reports["global"] = report + + print(header(f"Resource usage", 80), file=sys.stderr) + for name, report in reports.items(): + print(format_resource_report(name, report), file=sys.stderr) + setattr(request.node, "fast_llm_resource_report", report) + + if not success: + raise RuntimeError(f"test {path.name} failed") + + return do_report_subtest + + +def parallel_worker( + rank: int, + world_size: int, + init_method: str, + backend: DistributedBackend, + do_capture: bool, + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], +): + DistributedConfig.default_rank = rank + DistributedConfig.default_world_size = world_size + DistributedConfig.default_local_world_size = world_size + with DistributedTestContext(do_capture, 60, init_method, backend) as test_context: + fn(test_context, *fn_args) + + +def do_run_parallel_script( + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], + port: int, + do_capture: bool, + world_size: int, + timeout: float = 240, + backend: DistributedBackend = DistributedBackend.nccl, +): + if do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + torch.multiprocessing.spawn( + parallel_worker, + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args), + nprocs=world_size, + join=False, + ).join(timeout, grace_period=5) + + +@pytest.fixture(scope="session") +def run_parallel_script(worker_resources: "WorkerResources", request: pytest.FixtureRequest): + return functools.partial( + do_run_parallel_script, + port=worker_resources.rendezvous_port, + do_capture=request.config.getoption("distributed_capture"), + ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 3b79f7607..f0ca20db8 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,23 +1,14 @@ -import json import logging -import math -import pathlib -import sys -import time -import traceback import typing import pytest import torch -from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import set_model_names -from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage -from fast_llm.utils import get_and_reset_memory_usage_mib, header from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) @@ -65,137 +56,3 @@ def get_stage( stage.restore_parameters() stage.reset_gradients() return stage - - -class DistributedSubtestContext: - def __init__( - self, base_path: pathlib.Path, name: str, group: ProcessGroup | None, num_gpus: int, enabled: bool = True - ) -> None: - self._path = base_path / name - self._name = name - self._group = group - self._rank = 0 if group is None else group.rank() - self._rank_enabled = self._rank < num_gpus - self._enabled = enabled and self._rank_enabled - self.success = False - - def __enter__(self) -> typing.Self: - if self._enabled: - self._sys_stdout = sys.stdout - self._sys_stderr = sys.stderr - self._path.mkdir(parents=True, exist_ok=True) - sys.stdout = self._path.joinpath(f"pytest_stdout_{self._rank}").open("w") - sys.stderr = self._path.joinpath(f"pytest_stderr_{self._rank}").open("w") - # Logging is set to log to the old stdout, so we need to reconfigure. - configure_logging() - self._start = time.perf_counter() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._enabled: - try: - stdout_handle = sys.stdout - stderr_handle = sys.stderr - sys.stdout = self._sys_stdout - sys.stderr = self._sys_stderr - stdout_handle.close() - stderr_handle.close() - finally: - configure_logging() - - if exc_type is None: - self.success = True - else: - self._path.mkdir(parents=True, exist_ok=True) - self._path.joinpath(f"pytest_traceback_{self._rank}").write_text(traceback.format_exc()) - - if self._group is not None: - # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(self._group, self._name) - self.success = allreduce_scalar(self.success, dtype=torch.int64, group=self._group) == self._group.size() - - if self._rank_enabled: - # Free resources to limit memory usage. - report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) - report["duration"] = time.perf_counter() - self._start - - json.dump(report, self._path.joinpath(f"pytest_report_{self._rank}").open("w")) - - logger.warning(f"{self._name} {"PASSED" if self.success else "FAILED"})") - if self._rank == 0: - set_subtest_success(self._path, self.success) - - return True - - -def set_subtest_success(path: pathlib.Path, success: bool = True): - path.joinpath("pytest_success").write_text(str(int(success))) - - -def check_subtest_success(path: pathlib, fail: bool = True) -> bool: - if not path.is_dir(): - if fail: - pytest.fail(f"Test {path.name} did not run", pytrace=False) - else: - return False - try: - return bool(int(path.joinpath("pytest_success").read_text())) - except OSError: - return False - - -def format_resource_report(title: str, report: dict[str, float]) -> str: - return "".join( - [ - f"{title}:\n ", - f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", - f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), - f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), - f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), - f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), - f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", - ] - ) - - -@pytest.fixture(scope="function") -def report_subtest(request: pytest.FixtureRequest): - verbose = request.config.getoption("verbose") - do_capture = request.config.getoption("distributed_capture") - - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: - success = check_subtest_success(path) - if not do_capture: - logger.warning("Distributed capture is disabled. See distributed test for run output.") - elif verbose > 1 or not success: - for rank in range(world_size): - for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): - print(header(f"{fd} rank {rank}", 80), file=file_) - file_path = path / f"pytest_{fd}_{rank}" - try: - print(file_path.read_text(), file=file_) - except OSError: - print(f"<<< not found {file_path}>>>", file=file_) - else: - print("Set verbose > 1 to show run output.") - - reports = {} - for rank in range(world_size): - try: - reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) - except OSError: - reports[rank] = {} - keys = {key for report in reports.values() for key in report} - report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} - report["gpus"] = world_size - reports["global"] = report - - print(header(f"Resource usage", 80), file=sys.stderr) - for name, report in reports.items(): - print(format_resource_report(name, report), file=sys.stderr) - setattr(request.node, "fast_llm_resource_report", report) - - if not success: - raise RuntimeError(f"test {path.name} failed") - - return do_report_subtest