The following example demonstrates how to use the TrueFoundry Huggingface TrainerCallback to automatically log metrics and checkpoints to TrueFoundry. As an example, we will train a small language model on a small dataset but you can use this callback for any model. We also log the model at the end of the training.

Training Script with Huggingface Trainer Callback

Please choose logging_steps and save_steps in TrainingArguments carefully.
  • logging_steps controls how often the metrics are logged
  • save_steps controls how often the checkpoints are uploaded
train.py
import logging
import math
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional

import numpy as np
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainerCallback,
    TrainingArguments,
)
from transformers.integrations import rewrite_logs
from truefoundry import ml

if TYPE_CHECKING:
    from truefoundry.ml import MlFoundryRun

logger = logging.getLogger(__name__)


class TrueFoundryMLCallback(TrainerCallback):
    def __init__(
        self,
        run: "MlFoundryRun",
        log_checkpoints: bool = True,
        checkpoint_artifact_name: Optional[str] = None,
    ):
        self._run = run
        self._log_checkpoints = log_checkpoints
        if self._log_checkpoints and not checkpoint_artifact_name:
            raise ValueError(
                "`checkpoint_artifact_name` is required when `log_checkpoints` is True"
            )
        self._checkpoint_artifact_name = checkpoint_artifact_name

    def _drop_non_finite_values(self, dct: Dict[str, Any]) -> Dict[str, Any]:
        sanitized = {}
        for k, v in dct.items():
            if isinstance(v, (int, float, np.integer, np.floating)):
                if not math.isfinite(v):
                    logger.warning(f"Dropping non-finite value for key={k} value={v!r}")
                    continue
            sanitized[k] = v
        return sanitized

    # noinspection PyMethodOverriding
    def on_log(self, args, state, control, logs=None, model=None, **kwargs):
        logs = logs or {}
        if not state.is_world_process_zero:
            return

        metrics = {}
        for k, v in logs.items():
            if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(
                v
            ):
                metrics[k] = v
            else:
                logger.warning(
                    f'Trainer is attempting to log a value of "{v}" of'
                    f' type {type(v)} for key "{k}" as a metric.'
                    " Mlfoundry's log_metric() only accepts finite float and"
                    " int types so we dropped this attribute."
                )
        self._run.log_metrics(rewrite_logs(metrics), step=state.global_step)

    def on_save(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        if not self._log_checkpoints:
            return

        if not self._checkpoint_artifact_name:
            return

        ckpt_dir = f"checkpoint-{state.global_step}"
        artifact_path = os.path.join(args.output_dir, ckpt_dir)
        description = None
        _job_name = os.getenv("TFY_INTERNAL_COMPONENT_NAME")
        _job_run_name = os.getenv("TFY_INTERNAL_JOB_RUN_NAME")
        if _job_name:
            description = f"Checkpoint from job={_job_name} run={_job_run_name}"
        logger.info(f"Uploading checkpoint {ckpt_dir} ...")
        metadata = {}
        for log in state.log_history:
            if isinstance(log, dict) and log.get("step") == state.global_step:
                metadata = log.copy()
        metadata = self._drop_non_finite_values(metadata)
        self._run.log_artifact(
            name=self._checkpoint_artifact_name,
            artifact_paths=[(artifact_path,)],
            metadata=metadata,
            step=state.global_step,
            description=description,
        )


def tokenize_fn(batch, tokenizer):
    out = tokenizer(batch["text"], truncation=True, max_length=128, padding="longest")
    out["labels"] = out["input_ids"].copy()
    return out


def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--ml-repo", type=str, required=True)
    parser.add_argument("--run-name", type=str, required=False)
    args = parser.parse_args()

    dt = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    run_name = args.run_name or f"test-run-{dt}"

    client = ml.get_client()
    run = client.create_run(
        ml_repo=args.ml_repo,
        run_name=run_name,
        auto_end=False,
    )
    run_name = run.name
    callback = TrueFoundryMLCallback(
        run=run,
        log_checkpoints=True,
        checkpoint_artifact_name=f"ckpt-{run_name}",
    )

    # 1. Create minimal dataset
    examples = [
        {"text": "Hello, how are you? I am fine."},
        {"text": "Translate: cat -> chat"},
        {"text": "Q: What is 2+2? A: 4"},
        {"text": "Say hello in Spanish: Hola"},
        {"text": "Who wrote 'Hamlet'? William Shakespeare."},
    ]
    ds = Dataset.from_list(examples)

    # 2. Load tokenizer and model
    model_id = "llamafactory/tiny-random-Llama-3"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id)

    # 3. Tokenize as causal LM
    ds = ds.map(
        tokenize_fn,
        batched=True,
        remove_columns=["text"],
        fn_kwargs={"tokenizer": tokenizer},
    )

    # 4. Setup training arguments
    training_args = TrainingArguments(
        output_dir="./output",
        per_device_train_batch_size=2,
        num_train_epochs=15,
        logging_steps=5,
        save_steps=5,
        save_total_limit=2,
        learning_rate=5e-5,
        fp16=False,
        bf16=False,
        report_to="none",
    )

    run.log_params(
        {
            "model_id": model_id,
            "batch_size": training_args.per_device_train_batch_size,
            "num_train_epochs": training_args.num_train_epochs,
            "learning_rate": training_args.learning_rate,
            "fp16": training_args.fp16,
            "bf16": training_args.bf16,
            "report_to": training_args.report_to,
        }
    )

    # 5. Trainer
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=ds,
        callbacks=[callback],
    )

    # 6. Train and save
    trainer.train()
    model_dir = "./model"
    trainer.save_model(model_dir)

    # 7. Log the final model
    run.log_model(
        name=f"model-{run_name}",
        model_file_or_folder=model_dir,
        framework=ml.TransformersFramework(
            library_name="transformers",  # type: ignore
            pipeline_tag="text-generation",
        ),
        metadata={
            "base_model": model_id,
        },
    )


if __name__ == "__main__":
    main()

Running the script

  1. Create a ML Repo if not already created
  2. Setup TrueFoundry CLI
  3. Run the script
python train.py --ml-repo <ml-repo-name> --run-name <run-name>
# E.g. python train.py  --ml-repo llm-finetuning --run-name test-hf-cb-gxerg

Viewing the Run on UI

If you ran the script as a TrueFoundry Job, you can access the run in the Job Details column of job run

Click on the Run Name to view the run details

Navigate to the ML Repo, you should see the run in the Runs Tab. All metrics, checkpoints and models should be available in this run.

Run Details

Metrics

Checkpoints

Models