Using TPUs

Deploy apps on Single Host TPU v4, v5e, v5p slices

🚧

Compatibility Notes

  • Currently we only support Single Host TPU Topologies. If you have a use case that can benefit from Multi Host Topologies please reach out to us
  • Currently we only support provisioning/scheduling TPU workloads using Node Auto Provisioning (NAP). We'll add support for nodepools soon.

Prerequistes

Kubernetes Version

While different TPU types require different mininum GKE version, considering all TPU types, Node Auto Provisioning Support and ability to run without privileded mode, we recommend using 1.28.7-gke.1020000 or later and 1.29.2-gke.1035000 or later

Please refer to following links for up to date requirements

Regional Availability

TPUs are available in limited zones. Following table lists the availability for Single Host TPUs.

Please refer to following links for up to date availability:

TPU Typeus-west1-cus-west4-aus-central1-aus-central2-bus-east1-cus-east1-dus-east5-aus-east5-ceurope-west4-b
v4
v5e Device
v5e PodSlice
v5p

Quota

Please make sure you have enough quota available in the respective region. You can apply for Quotas on your GCP Console: https://console.cloud.google.com/iam-admin/quotas

Node Auto Provisioning

Once you have Quota, add the TPU accelerators to NAP's limits. Please follow the guide at https://cloud.google.com/kubernetes-engine/docs/how-to/node-auto-provisioning#config_file to update these limits.

📘

CPU and Memory Limits

TPU VMs often have large amounts of CPU and Memory. Please make sure to also adjust CPU and Memory resource limits in NAP config. Inadequate limits on CPU and Memory might block TPU nodes from scaling

Here is what a sample config looks like for adding TPU accelerators

resourceLimits:
  - resourceType: 'cpu'
    minimum: 0
    maximum: 1000
  - resourceType: 'memory'
    minimum: 0
    maximum: 10000
  # ...
  - resourceType: 'tpu-v4-podslice'
    minimum: 0
    maximum: 32
  - resourceType: 'tpu-v5-lite-podslice'
    minimum: 0
    maximum: 32
  - resourceType: 'tpu-v5-lite-device'
    minimum: 0
    maximum: 32
  - resourceType: 'tpu-v5p-slice'
    minimum: 0
    maximum: 32
autoprovisioningLocations:
  # Change these zones according to your cluster
  - us-central2-b
management:
  autoRepair: true
  autoUpgrade: true
shieldedInstanceConfig:
  enableSecureBoot: true
  enableIntegrityMonitoring: true
diskSizeGb: 100
diskType: pd-balanced

Adding TPU Devices

Using Spec or Python SDK

📘

Tip

You can use the UI to generate equivalent YAML or Python Spec

You can add TPU by simply adding it under the resources.devices section like folllows:

name: my-service
type: service
...
resources:
  node:
    type: node_selector
  devices:
+    - name: tpu-v5-lite-device
+      type: gcp_tpu
+      topology: 1x1
  cpu_request: 21
  cpu_limit: 23
  memory_request: 39100
  memory_limit: 46000
  ephemeral_storage_request: 20000
  ephemeral_storage_limit: 100000
...
from truefoundry.deploy import Service, Resources, NodeSelector
+ from truefoundry.deploy import GcpTPU, TPUType
...


service = Service(
  name=" my-service",
  ...
  resources=Resources(
+    devices=[
+      GcpTPU(name=TPUType.V5_LITE_DEVICE, topology="1x1")
+    ]
    cpu_request=21,
    cpu_limit=23,
    memory_request=39100,
    memory_limit=46000,
    ephemeral_storage_request=20000,
    ephemeral_storage_limit=100000,
    node=NodeSelector()
  )
    
)

Supported values for name and topologies:

TPU TypeNameSupported TopologiesEquivalent Devices Count
v4 PodSlicetpu-v4-podslice2x2x14
v5 Lite Devicetpu-v5-lite-device1x1, 2x2, 2x41, 4, 8
v5 Lite PodSlicetpu-v5-lite-podslice1x1, 2x2, 2x41, 4, 8
v5p Slicetpu-v5p-slice2x2x14

Using UI

You can add TPUs to your Service / Job / Notebook / SSH Server by simply selecting a TPU Type and a topology from the Resources section

Examples

📘

Before you begin

We recommend reading through following pages first to understand basics of deployment with TrueFoundry

Enumerate TPU chips

This simple example Job installs Jax and enumerates the assigned TPU devices

name: enumerate-tpu-devices
type: job
image:
  type: image
  image_uri: python:3.11
  command: bash /run.sh
mounts:
  - type: string
    mount_path: /run.sh
    data: |
      #!/bin/bash
      pip install --quiet --upgrade 'jax[tpu]>0.3.0' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      python -c 'import jax; print(jax.devices())'
retries: 0
trigger:
  type: manual
resources:
  node:
    type: node_selector
    capacity_type: spot_fallback_on_demand
  devices:
    - name: tpu-v5-lite-device
      type: gcp_tpu
      topology: 1x1
  cpu_request: 21
  cpu_limit: 23
  memory_request: 39100
  memory_limit: 46000
  ephemeral_storage_request: 20000
  ephemeral_storage_limit: 100000
trigger_on_deploy: true

On submitting, it would bring up a TPU node and produce the following output

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]

Running Gemma 2B with Jax

This Sample Notebook demonstrates running Gemma 2B using Jax in an interactive environment. The notebook can be uploaded directly to Notebook instance running with TPUs on TrueFoundry

Deploying Gemma 2B LLM with Jestream & MaxText on TPU v5e

This section is adapted from official GKE Guide: https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream

Step 1: Get Access to the model on Kaggle

Step 2: Add kaggle.json to Secrets

