diff --git a/.github/workflows/format_pr_body.yaml b/.github/workflows/format_pr_body.yaml index a95dcc6f..45f0fc95 100644 --- a/.github/workflows/format_pr_body.yaml +++ b/.github/workflows/format_pr_body.yaml @@ -36,7 +36,7 @@ jobs: - name: Get vLLM version run: | - VLLM_COMMIT=83f478bb19489b41e9d208b47b4bb5a95ac171ac + VLLM_COMMIT=2918c1b49c88c29783c86f78d2c4221cb9622379 echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV - name: Checkout repository diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 619e8715..ecfa83a6 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -42,7 +42,7 @@ jobs: lint: uses: ./.github/workflows/pre-commit.yml with: - vllm: 83f478bb19489b41e9d208b47b4bb5a95ac171ac + vllm: 2918c1b49c88c29783c86f78d2c4221cb9622379 changes: runs-on: ubuntu-latest outputs: @@ -83,7 +83,7 @@ jobs: VLLM_USE_MODELSCOPE: True strategy: matrix: - vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] + vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0] steps: - name: Install packages run: | @@ -138,7 +138,7 @@ jobs: name: e2e-light strategy: matrix: - vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] + vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0] # Note (yikun): If CI resource are limited we can split job into two chain jobs needs: [lint, changes] # only trigger e2e test after lint passed and the change is e2e related with pull request. diff --git a/.github/workflows/vllm_ascend_test_full.yaml b/.github/workflows/vllm_ascend_test_full.yaml index e16b7619..ec5fb344 100644 --- a/.github/workflows/vllm_ascend_test_full.yaml +++ b/.github/workflows/vllm_ascend_test_full.yaml @@ -69,7 +69,7 @@ jobs: name: e2e-full strategy: matrix: - vllm_version: [83f478bb19489b41e9d208b47b4bb5a95ac171ac, v0.11.0] + vllm_version: [2918c1b49c88c29783c86f78d2c4221cb9622379, v0.11.0] needs: [changes] if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} uses: ./.github/workflows/_e2e_test.yaml diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md index 4aa79832..f4ee66df 100644 --- a/docs/source/community/versioning_policy.md +++ b/docs/source/community/versioning_policy.md @@ -42,7 +42,7 @@ The table below is the release compatibility matrix for vLLM Ascend release. For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly. | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | |-------------|--------------|------------------|-------------|--------------------| -| main | v0.11.0/83f478bb19489b41e9d208b47b4bb5a95ac171ac | >= 3.10, < 3.12 | 8.3.RC1 | 2.7.1 / 2.7.1 | +| main | v0.11.0/2918c1b49c88c29783c86f78d2c4221cb9622379 | >= 3.10, < 3.12 | 8.3.RC1 | 2.7.1 / 2.7.1 | ## Release cadence diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index c3c5d462..bd3192b5 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -8,6 +8,9 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from tests.ut.base import TestBase from vllm_ascend.utils import vllm_version_is +init_cached_hf_modules_path = "vllm.utils.init_cached_hf_modules" if vllm_version_is( + "0.11.0") else "vllm.utils.import_utils.init_cached_hf_modules" + class TestNPUWorker(TestBase): @@ -53,7 +56,7 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.try_register_lib") - @patch("vllm.utils.init_cached_hf_modules") + @patch(init_cached_hf_modules_path) @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") def test_init_npu_worker_normal_case( self, @@ -115,7 +118,7 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.try_register_lib") - @patch("vllm.utils.init_cached_hf_modules") + @patch(init_cached_hf_modules_path) @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") def test_init_npu_worker_with_trust_remote_code( self, @@ -160,7 +163,7 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.init_ascend_config") @patch("vllm_ascend.worker.worker_v1.init_ascend_soc_version") @patch("vllm_ascend.worker.worker_v1.try_register_lib") - @patch("vllm.utils.init_cached_hf_modules") + @patch(init_cached_hf_modules_path) @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") def test_init_npu_worker_with_custom_cache_dtype( self, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index a0cc20f0..d18f5a3b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -31,7 +31,14 @@ from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2fa60ca8..adb19cce 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,7 +22,14 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.utils import cdiv, round_down + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv, round_down +else: + from vllm.utils.math_utils import cdiv, round_down + from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend import envs diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index d77605d9..5f02567f 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -22,7 +22,14 @@ from vllm.config import VllmConfig from vllm.distributed.kv_events import KVEventBatch from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 36c820b0..4525fe16 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -9,7 +9,15 @@ from typing import Iterable, List, Optional, Tuple, Union import torch from vllm.distributed.kv_transfer.kv_connector.v1.base import \ KVConnectorMetadata -from vllm.utils import cdiv, logger +from vllm.utils import logger + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.v1.core.sched.output import NewRequestData DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 9ccfa43f..6f07afdc 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -42,6 +42,7 @@ from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz, vllm_version_is) @@ -536,7 +537,11 @@ class AscendQwen2_5_VLForConditionalGeneration( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + if vllm_version_is("0.11.0"): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + else: + with set_ascend_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -553,7 +558,13 @@ class AscendQwen2_5_VLForConditionalGeneration( else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if vllm_version_is("0.11.0"): + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw) + else: + with set_ascend_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/sigmoid_gating.py index c99799c0..39e653a5 100644 --- a/vllm_ascend/ops/sigmoid_gating.py +++ b/vllm_ascend/ops/sigmoid_gating.py @@ -10,9 +10,7 @@ # mypy: ignore-errors import os -from typing import Optional -import torch from vllm.triton_utils import tl, tldevice, triton if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': @@ -77,6 +75,147 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t + p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t + p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t + + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t + else: + p_beta = beta + bos * HV + i_hv + HV * i_t + + if not IS_KDA: + p_g = g + bos * HV + i_hv + HV * i_t + else: + p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t + + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + # b_h *= tl.exp(b_g) + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV @@ -159,226 +298,3 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - - -def fused_recurrent_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, v.shape[-1] - HV = v.shape[2] - N = B if cu_seqlens is None else len(cu_seqlens) - 1 - BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) - NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) - assert NK == 1, "NK > 1 is not supported yet" - num_stages = 3 - num_warps = 1 - - o = q.new_empty(NK, *v.shape) - if inplace_final_state: - final_state = initial_state - else: - final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) - - stride_init_state_token = initial_state.stride(0) - stride_final_state_token = final_state.stride(0) - - if ssm_state_indices is None: - stride_indices_seq, stride_indices_tok = 1, 1 - elif ssm_state_indices.ndim == 1: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 - else: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - - # print("N: ", N) - # print("T: ", T) - # print("B: ", B) - # print("H: ", H) - # print("HV: ", HV) - # print("K: ", K) - # print("V: ", V) - # print("BK: ", BK) - # print("BV: ", BV) - - grid = (NK, NV, N * HV) - fused_recurrent_gated_delta_rule_fwd_kernel[grid]( - q=q, - k=k, - v=v, - g=g, - beta=beta, - o=o, - h0=initial_state, - ht=final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - scale=scale, - N=N, - T=T, - B=B, - H=H, - HV=HV, - K=K, - V=V, - BK=BK, - BV=BV, - stride_init_state_token=stride_init_state_token, - stride_final_state_token=stride_final_state_token, - stride_indices_seq=stride_indices_seq, - stride_indices_tok=stride_indices_tok, - IS_BETA_HEADWISE=beta.ndim == v.ndim, - USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, - INPLACE_FINAL_STATE=inplace_final_state, - num_warps=num_warps, - num_stages=num_stages, - ) - o = o.squeeze(0) - return o, final_state - - -class FusedRecurrentFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False): - o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q.contiguous(), - k=k.contiguous(), - v=v.contiguous(), - g=g.contiguous(), - beta=beta.contiguous(), - scale=scale, - initial_state=initial_state, - inplace_final_state=inplace_final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - - return o, final_state - - -def fused_recurrent_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor = None, - scale: float = None, - initial_state: torch.Tensor = None, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - r""" - Args: - q (torch.Tensor): - queries of shape `[B, T, H, K]`. - k (torch.Tensor): - keys of shape `[B, T, H, K]`. - v (torch.Tensor): - values of shape `[B, T, HV, V]`. - GVA is applied if `HV > H`. - g (torch.Tensor): - g (decays) of shape `[B, T, HV]`. - beta (torch.Tensor): - betas of shape `[B, T, HV]`. - scale (Optional[int]): - Scale factor for the RetNet attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, HV, K, V]` for `N` input sequences. - For equal-length input sequences, `N` equals the batch size `B`. - Default: `None`. - inplace_final_state: bool: - Whether to store the final state in-place to save memory. - Default: `True`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - ssm_state_indices (Optional[torch.Tensor]): - Indices to map the input sequences to the initial/final states. - num_accepted_tokens (Optional[torch.Tensor]): - Number of accepted tokens for each sequence during decoding. - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, HV, V]`. - final_state (torch.Tensor): - Final state of shape `[N, HV, K, V]`. - Examples:: - >>> import torch - >>> import torch.nn.functional as F - >>> from einops import rearrange - >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 - >>> q = torch.randn(B, T, H, K, device='cuda') - >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) - >>> v = torch.randn(B, T, HV, V, device='cuda') - >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) - >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() - >>> h0 = torch.randn(B, HV, K, V, device='cuda') - >>> o, ht = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - cu_seqlens=cu_seqlens - ) - """ - if cu_seqlens is not None and q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") - if scale is None: - scale = k.shape[-1]**-0.5 - else: - assert scale > 0, "scale must be positive" - if beta is None: - beta = torch.ones_like(q[..., 0]) - o, final_state = FusedRecurrentFunction.apply( - q, - k, - v, - g, - beta, - scale, - initial_state, - inplace_final_state, - cu_seqlens, - ssm_state_indices, - num_accepted_tokens, - use_qk_l2norm_in_kernel, - ) - return o, final_state \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index 1b077b41..1c35106e 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -3,7 +3,14 @@ import vllm.model_executor.models.config from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.config import MambaModelConfig -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from vllm_ascend.utils import vllm_version_is diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index cc550ccc..0383da9e 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -6,11 +6,16 @@ import vllm.model_executor.layers.mamba.ops.causal_conv1d from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, causal_conv1d_update_npu) from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule -from vllm_ascend.ops.sigmoid_gating import \ - fused_recurrent_gated_delta_rule_fwd_kernel +from vllm_ascend.ops.sigmoid_gating import ( + fused_recurrent_gated_delta_rule_fwd_kernel, + fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0) +from vllm_ascend.utils import vllm_version_is vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn -vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel +if vllm_version_is('0.11.0'): + vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0 +else: + vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 362c6148..627411fe 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -15,7 +15,14 @@ from vllm.model_executor.model_loader.utils import \ process_weights_after_loading from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 3faf28f8..f67a0ff0 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -670,6 +670,8 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + q_b_proj=self.q_b_proj + if self.q_lora_rank is not None else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 730adbda..a524a3bb 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -26,7 +26,13 @@ from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendAttentionMetadataBuilder, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 51becad9..116b124e 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -13,7 +13,13 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.utils import cdiv, round_down + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv, round_down +else: + from vllm.utils.math_utils import cdiv, round_down import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 1390aee3..12b8d07a 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -14,7 +14,13 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.utils import cdiv, round_down + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv, round_down +else: + from vllm.utils.math_utils import cdiv, round_down import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 6c35fcfa..c33a8afa 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -3,7 +3,13 @@ from typing import Optional, Union import numpy as np import torch from vllm.distributed import get_dcp_group -from vllm.utils import cdiv + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv from vllm_ascend.utils import prefill_context_parallel_enable diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 729cef1b..bf013c28 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -72,7 +72,15 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds +from vllm.utils import length_from_prompt_token_ids_or_embeds + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import cdiv +else: + from vllm.utils.math_utils import cdiv + from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 145f38a1..58ac27a0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -142,7 +142,11 @@ class NPUWorker(WorkerBase): if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules + if vllm_version_is("0.11.0"): + from vllm.utils import init_cached_hf_modules + else: + from vllm.utils.import_utils import init_cached_hf_modules + init_cached_hf_modules() self.profiler = self._init_profiler()