import os
import math
import os
import re
import shutil
from typing import Any, Dict, Optional
import numpy as np
from huggingface_hub import scan_cache_dir
import argparse
import json
from truefoundry.ml import get_client, TransformersFramework
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
def get_or_create_run(
ml_repo: str, run_name: str, auto_end: bool = False, create_ml_repo: bool = False
):
client = get_client()
if create_ml_repo:
client.create_ml_repo(ml_repo=ml_repo)
try:
run = client.get_run_by_name(ml_repo=ml_repo, run_name=run_name)
except Exception as e:
if "RESOURCE_DOES_NOT_EXIST" not in str(e):
raise
run = client.create_run(ml_repo=ml_repo, run_name=run_name, auto_end=auto_end)
return run
def log_model_to_truefoundry(
run,
model_name: str,
model_dir: str,
hf_hub_model_id: str,
metadata: Optional[Dict[str, Any]] = None,
step: int = 0,
):
metadata = metadata or {}
print("Uploading Model...")
hf_cache_info = scan_cache_dir()
files_to_save = []
for repo in hf_cache_info.repos:
if repo.repo_id == hf_hub_model_id:
for revision in repo.revisions:
for file in revision.files:
if file.file_path.name.endswith(".py"):
files_to_save.append(file.file_path)
break
# copy the files to output_dir of pipeline
for file_path in files_to_save:
match = re.match(r".*snapshots\/[^\/]+\/(.*)", str(file_path))
if match:
relative_path = match.group(1)
destination_path = os.path.join(model_dir, relative_path)
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
shutil.copy(str(file_path), destination_path)
else:
print("Python file in hf model cache in unknown path:", file_path)
metadata.update(
{
"pipeline_tag": "text-generation",
"library_name": "transformers",
"base_model": hf_hub_model_id,
"huggingface_model_url": f"https://huggingface.co/{hf_hub_model_id}"
}
)
metadata = {
k: v
for k, v in metadata.items()
if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(v)
}
run.log_model(
name=model_name,
model_file_or_folder=model_dir,
framework=TransformersFramework(pipeline_tag="text-generation", library_name="transformers", base_model=hf_hub_model_id),
metadata=metadata,
step=step,
)
print(f"You can view the model at {run.dashboard_link}?tab=models")
def merge_and_upload(
hf_hub_model_id: str,
ml_repo: str,
run_name: str,
artifact_version_fqn: str,
saved_model_name: str,
dtype: str = "bfloat16",
device_map: str = "auto",
):
import torch
client = get_client()
if device_map.startswith("{"):
device_map = json.loads(device_map)
artifact_version = client.get_artifact_version_by_fqn(artifact_version_fqn)
lora_model_path = artifact_version.download()
tokenizer = AutoTokenizer.from_pretrained(hf_hub_model_id)
model = AutoModelForCausalLM.from_pretrained(hf_hub_model_id, device_map=device_map, torch_dtype=getattr(torch, dtype))
model = PeftModel.from_pretrained(model, lora_model_path)
model = model.merge_and_unload(progressbar=True)
merged_model_dir = os.path.abspath("./merged")
os.makedirs(merged_model_dir, exist_ok=True)
tokenizer.save_pretrained(merged_model_dir)
model.save_pretrained(merged_model_dir)
run = get_or_create_run(
ml_repo=ml_repo,
run_name=run_name,
auto_end=False,
create_ml_repo=False,
)
log_model_to_truefoundry(
run=run,
model_name=saved_model_name,
model_dir=merged_model_dir,
hf_hub_model_id=hf_hub_model_id,
metadata={
"checkpoint": artifact_version_fqn,
},
step=artifact_version.step,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hf_hub_model_id", type=str, required=True)
parser.add_argument("--ml_repo", type=str, required=True)
parser.add_argument("--run_name", type=str, required=True)
parser.add_argument("--artifact_version_fqn", type=str, required=True)
parser.add_argument("--saved_model_name", type=str, required=True)
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument("--device_map", type=str, default="auto")
args = parser.parse_args()
merge_and_upload(
hf_hub_model_id=args.hf_hub_model_id,
ml_repo=args.ml_repo,
run_name=args.run_name,
artifact_version_fqn=args.artifact_version_fqn,
saved_model_name=args.saved_model_name,
dtype=args.dtype,
device_map=args.device_map,
)
if __name__ == "__main__":
main()