Skip to content

Commit

Permalink
Training improvements (allenai#239)
Browse files Browse the repository at this point in the history
Co-authored-by: ananyahjha93 <[email protected]>
  • Loading branch information
epwalsh and ananyahjha93 authored Aug 24, 2023
1 parent 642d0fa commit 9071816
Show file tree
Hide file tree
Showing 29 changed files with 18,917 additions and 268 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pip-wheel-metadata/
.venv/
.vscode/
/*.iml
pyrightconfig.json


# jupyter notebooks
Expand Down
4,434 changes: 4,434 additions & 0 deletions configs/v1-mix-medium-mcli.yaml

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions configs/v1-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ model:
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_std: 0.02
init_fn: normal

compile: null # causes instability on AMD GPUs

optimizer:
name: lionw
learning_rate: 1.0e-4
weight_decay: 0.01
name: adamw
learning_rate: 3.0e-4
weight_decay: 0.1
betas:
- 0.9
- 0.95
metrics_log_interval: 10

scheduler:
name: cosine_with_warmup
t_warmup: 2000
# t_max: 47684 # 200B tokens, after which we decay linearly to 1/100th of initial LR
t_warmup: 5000
alpha_f: 0.1

data:
paths: ${path.glob:${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/books/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/c4/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/common-crawl/*/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/s2/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/stack/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/wiki/*.npy}
Expand All @@ -66,7 +67,7 @@ save_overwrite: false
save_interval: 5000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: 1000000 # getting errors on LUMI right now
save_interval_unsharded: null # getting errors on LUMI right now
save_num_unsharded_checkpoints_to_keep: -1

load_path: null
Expand Down
4,433 changes: 4,433 additions & 0 deletions configs/v1-mix-small-mcli.yaml

Large diffs are not rendered by default.

21 changes: 11 additions & 10 deletions configs/v1-mix-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dry_run: false

wandb:
name: ${run_name}
project: c4-small
project: olmo-small
group: v1-mix

model:
d_model: 2048
Expand All @@ -28,22 +29,22 @@ model:
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_std: 0.02
init_fn: mitchell

compile: null # causes instability on AMD GPUs

optimizer:
name: lionw
learning_rate: 2.0e-4
weight_decay: 0.01
name: adamw
learning_rate: 1.0e-3
weight_decay: 0.1
betas:
- 0.9
- 0.95

scheduler:
name: cosine_with_warmup
t_warmup: 2000
t_max: null
alpha_f: 0.1

data:
paths: ${path.glob:${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/books/*.npy,${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/c4/*.npy,${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/common-crawl/*/*.npy,${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/s2/*.npy,${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/stack/*.npy,${oc.env:FLASH_DIR}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/wiki/*.npy}
Expand All @@ -59,13 +60,13 @@ tokenizer:
identifier: allenai/eleuther-ai-gpt-neox-20b-pii-special
truncate_direction: right

save_folder: ${path.choose:${oc.env:SCRATCH_DIR,no_exist}/checkpoints,/results}/${oc.env:SLURM_JOB_ID,${run_name}}
save_folder: ${path.choose:${oc.env:FLASH_DIR,no_exist}/checkpoints,/results}/${oc.env:SLURM_JOB_ID,${run_name}}
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 1000
save_num_checkpoints_to_keep: 9
save_interval: 5000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: 10000
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null
Expand Down
2 changes: 2 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from olmo.config import (
DataConfig,
InitFnType,
ModelConfig,
OptimizerConfig,
PaddingDirection,
Expand Down Expand Up @@ -49,6 +50,7 @@ def model_config() -> ModelConfig:
n_heads=2,
n_layers=3,
max_sequence_length=512,
init_fn=InitFnType.normal,
)


Expand Down
60 changes: 52 additions & 8 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"BlockType",
"CompilerConfig",
"LayerNormType",
"InitFnType",
"ModelConfig",
"OptimizerType",
"OptimizerConfig",
Expand Down Expand Up @@ -180,6 +181,32 @@ class BlockType(StrEnum):
parallel = "parallel"


class InitFnType(StrEnum):
mitchell = "mitchell"
"""
The strategy suggested to us by Mitchell Wortsman from UW.
This uses a truncated normal distribution with an adaptive standard deviation that depends
on the size of the weights as well as the depth of the layer.
"""

normal = "normal"
"""
All weights are initialized from the same normal distribution.
"""

kaiming_normal = "kaiming_normal"
"""
All weights are initialized with the Kaiming method from a normal distribution.
Note this currently won't work with FSDP.
"""

fan_in = "fan_in"
"""
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
is the input dimensionality of the kernel.
"""


@dataclass
class ModelConfig(BaseConfig):
"""
Expand Down Expand Up @@ -282,6 +309,11 @@ class ModelConfig(BaseConfig):
models tend to have near 0 bias terms anyway.
"""

scale_logits: bool = False
"""
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
"""

vocab_size: int = 50257
"""
Vocabulary size of the model.
Expand Down Expand Up @@ -310,9 +342,15 @@ class ModelConfig(BaseConfig):
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
"""

init_fn: InitFnType = InitFnType.normal
"""
The weight initialization strategy.
"""

init_std: float = 0.02
"""
Standard deviation used when initializing parameters.
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
as "normal".
"""

precision: Optional[str] = None
Expand All @@ -324,7 +362,6 @@ class ModelConfig(BaseConfig):

class OptimizerType(StrEnum):
lionw = "lionw"
adam = "adam"
adamw = "adamw"


Expand All @@ -336,6 +373,12 @@ class OptimizerConfig(BaseConfig):
betas: Tuple[float, float] = (0.9, 0.95)
no_decay_norm_and_bias: bool = True
"""Do not apply weight decay to norms and biases."""
metrics_log_interval: Optional[int] = None
"""
The interval with which to collect and log optimizer-specific metrics.
This only applies when logging to W&B, since these metrics won't be logged to the console.
If not set, defaults to the wandb `log_interval`.
"""

def __post_init__(self):
self.betas = tuple(self.betas) # type: ignore[assignment]
Expand All @@ -344,6 +387,7 @@ def __post_init__(self):
class SchedulerType(StrEnum):
cosine_with_warmup = "cosine_with_warmup"
inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
max_scheduler = "max_scheduler"


@dataclass
Expand Down Expand Up @@ -440,7 +484,7 @@ class CompilerConfig(BaseConfig):
class FSDPConfig(BaseConfig):
use_orig_params: bool = True
"""
This must be ``True`` if using ``compile``.
This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
"""

sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
Expand Down Expand Up @@ -487,11 +531,6 @@ class TrainConfig(BaseConfig):
Learning rate scheduler configuration.
"""

restore_base_learning_rate: bool = True
"""
Set to ``False`` if you want to restart with the base learning rate from the config, not the checkpoint.
"""

data: DataConfig = field(default_factory=DataConfig)
"""
Training data configuration.
Expand Down Expand Up @@ -535,6 +574,11 @@ class TrainConfig(BaseConfig):
A folder in a cloud bucket to upload saved checkpoints to.
"""

canceled_check_interval: int = 5
"""
How often (in batches) to check if the run has been canceled or reached its time limit.
"""

save_interval: int = 1000
"""
How often (in terms of batches) to save training state checkpoints that can be used for restarts.
Expand Down
16 changes: 12 additions & 4 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
__all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"]


def build_memmap_dataset(train_config: TrainConfig, data_config: DataConfig) -> MemMapDataset:
def build_memmap_dataset(
train_config: TrainConfig, data_config: DataConfig, include_instance_metadata: bool = True
) -> MemMapDataset:
paths: List[str]
metadata: List[Dict[str, Any]] = []
if data_config.paths:
Expand All @@ -30,7 +32,12 @@ def build_memmap_dataset(train_config: TrainConfig, data_config: DataConfig) ->
metadata.extend([{"label": label}] * len(label_paths))
else:
raise OlmoConfigurationError("One of DataConfig.paths or DataConfig.datasets is required")
return MemMapDataset(*paths, chunk_size=train_config.model.max_sequence_length, metadata=metadata)
return MemMapDataset(
*paths,
chunk_size=train_config.model.max_sequence_length,
metadata=metadata,
include_instance_metadata=include_instance_metadata,
)


def build_eval_dataloader(
Expand All @@ -39,7 +46,7 @@ def build_eval_dataloader(
batch_size: int,
shuffle: bool = True,
) -> DataLoader:
dataset = build_memmap_dataset(train_config, data_config)
dataset = build_memmap_dataset(train_config, data_config, include_instance_metadata=True)
collator = DataCollator(pad_direction=data_config.pad_direction, pad_token_id=train_config.model.pad_token_id)
if data_config.drop_last:
# Make sure batch size is small enough.
Expand Down Expand Up @@ -72,7 +79,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
collator = DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)
dataset = build_memmap_dataset(train_config, train_config.data)
dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False)
work_dir = Path(train_config.save_folder) / "train_data"
if get_global_rank() == 0:
if work_dir.is_dir() and not train_config.save_overwrite:
Expand All @@ -90,6 +97,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
drop_last=train_config.data.drop_last,
max_examples=train_config.global_train_batch_size * train_config.max_duration,
work_dir=work_dir,
global_batch_size=train_config.global_train_batch_size,
),
batch_size=train_config.device_train_batch_size,
drop_last=train_config.data.drop_last,
Expand Down
33 changes: 29 additions & 4 deletions olmo/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.utils.data

from ..aliases import PathOrStr
from ..util import barrier, get_global_rank, get_world_size
from ..util import barrier, get_fs_local_rank, get_global_rank, get_world_size

__all__ = ["IterableDataset"]

Expand All @@ -35,7 +35,9 @@ def __init__(
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
fs_local_rank: Optional[int] = None,
work_dir: Optional[PathOrStr] = None,
global_batch_size: Optional[int] = None,
):
self.dataset = dataset
self.seed = seed
Expand All @@ -44,6 +46,7 @@ def __init__(
self.shuffle = shuffle
self.drop_last = drop_last
self.rank = rank if rank is not None else get_global_rank()
self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank()
self.world_size = world_size if world_size is not None else get_world_size()
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
Expand All @@ -56,10 +59,14 @@ def __init__(
else:
num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
self.total_size = num_samples * self.world_size
self.device_batch_size: Optional[int] = None
if global_batch_size is not None:
assert global_batch_size % self.world_size == 0
self.device_batch_size = global_batch_size // self.world_size
self.global_indices_file: Optional[Path] = None
if work_dir is not None:
self.global_indices_file = Path(work_dir) / "global_indices.npy"
if self.rank == 0:
if self.fs_local_rank == 0:
log.info("Saving global data order indices...")
self.global_indices_file.parent.mkdir(parents=True, exist_ok=True)
global_indices = self._build_global_indices()
Expand Down Expand Up @@ -121,10 +128,28 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
# Lastly, slice the indices by data loader worker rank to avoid duplicates.
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
indices = indices[worker_info.id :: worker_info.num_workers]
# Note that each data loading worker gathers a whole batch at a time, and the workers
# are called round-robin by rank. So to slice these up in a way that preserves order, regardless
# of the number of workers, we should give worker 0 the first chunk of `device_batch_size` indices,
# worker 1 the 2nd chunk of `device_train_batch_size` indices, etc...
if self.device_batch_size is not None:
if not isinstance(indices, (np.memmap, np.ndarray)):
indices = np.array(indices, dtype=np.uint64)
truncated_size = self.device_batch_size * (len(indices) // self.device_batch_size)
left_overs = indices[
truncated_size + worker_info.id : truncated_size + worker_info.id + worker_info.num_workers
]
indices = (
indices[:truncated_size]
.reshape((-1, self.device_batch_size))[worker_info.id :: worker_info.num_workers] # type: ignore
.reshape((-1,))
)
indices = np.concatenate([indices, left_overs])
else:
indices = indices[worker_info.id :: worker_info.num_workers]

# Convert to a list at this point so we don't have to rely on memory-mapping.
if isinstance(indices, np.memmap):
if isinstance(indices, (np.memmap, np.ndarray)):
indices_list = indices.tolist() # type: ignore
else:
indices_list = indices
Expand Down
Loading

0 comments on commit 9071816

Please sign in to comment.