Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
1a18929
Dataset interface
jlamypoirier Oct 15, 2025
fd63846
misc
jlamypoirier Oct 15, 2025
2486caf
fix
jlamypoirier Oct 15, 2025
92e93e8
Language model sample
jlamypoirier Oct 16, 2025
d6f6944
fix
jlamypoirier Oct 16, 2025
5c802fa
fixes
jlamypoirier Oct 16, 2025
95d1840
test
jlamypoirier Oct 16, 2025
eafd9cb
fixes
jlamypoirier Oct 17, 2025
c56df69
cleanup
jlamypoirier Oct 17, 2025
7f437e1
misc
jlamypoirier Oct 17, 2025
dfd27f5
misc
jlamypoirier Oct 17, 2025
90cd009
Memmap dataset
jlamypoirier Oct 18, 2025
acfd30e
fixes
jlamypoirier Oct 29, 2025
34939e9
fixes
jlamypoirier Oct 29, 2025
c5fa072
int64
jlamypoirier Oct 29, 2025
cd28676
Test and fix preparator
jlamypoirier Nov 5, 2025
435d214
fix
jlamypoirier Nov 5, 2025
f6bef55
fix
jlamypoirier Nov 6, 2025
e05d9a1
fix
jlamypoirier Nov 6, 2025
9ba8d1b
fix
jlamypoirier Nov 6, 2025
b35b297
fixes
jlamypoirier Nov 6, 2025
abe2357
misc
jlamypoirier Nov 11, 2025
1801d87
fix
jlamypoirier Nov 11, 2025
2223b85
fix right stage mode
bigximik Nov 13, 2025
a9a4ace
newer transformers fixes
bigximik Nov 13, 2025
97f2b60
fix distributed tests skip on single gpu
bigximik Nov 13, 2025
0fdc978
set mamba 2 style model conversions to broke
bigximik Nov 13, 2025
665deb5
Merge branch 'jlp/dataset_interface' of github.com:ServiceNow/Fast-LL…
bigximik Nov 17, 2025
4d03889
Merge branch 'jlp/lm_sample' of github.com:ServiceNow/Fast-LLM into d…
bigximik Nov 17, 2025
224c2ec
mmaba2 enable conversion tests
bigximik Nov 17, 2025
f1afbf2
Merge branch 'jlp/memmap_dataset' of github.com:ServiceNow/Fast-LLM i…
bigximik Nov 17, 2025
00bba27
added model_and_sequence_data_group
bigximik Nov 23, 2025
5b20276
added Iterable dataset base classes
bigximik Nov 23, 2025
978a68f
added naive sampled iterable dataset
bigximik Nov 23, 2025
066a0bf
added iterable dataset configs, streaming dataset and PipelineRL samp…
bigximik Nov 23, 2025
68b3d65
added distributed data loader wrapper
bigximik Nov 23, 2025
2fbfe99
added iterable dataset to gpt data
bigximik Nov 23, 2025
0892523
appended comment
bigximik Nov 23, 2025
54fadb4
changed base classes for iterable dataset configs
bigximik Nov 24, 2025
4e11bf3
fix batch type
bigximik Nov 24, 2025
8428df8
fix added name property to the class
bigximik Nov 24, 2025
04ee4d7
add eof for tests
bigximik Nov 24, 2025
1217998
change base class to torch iterable
bigximik Nov 24, 2025
c542dac
added straming dataset, sampling and base data tests
bigximik Nov 24, 2025
3999a8e
merge from main
bigximik Nov 24, 2025
c6ef780
merge from main
bigximik Nov 24, 2025
a1556f8
change import
bigximik Nov 24, 2025
63737b1
fix iterable sampler for spawn, add fake redis server to multi proces…
bigximik Nov 25, 2025
e843c8e
preparation for multi gpu tests
bigximik Nov 25, 2025
d5ce3f2
added multi gpu gptdata streaming test
bigximik Nov 26, 2025
c13c6df
added streming dataset requirements
bigximik Nov 27, 2025
e6d8f49
added streaming dataset installation to tests
bigximik Nov 27, 2025
1e92dd4
removed cheking for max samples
bigximik Nov 27, 2025
3ac4882
remved test eof, reduces timeout
bigximik Nov 27, 2025
46db991
changed tests to work without eof or max_samplmes_count
bigximik Nov 27, 2025
187055b
fix quen2 converter to accept qkv biases properly
bigximik Nov 28, 2025
21833a0
fix import errors
rafapi Dec 4, 2025
2f5f848
changes to config
bigximik Dec 8, 2025
1e07bad
Merge branch 'denis/new_datasets' of github.com:ServiceNow/Fast-LLM i…
bigximik Dec 8, 2025
c8cb9fd
added tensor iterator
bigximik Dec 10, 2025
e367998
added trainer events
bigximik Dec 10, 2025
5230b74
update test for changed config
bigximik Dec 10, 2025
1a94de5
added 2 gpus trainer events test
bigximik Dec 10, 2025
6cfd445
fix for multiple gpus
bigximik Dec 10, 2025
333665d
updated test to multiple gpus
bigximik Dec 10, 2025
5d1f474
added not implemented for pp streaming
bigximik Dec 12, 2025
5f7cb29
removed PipelineRL sample and batch
bigximik Dec 12, 2025
d07a900
base radis and streaming dataset config class refactoring
bigximik Dec 12, 2025
3a7ba92
refactoring of redis config, trainer event config, corresponding tests
bigximik Dec 12, 2025
59f6f7d
removed eof message which is not supported
bigximik Dec 12, 2025
2c20ebd
added implementation for initial_weights_step_message_type event
bigximik Dec 12, 2025
f4107c3
removed explicit msg ack
bigximik Dec 16, 2025
c32ef89
fix of training finished event
bigximik Dec 16, 2025
f637649
alternative streaming immplementaions: one stream and n streams witho…
bigximik Dec 16, 2025
e43ce95
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier Dec 16, 2025
5545598
merge from main
bigximik Dec 16, 2025
0d198ff
fix after merge added preprocessing empty configs
bigximik Dec 16, 2025
70ef5c4
fix for tests with no import
bigximik Dec 16, 2025
058c93c
fixes
jlamypoirier Dec 16, 2025
d34d39a
Merge remote-tracking branch 'origin/denis/new_datasets' into denis/n…
jlamypoirier Dec 16, 2025
ffb0a5f
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier Dec 16, 2025
359231f
removed cloudpickle
bigximik Dec 16, 2025
ca9e94e
Simplify distributed
jlamypoirier Dec 16, 2025
9f0704c
Simplified pipeline RL
jlamypoirier Dec 17, 2025
992f447
stuff
jlamypoirier Dec 17, 2025
4144317
misc
jlamypoirier Dec 22, 2025
6cf1e70
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier Dec 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV,DOCS]"
- name: Run tests
run: pytest -v -ra .

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
72 changes: 72 additions & 0 deletions fast_llm/data/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import itertools
import typing

