forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -15,8 +15,18 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
import vllm_ascend.patch.worker.patch_common.patch_triton
|
||||
|
||||
# isort: off
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
||||
|
||||
# TODO: revert me when triton import is fixed
|
||||
# import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
|
||||
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
AttentionLayerBase.__init__(self)
|
||||
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.query_quant = None
|
||||
|
||||
|
||||
vllm.attention.Attention = AscendAttention
|
||||
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# mypy: ignore-errors
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.2"):
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
from vllm.attention.backends.placeholder_attn import \
|
||||
PlaceholderAttentionBackend
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
else:
|
||||
|
||||
def get_attn_backend( # type: ignore[misc]
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
vllm.attention.get_attn_backend = get_attn_backend
|
||||
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend
|
||||
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from typing_extensions import Self
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
|
||||
spec_manager_map)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
use_sfa: bool
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
sfa_bytes = 128 * self.block_size * get_dtype_size(
|
||||
self.dtype) if self.use_sfa else 0
|
||||
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype) + sfa_bytes
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size.")
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullAttentionSpec.")
|
||||
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
use_sfa=specs[0].use_sfa,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec.")
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
), ("Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
|
||||
|
||||
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec
|
||||
@@ -1,147 +0,0 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""
|
||||
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||
that overrides the forward method to handle Ascend-specific operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
tp_group = get_tp_group().device_group
|
||||
hcomm_info = self.get_hcomm_info(tp_group)
|
||||
self.hcomm_info = hcomm_info
|
||||
super().__init__(*args, **kwargs)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||
|
||||
Returns:
|
||||
str: The HCCL communication name for the given group.
|
||||
"""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Forward pass for the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
The output tensor after applying the linear transformation,
|
||||
and optionally the bias if `return_bias` is True.
|
||||
"""
|
||||
input_parallel = self.calc_input(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
output = self.calc_output(input_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the input tensor for parallel processing.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to be processed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor split along the last dimension
|
||||
for tensor model parallelism, or the original input if not parallel.
|
||||
"""
|
||||
if self.input_is_parallel:
|
||||
return input_
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
return splitted_input[tp_rank].contiguous()
|
||||
|
||||
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation.
|
||||
|
||||
Args:
|
||||
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after applying the linear transformation
|
||||
and optionally handle communication between tensor model parallel ranks.
|
||||
"""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
return output
|
||||
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
|
||||
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
|
||||
from vllm.lora.utils import _all_lora_classes
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
|
||||
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes = _all_lora_classes
|
||||
16
vllm_ascend/patch/worker/patch_common/patch_triton.py
Normal file
16
vllm_ascend/patch/worker/patch_common/patch_triton.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import vllm.model_executor.layers.fla.ops.chunk
|
||||
import vllm.model_executor.layers.fla.ops.fused_recurrent
|
||||
import vllm.model_executor.layers.fla.ops.layernorm_guard
|
||||
import vllm.model_executor.layers.mamba.ops.causal_conv1d
|
||||
|
||||
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
|
||||
causal_conv1d_update_npu)
|
||||
from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule
|
||||
from vllm_ascend.ops.sigmoid_gating import \
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
|
||||
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
|
||||
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
|
||||
44
vllm_ascend/patch/worker/patch_common/patch_weight_loader.py
Normal file
44
vllm_ascend/patch/worker/patch_common/patch_weight_loader.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
# This method creates unquantized linear weights.
|
||||
# The weights are not quantized, and they are not sharded.
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
try:
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to create unquantized linear weights. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"the weight.") from e
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
|
||||
if not vllm_version_is("0.10.2"):
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
UnquantizedLinearMethod.create_weights = create_weights
|
||||
Reference in New Issue
Block a user