Skip to content

vllm.model_executor.warmup.kernel_warmup

Warmup kernels used during model execution. This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution.

Functions:

_warmup_triton_nvfp4_prefill_kernels(runner)

Warm NVFP4 pure-prefill Triton kernels missed by dummy runs.

The NVFP4 Triton path can bypass the paged cache for pure prefill and call context_attention_fwd directly. Hybrid models may have several Triton prefill specializations, for example full and sliding-window attention with different head sizes. Use tiny synthetic tensors with the real layer shapes so those variants compile before the JIT monitor is enabled.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def _warmup_triton_nvfp4_prefill_kernels(runner: "GPUModelRunner") -> None:
    """Warm NVFP4 pure-prefill Triton kernels missed by dummy runs.

    The NVFP4 Triton path can bypass the paged cache for pure prefill and call
    `context_attention_fwd` directly. Hybrid models may have several Triton
    prefill specializations, for example full and sliding-window attention with
    different head sizes. Use tiny synthetic tensors with the real layer shapes
    so those variants compile before the JIT monitor is enabled.
    """
    from vllm.config import get_layers_from_vllm_config
    from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
    from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd

    warmup_tokens = min(runner.max_num_tokens, runner.max_model_len, 64)
    if warmup_tokens <= 0:
        return

    b_start_loc = torch.zeros((1,), dtype=torch.int32, device=runner.device)
    b_seq_len = torch.full((1,), warmup_tokens, dtype=torch.int32, device=runner.device)

    seen: set[tuple] = set()
    for groups in runner.attn_groups:
        for group in groups:
            if not _is_attention_backend(group.backend, "TRITON_ATTN"):
                continue

            layer_names = getattr(group, "layer_names", ())
            if not layer_names:
                continue

            layer_type = cast(type[Any], AttentionLayerBase)
            layers = get_layers_from_vllm_config(
                runner.vllm_config,
                layer_type,
                layer_names,
            )
            for layer_name in layer_names:
                layer = layers.get(layer_name)
                if layer is None:
                    continue

                impl = cast(Any, layer.impl)
                if (
                    getattr(impl, "kv_cache_dtype", None) != "nvfp4"
                    or getattr(impl, "kv_sharing_target_layer_name", None) is not None
                    or getattr(impl, "alibi_slopes", None) is not None
                    or getattr(impl, "use_alibi_sqrt", False)
                    or getattr(impl, "sinks", None) is not None
                    or getattr(impl, "chunk_lookback", -1) != -1
                ):
                    continue

                sliding_window = getattr(impl, "sliding_window", (-1, -1))
                key = (
                    impl.num_heads,
                    impl.num_kv_heads,
                    impl.head_size,
                    impl.scale,
                    impl.logits_soft_cap,
                    sliding_window,
                    runner.dtype,
                )
                if key in seen:
                    continue
                seen.add(key)

                q = torch.zeros(
                    (warmup_tokens, impl.num_heads, impl.head_size),
                    dtype=runner.dtype,
                    device=runner.device,
                )
                k = torch.zeros(
                    (warmup_tokens, impl.num_kv_heads, impl.head_size),
                    dtype=runner.dtype,
                    device=runner.device,
                )
                v = torch.zeros_like(k)
                out = torch.empty_like(q)

                context_attention_fwd(
                    q=q,
                    k=k,
                    v=v,
                    o=out,
                    b_start_loc=b_start_loc,
                    b_seq_len=b_seq_len,
                    max_input_len=warmup_tokens,
                    is_causal=True,
                    softmax_scale=impl.scale,
                    softcap=impl.logits_soft_cap,
                    sliding_window_q=sliding_window[0],
                    sliding_window_k=sliding_window[1],
                )

flashinfer_autotune(runner)

Autotune FlashInfer operations. FlashInfer have many implementations for the same operation, autotuning runs benchmarks for each implementation and stores the results. The results are cached transparently and future calls to FlashInfer will use the best implementation. Without autotuning, FlashInfer will rely on heuristics, which may be significantly slower.

Tuning is performed only on rank 0. The resulting cache is broadcast to every rank so all ranks dispatch the same kernel tactic.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
    """
    Autotune FlashInfer operations.
    FlashInfer have many implementations for the same operation,
    autotuning runs benchmarks for each implementation and stores
    the results. The results are cached transparently and
    future calls to FlashInfer will use the best implementation.
    Without autotuning, FlashInfer will rely on heuristics, which may
    be significantly slower.

    Tuning is performed only on rank 0. The resulting cache is broadcast
    to every rank so all ranks dispatch the same kernel tactic.
    """
    import vllm.utils.flashinfer as fi_utils
    from vllm.distributed.parallel_state import get_world_group

    if not _FLASHINFER_USE_PERSISTENT_CACHE:
        with torch.inference_mode(), fi_utils.autotune():
            runner._dummy_run(
                num_tokens=runner.scheduler_config.max_num_batched_tokens,
                skip_eplb=True,
                is_profile=True,
            )
        get_world_group().barrier()
        return

    world = get_world_group()
    is_leader = world.rank_in_group == 0

    cache_path = _resolve_flashinfer_autotune_file(runner)
    if is_leader:
        logger.info("Using FlashInfer autotune cache file: %s", cache_path)

    # We skip EPLB here since we don't want to record dummy metrics.
    # When autotuning with number of tokens m, flashinfer will autotune
    # operations for all number of tokens up to m, so we only need to
    # run with the max number of tokens.
    dummy_run_kwargs = dict(
        num_tokens=runner.scheduler_config.max_num_batched_tokens,
        skip_eplb=True,
        is_profile=True,
    )

    with torch.inference_mode():
        if is_leader:
            with fi_utils.autotune(tune_mode=True, cache=str(cache_path)):
                runner._dummy_run(**dummy_run_kwargs)
        else:
            runner._dummy_run(**dummy_run_kwargs)

    # Broadcast autotune cache from rank 0 to all other ranks so every
    # rank loads the same set of chosen tactics.
    tune_results: bytes | None = None
    if is_leader and cache_path.exists():
        with open(cache_path, "rb") as f:
            tune_results = f.read()

    tune_results = world.broadcast_object(tune_results, src=0)

    if tune_results is None:
        logger.warning(
            "No FlashInfer autotune cache entries found."
            "Falling back to default tactics."
        )
    else:
        if not is_leader and world.local_rank == 0:
            with open(cache_path, "wb") as f:
                f.write(tune_results)
        world.barrier()
        from flashinfer.autotuner import AutoTuner

        AutoTuner.get().load_configs(str(cache_path))
        logger.info(
            "FlashInfer autotune cache loaded on rank %d from %s.",
            world.rank_in_group,
            cache_path,
        )