import torch.utils.data

from fast_llm.core.distributed import broadcast_object


class SampledDatasetIterator(torch.utils.data.Sampler):
"""
A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
"""

def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
super().__init__()
self._total_samples = total_samples
self._begin_index = begin_index
self._batch_size = micro_batch_size * data_parallel
self._start_idx = data_rank * micro_batch_size
self._end_idx = (data_rank + 1) * micro_batch_size

def __len__(self) -> int:
return self._total_samples

def __iter__(self) -> typing.Iterator[list[int]]:
for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
yield list(range(idx + self._start_idx, idx + self._end_idx))


class DistributedDataLoaderWrapper:
"""
Wraps a regular dataloader so that only the process group leader
loads data, and then broadcasts the batch to other ranks in the group.
"""

def __init__(
self,
data_loader: torch.utils.data.dataloader.DataLoader,
process_group: torch.distributed.ProcessGroup | None,
):
self._data_loader = data_loader
self._rank = 0 if process_group is None else process_group.rank()
self._process_group = process_group

def __iter__(self):
if self._rank == 0:
self._iterator = iter(self._data_loader)
else:
self._iterator = itertools.repeat(None)
if self._process_group is None:
return self._iterator
return self

def __next__(self):
# TODO:
# Instead of broadcasting a general object, make this iterator yield an actual Batch class.
# Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
# efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
# entire Batch object, which is inefficient for tensors because it serializes
# (pickles) them before sending.

try:
data = next(self._iterator) # may raise StopIteration
except Exception as e:
data = e
data = broadcast_object(data, self._process_group, 0)

if isinstance(data, Exception):
raise data

return data
37 changes: 20 additions & 17 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
Expand Down Expand Up @@ -116,20 +116,23 @@ def get_iterator(
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

return iter(
torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
data_parallel=self._distributed.config.batch_data_parallel,
),
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
data_loader = torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
data_parallel=self._distributed.config.batch_data_parallel,
),
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)

if self._datasets[dataset_name].requires_broadcast:
data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group)

return iter(data_loader)
20 changes: 20 additions & 0 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData
from fast_llm.data.dataset.sampled import SampledIterableDataset


