Callback for Huggingface Trainer to automatically log metrics and checkpoints to TrueFoundry
TrainerCallback
to automatically log metrics and checkpoints to TrueFoundry.
As an example, we will train a small language model on a small dataset but you can use this callback for any model. We also log the model at the end of the training.
logging_steps
and save_steps
in TrainingArguments
carefully.logging_steps
controls how often the metrics are loggedsave_steps
controls how often the checkpoints are uploadedimport logging
import math
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional
import numpy as np
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.integrations import rewrite_logs
from truefoundry import ml
if TYPE_CHECKING:
from truefoundry.ml import MlFoundryRun
logger = logging.getLogger(__name__)
class TrueFoundryMLCallback(TrainerCallback):
def __init__(
self,
run: "MlFoundryRun",
log_checkpoints: bool = True,
checkpoint_artifact_name: Optional[str] = None,
):
self._run = run
self._log_checkpoints = log_checkpoints
if self._log_checkpoints and not checkpoint_artifact_name:
raise ValueError(
"`checkpoint_artifact_name` is required when `log_checkpoints` is True"
)
self._checkpoint_artifact_name = checkpoint_artifact_name
def _drop_non_finite_values(self, dct: Dict[str, Any]) -> Dict[str, Any]:
sanitized = {}
for k, v in dct.items():
if isinstance(v, (int, float, np.integer, np.floating)):
if not math.isfinite(v):
logger.warning(f"Dropping non-finite value for key={k} value={v!r}")
continue
sanitized[k] = v
return sanitized
# noinspection PyMethodOverriding
def on_log(self, args, state, control, logs=None, model=None, **kwargs):
logs = logs or {}
if not state.is_world_process_zero:
return
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(
v
):
metrics[k] = v
else:
logger.warning(
f'Trainer is attempting to log a value of "{v}" of'
f' type {type(v)} for key "{k}" as a metric.'
" Mlfoundry's log_metric() only accepts finite float and"
" int types so we dropped this attribute."
)
self._run.log_metrics(rewrite_logs(metrics), step=state.global_step)
def on_save(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
if not self._log_checkpoints:
return
if not self._checkpoint_artifact_name:
return
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
description = None
_job_name = os.getenv("TFY_INTERNAL_COMPONENT_NAME")
_job_run_name = os.getenv("TFY_INTERNAL_JOB_RUN_NAME")
if _job_name:
description = f"Checkpoint from job={_job_name} run={_job_run_name}"
logger.info(f"Uploading checkpoint {ckpt_dir} ...")
metadata = {}
for log in state.log_history:
if isinstance(log, dict) and log.get("step") == state.global_step:
metadata = log.copy()
metadata = self._drop_non_finite_values(metadata)
self._run.log_artifact(
name=self._checkpoint_artifact_name,
artifact_paths=[(artifact_path,)],
metadata=metadata,
step=state.global_step,
description=description,
)
def tokenize_fn(batch, tokenizer):
out = tokenizer(batch["text"], truncation=True, max_length=128, padding="longest")
out["labels"] = out["input_ids"].copy()
return out
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ml-repo", type=str, required=True)
parser.add_argument("--run-name", type=str, required=False)
args = parser.parse_args()
dt = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
run_name = args.run_name or f"test-run-{dt}"
client = ml.get_client()
run = client.create_run(
ml_repo=args.ml_repo,
run_name=run_name,
auto_end=False,
)
run_name = run.name
callback = TrueFoundryMLCallback(
run=run,
log_checkpoints=True,
checkpoint_artifact_name=f"ckpt-{run_name}",
)
# 1. Create minimal dataset
examples = [
{"text": "Hello, how are you? I am fine."},
{"text": "Translate: cat -> chat"},
{"text": "Q: What is 2+2? A: 4"},
{"text": "Say hello in Spanish: Hola"},
{"text": "Who wrote 'Hamlet'? William Shakespeare."},
]
ds = Dataset.from_list(examples)
# 2. Load tokenizer and model
model_id = "llamafactory/tiny-random-Llama-3"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id)
# 3. Tokenize as causal LM
ds = ds.map(
tokenize_fn,
batched=True,
remove_columns=["text"],
fn_kwargs={"tokenizer": tokenizer},
)
# 4. Setup training arguments
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=2,
num_train_epochs=15,
logging_steps=5,
save_steps=5,
save_total_limit=2,
learning_rate=5e-5,
fp16=False,
bf16=False,
report_to="none",
)
run.log_params(
{
"model_id": model_id,
"batch_size": training_args.per_device_train_batch_size,
"num_train_epochs": training_args.num_train_epochs,
"learning_rate": training_args.learning_rate,
"fp16": training_args.fp16,
"bf16": training_args.bf16,
"report_to": training_args.report_to,
}
)
# 5. Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=ds,
callbacks=[callback],
)
# 6. Train and save
trainer.train()
model_dir = "./model"
trainer.save_model(model_dir)
# 7. Log the final model
run.log_model(
name=f"model-{run_name}",
model_file_or_folder=model_dir,
framework=ml.TransformersFramework(
library_name="transformers", # type: ignore
pipeline_tag="text-generation",
),
metadata={
"base_model": model_id,
},
)
if __name__ == "__main__":
main()
python train.py --ml-repo <ml-repo-name> --run-name <run-name>
# E.g. python train.py --ml-repo llm-finetuning --run-name test-hf-cb-gxerg
Click on the Run Name to view the run details
Run Details
Metrics
Checkpoints
Models
Was this page helpful?