Create a Secret add the contents of kaggle.json like follows.

Once you have saved, keep the secret fqn handy, we'll use it later


Step 3: Create Workspace, GCS Bucket, GCP Service Account and Link Them

  • Create a Workspace to deploy the model to
  • Create a GCS Bucket from GCP Console
  • Run the following script to create GCP Service Account with correct permissions. Make sure to edit your project id, bucket name, regions, workspace name
#!/bin/bash

gcloud config set project your-project
PROJECT_ID=$(gcloud config get project)
BUCKET_NAME=tpu-models-bucket
REGION=us-central1
LOCATION=us-central1-a
WORKSPACE_NAME=tpu-models-workspace
SA_NAME=tpu-models-sa


gcloud iam service-accounts create "${SA_NAME}"

gcloud storage buckets add-iam-policy-binding \
    "gs://${BUCKET_NAME}" \
    --member "serviceAccount:${SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"  \
    --role roles/storage.objectUser \
    --project "${PROJECT_ID}"
gcloud storage buckets add-iam-policy-binding \
  "gs://${BUCKET_NAME}" \
  --member "serviceAccount:${SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com" \
  --role roles/storage.insightsCollectorService \
  --project "${PROJECT_ID}"
  
gcloud iam service-accounts add-iam-policy-binding \
    "${SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com" \
    --role roles/iam.workloadIdentityUser \
    --member "serviceAccount:${PROJECT_ID}.svc.id.goog[${WORKSPACE_NAME}/${SA_NAME}]"
  • Finally edit the workspace you created and add the service account principal

Step 4: Deploy and Run the Job to convert the model to MaxText Engine format

This Job pulls in the model from Kaggle using the kaggle.json secret we created earlier and puts them in your GCS bucket using the serviceaccount we created earlier

import logging
from truefoundry.deploy import Job, Image, Resources, GcpTPU, TPUType, SecretMount

logging.basicConfig(level=logging.INFO)

job = Job(
    name="maxengine-ckpt-conv",
    # Mention your bucket in command below
    image=Image(
        image_uri="us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0",
        command="/usr/bin/checkpoint_converter.sh -b=tpu-models-bucket -m=google/gemma/maxtext/2b-it/3",
    ),
    # Service account to use to write to the bucket
    service_account="tpu-models-sa",
    # Secret to use for fetching the model checkpoint from Kaggle
    mounts=[
        SecretMount(
            mount_path="/kaggle/kaggle.json",
            secret_fqn="tfy-secret://your-org:kaggle-secrets:kaggle-json"
        )
    ],
    resources=Resources(
        cpu_request=20,
        cpu_limit=20,
        memory_request=150000,
        memory_limit=190000,
        ephemeral_storage_request=30000,
        ephemeral_storage_limit=40000,
        devices=[
            GcpTPU(
                name=TPUType.V5_LITE_DEVICE,
                topology="2x2"
            )
        ]
    ),
    trigger_on_deploy=True,
)

# Add the fqn of the Workspace we created earlier
job.deploy(workspace_fqn="<your-workspace-fqn>", wait=False)

Once this runs successfully, you will get the base and converted checkpoints in your bucket

Step 5: Deploy the model via MaxText Engine

import logging
from servicefoundry import Service, Image, Port, Resources, GcpTPU, TPUType

logging.basicConfig(level=logging.INFO)

service = Service(
    name="maxengine-server",
    # Model name and bucket are mentioned in the command
    image=Image(
        image_uri="us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.0",
        command="/usr/bin/maxengine_server_entrypoint.sh model_name=gemma-2b tokenizer_path=assets/tokenizer.gemma per_device_batch_size=4 max_prefill_predict_length=1024 max_target_length=2048 async_checkpointing=false ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 ici_tensor_parallelism=1 scan_layers=false weight_dtype=bfloat16 load_parameters_path=gs://tpu-models-bucket/final/unscanned/gemma_2b-it/0/checkpoints/0/items",
    ),
    ports=[
        Port(port=9000, expose=False),
    ],
    # Service account to use to read from the bucket
    service_account="tpu-models-sa",
    resources=Resources(
        cpu_request=20,
        cpu_limit=24,
        memory_request=40000,
        memory_limit=48000,
        ephemeral_storage_request=30000,
        ephemeral_storage_limit=40000,
        devices=[
            GcpTPU(
                name=TPUType.V5_LITE_DEVICE,
                topology="1x1"
            )
        ]
    ),
)

# Add the fqn of the Workspace we created earlier
service.deploy(workspace_fqn="<your-workspace-fqn>", wait=False)

This runs the maxengine server in gRPC mode on port 9000

Step 5: Deploy a HTTP Head Server

Finally we can deploy a http head server that is just a gRPC client wrapped in FastAPI app

import logging
from servicefoundry import Service, Image, LocalSource, DockerFileBuild, Port, Resources

logging.basicConfig(level=logging.INFO)

service = Service(
    name="jetstream-http",
    image=Image(image_uri="truefoundrycloud/jetstream-http-patched:v0.2.0"),
    ports=[
        Port(
            port=8000,
            host="<Your Host>"
        ),
    ],
    resources=Resources(
        cpu_request=0.5,
        cpu_limit=1,
        memory_request=500,
        memory_limit=1000,
        ephemeral_storage_request=1000,
        ephemeral_storage_limit=4000,
    ),
)

# Add the fqn of the Workspace we created earlier
service.deploy(workspace_fqn="<your-workspace-fqn>", wait=False)

Once deployed, you can make request with the following body on the /generate endpoint

{
  "server": "maxengine-server.tpu-models-workspace.svc.cluster.local",
  "port": "9000",
  "prompt": "This is a sample prompt ",
  "priority": 0,
  "max_tokens": 500
}