Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
59 changes: 45 additions & 14 deletions 59 finetune_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from pretrain_gpt import get_batch_pipe as get_batch_pipe_gpt
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.gpt_dataset import build_dataset_group as build_dataset_group_gpt
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
Expand Down Expand Up @@ -48,6 +50,14 @@ def model_provider(pre_process=True, post_process=True):
return model


def fast_normalize(loss_mask: torch.Tensor):
"""
Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
"""
_, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
counts = torch.gather(dim=0, index=inverse_indices, input=counts)
return loss_mask / counts

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand All @@ -57,6 +67,9 @@ def get_batch_pipe(data):
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
if 'text' in data:
return get_batch_pipe_gpt(data)

args = get_args()
tokenizer = get_tokenizer()

Expand Down Expand Up @@ -95,6 +108,10 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.norm_target_loss:
loss_mask = loss_mask.view(-1)
loss_mask = fast_normalize(loss_mask)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

Expand Down Expand Up @@ -142,20 +159,34 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
if "merged-meg-ds_v3_pii" in paths[0]:
d = build_dataset_group_gpt(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
else:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
Expand Down
2 changes: 2 additions & 0 deletions 2 megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None):
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',
help='Mask loss on input sequence.')
group.add_argument('--norm-target-loss', action='store_true',
help='Normalize the loss per target. Used for multi-task finetuning with packing.')
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
Expand Down
5 changes: 4 additions & 1 deletion 5 megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""GPT-2 model."""

from functools import partial
import torch

from megatron import get_args
Expand Down Expand Up @@ -186,6 +185,10 @@ def CrossEntropy(output, labels):
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
elif args.norm_target_loss and (loss_mask.dim() == 1):
expected_num_of_target_seqs = loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs
return loss
else:
expected_number_of_tokens = loss_mask.sum()

Expand Down
5 changes: 3 additions & 2 deletions 5 megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def pretrain(train_valid_test_dataset_provider,
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')

iteration = 0
iteration = args.iteration
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
Expand All @@ -199,7 +199,8 @@ def pretrain(train_valid_test_dataset_provider,
iterator, model,
iteration, False, data_group_name=name)

if args.save and iteration != 0:
# Do not save if the iteration has not changed
if args.save and iteration != args.iteration:
save_checkpoint(iteration, model, optimizer, lr_scheduler)

if args.do_test:
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.