kd.data.py.Mix#
- class kauldron.data.py.Mix(datasets: list[base.PyGrainPipeline], *, _fake_refs: type[_FakeRefsUnset] | dict[str, _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>, batch_size: int | None = None, seed: int | typing.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = _FakeRootCfg('cfg.seed'), transforms: tr_normalize.Transformations = <factory>, num_epochs: Optional[int] = None, batch_drop_remainder: bool = True, num_workers: int = 16, read_options: grain.ReadOptions | None = None, enable_profiling: bool = False, per_worker_buffer_size: int = 1, weights: None | list[float | int] = None, shuffle: bool = True)[source]
Bases:
kauldron.data.py.base.PyGrainPipeline
Create a dataset mixture.
- datasets: list[base.PyGrainPipeline]
- weights: None | list[float | int] = None
- shuffle: bool = True
- ds_for_current_process(
- rng: kauldron.random.random.PRNGKey,