import logging
from datetime import datetime
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from truefoundry import ml
from truefoundry.ml.integrations.huggingface.trainer_callback import (
TrueFoundryMLCallback,
)
logger = logging.getLogger(__name__)
def tokenize_fn(batch, tokenizer):
out = tokenizer(batch["text"], truncation=True, max_length=128, padding="longest")
out["labels"] = out["input_ids"].copy()
return out
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ml-repo", type=str, required=True)
parser.add_argument("--run-name", type=str, required=False)
args = parser.parse_args()
dt = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
run_name = args.run_name or f"test-run-{dt}"
client = ml.get_client()
run = client.create_run(
ml_repo=args.ml_repo,
run_name=run_name,
auto_end=False,
)
run_name = run.run_name
callback = TrueFoundryMLCallback(
run=run,
log_checkpoints=True,
checkpoint_artifact_name=f"ckpt-{run_name}",
)
# 1. Create minimal dataset
examples = [
{"text": "Hello, how are you? I am fine."},
{"text": "Translate: cat -> chat"},
{"text": "Q: What is 2+2? A: 4"},
{"text": "Say hello in Spanish: Hola"},
{"text": "Who wrote 'Hamlet'? William Shakespeare."},
]
ds = Dataset.from_list(examples)
# 2. Load tokenizer and model
model_id = "llamafactory/tiny-random-Llama-3"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id)
# 3. Tokenize as causal LM
ds = ds.map(
tokenize_fn,
batched=True,
remove_columns=["text"],
fn_kwargs={"tokenizer": tokenizer},
)
# 4. Setup training arguments
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=len(examples),
num_train_epochs=15,
logging_strategy="steps",
logging_steps=5,
eval_strategy="steps",
eval_steps=5,
save_strategy="steps",
save_steps=5,
save_total_limit=2,
learning_rate=5e-5,
fp16=False,
bf16=False,
report_to="none",
load_best_model_at_end=True,
)
run.log_params(
{
"model_id": model_id,
"batch_size": training_args.per_device_train_batch_size,
"num_train_epochs": training_args.num_train_epochs,
"learning_rate": training_args.learning_rate,
"fp16": training_args.fp16,
"bf16": training_args.bf16,
"report_to": training_args.report_to,
}
)
# 5. Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=ds,
eval_dataset=ds,
callbacks=[callback],
)
# 6. Train and save
trainer.train()
model_dir = "./model"
trainer.save_model(model_dir)
# 7. Log the final model
run.log_model(
name=f"model-{run_name}",
model_file_or_folder=model_dir,
framework=ml.TransformersFramework(
library_name="transformers", # type: ignore
pipeline_tag="text-generation",
),
metadata={
"base_model": model_id,
},
step=trainer.state.best_global_step or trainer.state.global_step,
)
if __name__ == "__main__":
main()