Upgrade to 0.11.1 newest vllm commit (#3982)

### What this PR does / why we need it?
adapt vllm-ascend main branch with vllm releases/v0.11.1

fix `forward context not set` in test_vlm.py caused by:
https://github.com/vllm-project/vllm/pull/23207

fix import `cdiv round` failed caused by:
https://github.com/vllm-project/vllm/pull/27188

fix import `init_cached_hf_modules` failed caused by:
https://github.com/vllm-project/vllm/pull/27567

adapt triton kernel `fused_recurrent_gated_delta_rule_fwd_kernel` caused
by: https://github.com/vllm-project/vllm/pull/27654
- remove unused code in sigmoid_gating.py
- `class FusedRecurrentFunction` , `fused_recurrent_gated_delta_rule`,
`fused_recurrent_gated_delta_rule_fwd`

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI 


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-11-12 23:01:19 +08:00
committed by GitHub
parent 3ca11d5a7c
commit c272747d13
21 changed files with 267 additions and 251 deletions

View File

@@ -36,7 +36,7 @@ jobs:
- name: Get vLLM version - name: Get vLLM version
run: | run: |
VLLM_COMMIT=83f478bb19489b41e9d208b47b4bb5a95ac171ac VLLM_COMMIT=2918c1b49c88c29783c86f78d2c4221cb9622379
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV
- name: Checkout repository - name: Checkout repository

View File

@@ -42,7 +42,7 @@ jobs:
lint: lint:
uses: ./.github/workflows/pre-commit.yml uses: ./.github/workflows/pre-commit.yml
with: with:
vllm: 83f478bb19489b41e9d208b47b4bb5a95ac171ac vllm: 2918c1b49c88c29783c86f78d2c4221cb9622379
changes: changes:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
@@ -83,7 +83,7 @@ jobs:
VLLM_USE_MODELSCOPE: True VLLM_USE_MODELSCOPE: True
strategy: strategy:
matrix: matrix:
vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0]
steps: steps:
- name: Install packages - name: Install packages
run: | run: |
@@ -138,7 +138,7 @@ jobs:
name: e2e-light name: e2e-light
strategy: strategy:
matrix: matrix:
vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs # Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes] needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request. # only trigger e2e test after lint passed and the change is e2e related with pull request.

View File

@@ -69,7 +69,7 @@ jobs:
name: e2e-full name: e2e-full
strategy: strategy:
matrix: matrix:
vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0]
needs: [changes] needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml uses: ./.github/workflows/_e2e_test.yaml

View File

@@ -42,7 +42,7 @@ The table below is the release compatibility matrix for vLLM Ascend release.
For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly. For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly.
| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
|-------------|--------------|------------------|-------------|--------------------| |-------------|--------------|------------------|-------------|--------------------|
| main | v0.11.0/83f478bb19489b41e9d208b47b4bb5a95ac171ac | >= 3.10, < 3.12 | 8.3.RC1 | 2.7.1 / 2.7.1 | | main | v0.11.0/2918c1b49c88c29783c86f78d2c4221cb9622379 | >= 3.10, < 3.12 | 8.3.RC1 | 2.7.1 / 2.7.1 |
## Release cadence ## Release cadence

View File

