Inferência em lote com Dados de Raios e vLLM

Importante

Este recurso está em versão Beta. Os administradores do espaço de trabalho podem controlar o acesso a esse recurso na página Visualizações . Ver Gerir as pré-visualizações de Azure Databricks.

Este exemplo executa inferência em lote offline de LLM com Ray Data e vLLM em 8 GPU H100 num único nó. Um script de inicialização inicia um cluster Ray no nó e, em seguida, o driver utiliza a API de LLM do Ray Data (ray.data.llm) para lançar uma réplica do vLLM por GPU e fazer passar por estas, em fluxo, um conjunto de dados de prompts, escrevendo o texto gerado num volume do Unity Catalog em formato Parquet.

Utiliza um modelo público (Qwen2.5-7B-Instruct), pelo que funciona tal como está sem um token da Hugging Face.

A carga de trabalho faz o seguinte:

  • Faz o carregamento do projeto local com code_source: snapshot.
  • Arranca uma cabeça Ray com as 8 GPUs, depois executa o driver de inferência em lote.
  • Usa ray.data.llm para correr uma réplica vLLM por GPU e processar os prompts em paralelo.
  • Escreve os prompts e as saídas geradas num volume do Unity Catalog em formato Parquet.

Pré-requisitos

Estrutura do projeto

Crie um diretório com os seguintes ficheiros.

ray_batch_inference/
├── train.yaml            # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py    # Ray Data + vLLM batch inference driver

Passo 1: Escrever a carga de trabalho em YAML

train.yaml solicita um único GPU_8xH100 nó. As dependências são declaradas na própria linha em environment (com a imagem do cliente version), e o command inicia um cluster Ray no nó e depois executa o controlador, pelo que a carga de trabalho não necessita de um ficheiro de dependências separado nem de um script de arranque.

O vLLM não está na imagem base, por isso está instalado em linha juntamente com três pinos de que os nós da GPU precisam: hf_transfer (a imagem base permite downloads rápidos do Hugging Face e espera este pacote), um mais fsspec recente (a imagem base envia uma antiga que quebra os downloads) e um fixado opencv-python-headless (o vLLM puxa o OpenCV, cujo volante por defeito faz crashar o auto-teste do OpenSSL FIPS nos nós da GPU).

Define OUTPUT_PATH para um volume do Unity Catalog onde possas escrever.

experiment_name: air-ray-batch-inference

environment:
  version: '4'
  dependencies:
    - ray[data]>=2.44
    - vllm
    - datasets>=3.0
    - huggingface_hub>=0.34
    # The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
    # so model and dataset downloads don't error out.
    - hf_transfer
    # The base image ships fsspec 2023.5.0, which is too old for modern
    # huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
    - fsspec>=2024.6.1
    # vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
    # on the GPU nodes. This pinned headless build avoids the crash.
    - opencv-python-headless==4.12.0.88

# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
  num_accelerators: 8
  accelerator_type: GPU_8xH100

code_source:
  type: snapshot
  snapshot:
    root_path: .

command: |
  cd $CODE_SOURCE_PATH
  RAY_HEAD_PORT=6379
  GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
  if [ "${NODE_RANK:-0}" = "0" ]; then
    echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
    ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
    python batch_inference.py
    ray stop
  else
    echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
    for i in $(seq 1 12); do
      if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
        break
      fi
      echo "Attempt $i failed, retrying in 5s..."
      sleep 5
    done
  fi

max_retries: 0
timeout_minutes: 60
env_variables:
  NCCL_SOCKET_IFNAME: eth0
  # Unity Catalog volume where results land as Parquet. Replace with your volume.
  OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference

O inline command inicia uma cabeça Ray com todas as GPUs no nó, executa o driver com python batch_inference.py, e depois para o cluster. Inclui também uma ramificação de processamento que se associa ao nó principal, pelo que o mesmo comando continua a funcionar se dimensionar a tarefa para vários nós.

Passo 2: Defina o driver de inferência por lote

