diff --git a/python/src/dapt/train.py b/python/src/dapt/train.py index c6d2dec..e937c7f 100644 --- a/python/src/dapt/train.py +++ b/python/src/dapt/train.py @@ -98,7 +98,7 @@ def train(config: DAPTConfig) -> None: num_train_epochs=config.training.num_train_epochs, per_device_train_batch_size=config.training.per_device_train_batch_size, gradient_accumulation_steps=config.training.gradient_accumulation_steps, - warmup_ratio=config.training.warmup_ratio, + warmup_steps=int(config.training.warmup_ratio * (len(split["train"]) // (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))), weight_decay=config.training.weight_decay, bf16=config.training.bf16, gradient_checkpointing=config.training.gradient_checkpointing,