@@ -8,6 +8,9 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
init_cached_hf_modules_path = "vllm.utils.init_cached_hf_modules" if vllm_version_is(
"0.11.0") else "vllm.utils.import_utils.init_cached_hf_modules"
class TestNPUWorker(TestBase): class TestNPUWorker(TestBase):
@@ -53,7 +56,7 @@ class TestNPUWorker(TestBase):
@patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_config")
@patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version")
@patch("vllm_ascend.worker.worker_v1.try_register_lib") @patch("vllm_ascend.worker.worker_v1.try_register_lib")
@patch("vllm.utils.init_cached_hf_modules") @patch(init_cached_hf_modules_path)
@patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler")
def test_init_npu_worker_normal_case( def test_init_npu_worker_normal_case(
self, self,
@@ -115,7 +118,7 @@ class TestNPUWorker(TestBase):
@patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_config")
@patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version")
@patch("vllm_ascend.worker.worker_v1.try_register_lib") @patch("vllm_ascend.worker.worker_v1.try_register_lib")
@patch("vllm.utils.init_cached_hf_modules") @patch(init_cached_hf_modules_path)
@patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler")
def test_init_npu_worker_with_trust_remote_code( def test_init_npu_worker_with_trust_remote_code(
self, self,
@@ -160,7 +163,7 @@ class TestNPUWorker(TestBase):
@patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_config")
@patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version")
@patch("vllm_ascend.worker.worker_v1.try_register_lib") @patch("vllm_ascend.worker.worker_v1.try_register_lib")
@patch("vllm.utils.init_cached_hf_modules") @patch(init_cached_hf_modules_path)
@patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler")
def test_init_npu_worker_with_custom_cache_dtype( def test_init_npu_worker_with_custom_cache_dtype(
self, self,

View File

@@ -31,7 +31,14 @@ from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size) get_decode_context_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -22,7 +22,14 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs from vllm_ascend import envs

View File

@@ -22,7 +22,14 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler

View File

@@ -9,7 +9,15 @@ from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \ from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata KVConnectorMetadata
from vllm.utils import cdiv, logger from vllm.utils import logger
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB

View File

@@ -42,6 +42,7 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz,
vllm_version_is) vllm_version_is)
@@ -536,7 +537,11 @@ class AscendQwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else: else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw) if vllm_version_is("0.11.0"):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
else:
with set_ascend_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
@@ -553,7 +558,13 @@ class AscendQwen2_5_VLForConditionalGeneration(
else: else:
pixel_values_videos = video_input["pixel_values_videos"].type( pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype) self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) if vllm_version_is("0.11.0"):
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)
else:
with set_ascend_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size

View File

@@ -10,9 +10,7 @@
# mypy: ignore-errors # mypy: ignore-errors
import os import os
from typing import Optional
import torch
from vllm.triton_utils import tl, tldevice, triton from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
@@ -77,6 +75,147 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_VARLEN: tl.constexpr, IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_init_state_token
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
if IS_BETA_HEADWISE:
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
else:
p_beta = beta + bos * HV + i_hv + HV * i_t
if not IS_KDA:
p_g = g + bos * HV + i_hv + HV * i_t
else:
p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
# b_h *= tl.exp(b_g)
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta
# [BK, BV]
b_h += b_k[:, None] * b_v[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_final_state_token
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@triton.heuristics({
'USE_INITIAL_STATE':
lambda args: args['h0'] is not None,
'IS_VARLEN':
lambda args: args['cu_seqlens'] is not None,
"IS_CONTINUOUS_BATCHING":
lambda args: args['ssm_state_indices'] is not None,
"IS_SPEC_DECODING":
lambda args: args['num_accepted_tokens'] is not None,
})
@triton.jit(do_not_specialize=['N', 'T'])
def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0(
q,
k,
v,
g,
beta,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.constexpr, # num of sequences
T: tl.constexpr, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.
constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
): ):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV i_n, i_hv = i_nh // HV, i_nh % HV
@@ -159,226 +298,3 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
# print("N: ", N)
# print("T: ", T)
# print("B: ", B)
# print("H: ", H)
# print("HV: ", HV)
# print("K: ", K)
# print("V: ", V)
# print("BK: ", BK)
# print("BV: ", BV)
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
inplace_final_state,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -3,7 +3,14 @@ import vllm.model_executor.models.config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is

View File

@@ -6,11 +6,16 @@ import vllm.model_executor.layers.mamba.ops.causal_conv1d
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
causal_conv1d_update_npu) causal_conv1d_update_npu)
from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule
from vllm_ascend.ops.sigmoid_gating import \ from vllm_ascend.ops.sigmoid_gating import (
fused_recurrent_gated_delta_rule_fwd_kernel fused_recurrent_gated_delta_rule_fwd_kernel,
fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0)
from vllm_ascend.utils import vllm_version_is
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel if vllm_version_is('0.11.0'):
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0
else:
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule

View File

@@ -15,7 +15,14 @@ from vllm.model_executor.model_loader.utils import \
process_weights_after_loading process_weights_after_loading
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput

View File

@@ -670,6 +670,8 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
if self.q_lora_rank is not None else None, if self.q_lora_rank is not None else None,
q_proj=self.q_proj q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj, if self.q_lora_rank is None else self.q_b_proj,
q_b_proj=self.q_b_proj
if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm, kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj, kv_b_proj=self.kv_b_proj,

View File

@@ -26,7 +26,13 @@ from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType) AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder, AscendAttentionMetadataBuilder,

View File

@@ -13,7 +13,13 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config

View File

@@ -14,7 +14,13 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config

View File

@@ -3,7 +3,13 @@ from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group from vllm.distributed import get_dcp_group
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm_ascend.utils import prefill_context_parallel_enable from vllm_ascend.utils import prefill_context_parallel_enable

View File

@@ -72,7 +72,15 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (

View File

@@ -142,7 +142,11 @@ class NPUWorker(WorkerBase):
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules if vllm_version_is("0.11.0"):
from vllm.utils import init_cached_hf_modules
else:
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.profiler = self._init_profiler() self.profiler = self._init_profiler()