[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 17:41:34 +08:00
parent db7f48eeac
commit ce688181e6
33 changed files with 4835 additions and 1192 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import List, Optional
import torch
import torch.nn as nn
@@ -9,19 +9,49 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(
100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
class Attention(nn.Module):
@@ -45,13 +75,13 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
@@ -80,6 +110,9 @@ class Attention(nn.Module):
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
@@ -105,6 +138,7 @@ class Attention(nn.Module):
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
@@ -126,19 +160,23 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
blocksparse_params is not None,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla,
has_sink=self.has_sink)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@@ -148,7 +186,7 @@ class Attention(nn.Module):
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_output = attn_backend.accept_output_buffer
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -206,7 +244,7 @@ class Attention(nn.Module):
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.empty(output_shape,
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
@@ -274,6 +312,9 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
@@ -291,7 +332,9 @@ class MultiHeadAttention(nn.Module):
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
dtype = torch.get_default_dtype()
@@ -301,12 +344,21 @@ class MultiHeadAttention(nn.Module):
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS
if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA
def forward(
self,
@@ -430,6 +482,7 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
@@ -444,7 +497,8 @@ def unified_attention_with_output(
value,
kv_cache,
attn_metadata,
output=output)
output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@@ -455,6 +509,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return

View File

@@ -3,8 +3,9 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import Generator, Optional, Type
from typing import Generator, Optional, Type, Union
import torch
@@ -79,15 +80,72 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size",
None):
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
is_attention_free: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
@@ -99,9 +157,9 @@ def get_attn_backend(
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
)
@@ -112,16 +170,10 @@ def _cached_get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
@@ -144,11 +196,15 @@ def _cached_get_attn_backend(
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}")
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla)
use_mla, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")