class Dataset[SampleType: Sample](abc.ABC):
Expand All @@ -27,6 +28,14 @@ def __getstate__(self):
del state["__orig_class__"]
return state

@property
def requires_broadcast(self) -> bool:
"""
Some dataset schemes load the dataset on a batch-data-parallel group leaders,
then broadcast to the other devices.
"""
return False


class SampledDataset[SampleType: Sample](Dataset[SampleType]):
"""
Expand All @@ -48,3 +57,14 @@ class SamplableDataset[SampleType: Sample](Dataset[SampleType]):
@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass


class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]):
@abc.abstractmethod
def __iter__(self) -> typing.Iterator[SampleType]:
pass

def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]":
from fast_llm.data.dataset.sampled import SampledIterableDataset

return SampledIterableDataset(self, config)
44 changes: 44 additions & 0 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -298,3 +299,46 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp
return LegacyMemmapDataset[SampleType](name, self.path, preprocessing)
else:
raise FileNotFoundError(self.path)


@config_class()
class RedisConfig(Config):
# TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config?
host: str = Field(
default="localhost",
desc="Hostname or IP address of the Redis server.",
hint=FieldHint.core,
)

port: int = Field(
default=6379,
desc="Port number on which the Redis server is running.",
hint=FieldHint.core,
)

def get_client(self):
import redis

return redis.Redis(self.host, self.port)


@config_class(dynamic_type={SampledDatasetConfig: "streaming"})
class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]):
"""
Configuration for a streaming dataset that reads training data from a Redis stream.
"""

_abstract = False

acknowledge_interval: int = Field(
default=10,
desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.",
hint=FieldHint.core,
)

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
from fast_llm.data.dataset.streaming import RedisStreamingDataset

return RedisStreamingDataset[StreamingDatasetConfig, SampleType](self, sampling.distributed.config).sample(
sampling
)
8 changes: 8 additions & 0 deletions fast_llm/data/dataset/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ def __getitem__(self, index: int) -> SampleType:
@property
def name(self) -> str:
return self._dataset.name

@property
def requires_broadcast(self) -> bool:
"""
Some dataset schemes load the dataset on a batch-data-parallel group leaders,
then broadcast to the other devices.
"""
return self._dataset.requires_broadcast
64 changes: 63 additions & 1 deletion fast_llm/data/dataset/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import yaml

from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset
from fast_llm.data.dataset.config import SamplingData, ShufflingType
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.sample.abstract import Sample
Expand Down Expand Up @@ -111,6 +111,10 @@ def __init__(
# No barrier yet to allow running in parallel.
# There needs to be one before calling `__getitem__`, normally handled through `Data`.

@property
def requires_broadcast(self) -> bool:
return self._indexed_dataset.requires_broadcast

def _sample(self) -> None:
"""
Create a `SampledDataset` with the requested parameters.
Expand Down Expand Up @@ -429,3 +433,61 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:

self._unshuffled_tokens = data["unshuffled_tokens"]
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch


class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]):
def __init__(
self,
dataset: SamplableIterableDataset[SampleType],
sampling: SamplingData,
):
self._dataset = dataset
self._config = sampling.config
self._parameters = sampling.parameters
self._documents: list[SampleType] = []
self._current_length = 0
self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens
# Delay iterator creation to avoid pickling issues.
self._iterator: typing.Iterator[SampleType] | None = None

@property
def requires_broadcast(self) -> bool:
# TODO: ====== fix ======
# return self._iterator.requires_broadcast
return True

def __getitem__(self, index: int) -> SampleType:
if self._iterator is None:
self._iterator = iter(self._dataset)
while self._current_length < self._sample_length:
document = next(self._iterator)
if len(document) > self._sample_length:
logging.warning(f"Dropping document with length {len(document)} > {self._sample_length}.")
continue
self._documents.append(document)
self._current_length += len(document)

if self._current_length == self._sample_length:
documents = self._documents
self._documents = []
self._current_length = 0
else:
last_length = len(self._documents[-1])
remaining_length = last_length - (self._current_length - self._sample_length)
if self._parameters.truncate_documents:
documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)]
self._documents = [self._documents[-1].crop(remaining_length, last_length)]
else:
documents = self._documents[:-1] + [self._documents[0].get_padding(remaining_length)]
self._documents = [self._documents[-1]]
self._current_length = len(self._documents[0])
sample = documents[0].from_documents(documents)
Assert.eq(len(sample), self._sample_length)
return sample

def __len__(self) -> int:
return self._parameters.num_samples

@property
def name(self) -> str:
return self._dataset.name
Loading