[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,19 +9,49 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
|
||||
|
||||
def check_xformers_availability():
|
||||
global USE_XFORMERS_OPS
|
||||
if USE_XFORMERS_OPS is not None:
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||
100):
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
try:
|
||||
from importlib.util import find_spec
|
||||
|
||||
find_spec("xformers.ops")
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
# the warning only needs to be shown once
|
||||
if not USE_XFORMERS_OPS:
|
||||
logger.warning("Xformers is not available, falling back.")
|
||||
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -45,13 +75,13 @@ class Attention(nn.Module):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -80,6 +110,9 @@ class Attention(nn.Module):
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
@@ -105,6 +138,7 @@ class Attention(nn.Module):
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
@@ -126,19 +160,23 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@@ -148,7 +186,7 @@ class Attention(nn.Module):
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@@ -206,7 +244,7 @@ class Attention(nn.Module):
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
@@ -274,6 +312,9 @@ class Attention(nn.Module):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
@@ -291,7 +332,9 @@ class MultiHeadAttention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
@@ -301,12 +344,21 @@ class MultiHeadAttention(nn.Module):
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
if current_platform.is_rocm():
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -430,6 +482,7 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@@ -444,7 +497,8 @@ def unified_attention_with_output(
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output=output,
|
||||
output_scale=output_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
@@ -455,6 +509,7 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user