batch_inference.py constrói um conjunto de dados Ray de prompts, configura um processador vLLM com ray.data.llm, e escreve os resultados. concurrency é o número de réplicas vLLM que Ray Data executa em paralelo. Defini-lo para o número de GPUs do cluster dá uma réplica por GPU, pelo que os prompts são processados em simultâneo em todas as GPUs e o exemplo escala à medida que se adicionam nós:

from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))

config = vLLMEngineProcessorConfig(
    model_source="Qwen/Qwen2.5-7B-Instruct",
    engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
    concurrency=total_gpus,   # one vLLM replica per GPU in the cluster
    batch_size=64,
)

processor = build_llm_processor(
    config,
    preprocess=lambda row: dict(
        messages=[{"role": "user", "content": row["instruction"]}],
        sampling_params=dict(max_tokens=256, temperature=0.7),
    ),
    postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)

out = processor(ds)       # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)

preprocess transforma cada linha de entrada num pedido de chat e postprocess mantém as colunas para persistir. O Ray Data adiciona uma generated_text coluna com a saída do modelo. O script completo está em script completo do controlador no final desta página.

Para modelos maiores, defina tensor_parallel_size para fragmentar uma réplica por várias GPUs e divida total_gpus por esse valor para que as réplicas continuem a ocupar totalmente o cluster, por exemplo concurrency=total_gpus // 2 com tensor_parallel_size=2.

Passo 3: Submeter a execução

air run -f train.yaml --dry-run
air run -f train.yaml --watch

Passo 4: Inspecionar a pista

air get run <run-id>
air logs <run-id>

Os registos mostram o prompt do motor vLLM e o throughput de geração à medida que o batch é executado, depois uma Wrote <n> rows linha quando a saída é escrita.

Onde os resultados aparecem

O controlador escreve um conjunto de dados Parquet no volume OUTPUT_PATH, com uma coluna instruction e uma coluna output. Leia-o novamente com Spark ou pandas, por exemplo spark.read.parquet(OUTPUT_PATH).

Script completo do driver

A versão completa batch_inference.py para copiar e colar:

#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.

The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.

Uses a public model (no Hugging Face token required) so the example runs as-is.
"""

import os

import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")


def build_prompts():
    """Build a Ray Dataset of prompts from a public instruction dataset."""
    raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
    items = []
    for row in raw:
        instruction = row["instruction"]
        if row.get("input"):
            instruction = f"{instruction}\n\n{row['input']}"
        items.append({"instruction": instruction})
    return ray.data.from_items(items)


def main():
    ray.init(address="auto")
    # Derive replicas from the live cluster so the example scales when nodes are added.
    total_gpus = int(ray.cluster_resources().get("GPU", 0))
    print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)

    ds = build_prompts()

    # vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
    # one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
    config = vLLMEngineProcessorConfig(
        model_source=MODEL_SOURCE,
        engine_kwargs={
            "max_model_len": 4096,
            "tensor_parallel_size": 1,
            "enable_chunked_prefill": True,
        },
        concurrency=total_gpus,
        batch_size=64,
    )

    # preprocess maps each input row to a chat request; postprocess keeps the columns
    # we want to persist. ray.data.llm adds a `generated_text` column.
    processor = build_llm_processor(
        config,
        preprocess=lambda row: dict(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": row["instruction"]},
            ],
            sampling_params=dict(max_tokens=256, temperature=0.7),
        ),
        postprocess=lambda row: dict(
            instruction=row["instruction"],
            output=row["generated_text"],
        ),
    )

    # materialize once so the write and the sample print don't re-run inference.
    out = processor(ds).materialize()
    out.write_parquet(OUTPUT_PATH)
    print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)

    for row in out.take(2):
        print("INSTRUCTION:", row["instruction"][:120], flush=True)
        print("OUTPUT:", row["output"][:200], flush=True)

    ray.shutdown()


if __name__ == "__main__":
    main()

Passos seguintes