From e5f89ffabbe9fcfa7499d24fbaf94aeabc0c46a7 Mon Sep 17 00:00:00 2001 From: Joey Eamigh <55670930+JoeyEamigh@users.noreply.github.com> Date: Sun, 29 Mar 2026 21:17:50 -0400 Subject: [PATCH] caching in the pipelines --- python/src/dapt/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/dapt/train.py b/python/src/dapt/train.py index c6d2dec..e937c7f 100644 --- a/python/src/dapt/train.py +++ b/python/src/dapt/train.py @@ -98,7 +98,7 @@ def train(config: DAPTConfig) -> None: num_train_epochs=config.training.num_train_epochs, per_device_train_batch_size=config.training.per_device_train_batch_size, gradient_accumulation_steps=config.training.gradient_accumulation_steps, - warmup_ratio=config.training.warmup_ratio, + warmup_steps=int(config.training.warmup_ratio * (len(split["train"]) // (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))), weight_decay=config.training.weight_decay, bf16=config.training.bf16, gradient_checkpointing=config.training.gradient_checkpointing,