Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/layer.py
2026-02-04 17:22:39 +08:00

119 lines
4.5 KiB
Python

"""Attention layer."""
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.layer import Attention
from vllm_mlu.attention.selector import vllm__attention__selector__get_attn_backend as get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
'''
=============================
Modify by vllm_mlu
=============================
@brief: add a arg use_mla for function get_attn_backend, _cached_get_attn_backend,
which_attn_to_use
'''
'''
==================
End of MLU Hijack
==================
'''
def vllm__attention__layer__Attention__init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
use_mla: bool = False,
prefix: str = "",
) -> None:
super(Attention, self).__init__()
self.use_mla = use_mla
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free,
blocksparse_params is not None,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
def vllm__attention__layer__Attention__forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
use_mla=self.use_mla)
MluHijackObject.apply_hijack(Attention,
Attention.__init__,
vllm__attention__layer__Attention__init__)
MluHijackObject.apply_hijack(Attention,
Attention.forward,
vllm__attention__layer__Attention__forward)