Upgrade to 0.11.1 newest vllm commit (#3982)

### What this PR does / why we need it?
adapt vllm-ascend main branch with vllm releases/v0.11.1

fix `forward context not set` in test_vlm.py caused by:
https://github.com/vllm-project/vllm/pull/23207

fix import `cdiv round` failed caused by:
https://github.com/vllm-project/vllm/pull/27188

fix import `init_cached_hf_modules` failed caused by:
https://github.com/vllm-project/vllm/pull/27567

adapt triton kernel `fused_recurrent_gated_delta_rule_fwd_kernel` caused
by: https://github.com/vllm-project/vllm/pull/27654
- remove unused code in sigmoid_gating.py
- `class FusedRecurrentFunction` , `fused_recurrent_gated_delta_rule`,
`fused_recurrent_gated_delta_rule_fwd`

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI 


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-11-12 23:01:19 +08:00
committed by GitHub
parent 3ca11d5a7c
commit c272747d13
21 changed files with 267 additions and 251 deletions

View File

@@ -31,7 +31,14 @@ from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -22,7 +22,14 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs

View File

@@ -22,7 +22,14 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler

View File

@@ -9,7 +9,15 @@ from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import cdiv, logger
from vllm.utils import logger
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import NewRequestData
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB

View File

@@ -42,6 +42,7 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz,
vllm_version_is)
@@ -536,7 +537,11 @@ class AscendQwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
if vllm_version_is("0.11.0"):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
else:
with set_ascend_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
@@ -553,7 +558,13 @@ class AscendQwen2_5_VLForConditionalGeneration(
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
if vllm_version_is("0.11.0"):
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)
else:
with set_ascend_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size

View File

@@ -10,9 +10,7 @@
# mypy: ignore-errors
import os
from typing import Optional
import torch
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
@@ -77,6 +75,147 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_init_state_token
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
if IS_BETA_HEADWISE:
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
else:
p_beta = beta + bos * HV + i_hv + HV * i_t
if not IS_KDA:
p_g = g + bos * HV + i_hv + HV * i_t
else:
p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
# b_h *= tl.exp(b_g)
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta
# [BK, BV]
b_h += b_k[:, None] * b_v[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_final_state_token
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@triton.heuristics({
'USE_INITIAL_STATE':
lambda args: args['h0'] is not None,
'IS_VARLEN':
lambda args: args['cu_seqlens'] is not None,
"IS_CONTINUOUS_BATCHING":
lambda args: args['ssm_state_indices'] is not None,
"IS_SPEC_DECODING":
lambda args: args['num_accepted_tokens'] is not None,
})
@triton.jit(do_not_specialize=['N', 'T'])
def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0(
q,
k,
v,
g,
beta,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.constexpr, # num of sequences
T: tl.constexpr, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.
constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
@@ -159,226 +298,3 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
# print("N: ", N)
# print("T: ", T)
# print("B: ", B)
# print("H: ", H)
# print("HV: ", HV)
# print("K: ", K)
# print("V: ", V)
# print("BK: ", BK)
# print("BV: ", BV)
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
inplace_final_state,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -3,7 +3,14 @@ import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.utils import vllm_version_is

View File

@@ -6,11 +6,16 @@ import vllm.model_executor.layers.mamba.ops.causal_conv1d
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
causal_conv1d_update_npu)
from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule
from vllm_ascend.ops.sigmoid_gating import \
fused_recurrent_gated_delta_rule_fwd_kernel
from vllm_ascend.ops.sigmoid_gating import (
fused_recurrent_gated_delta_rule_fwd_kernel,
fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0)
from vllm_ascend.utils import vllm_version_is
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
if vllm_version_is('0.11.0'):
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0
else:
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule

View File

@@ -15,7 +15,14 @@ from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput

View File

@@ -670,6 +670,8 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
if self.q_lora_rank is not None else None,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
q_b_proj=self.q_b_proj
if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,

View File

@@ -26,7 +26,13 @@ from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,

View File

@@ -13,7 +13,13 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config

View File

@@ -14,7 +14,13 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv, round_down
else:
from vllm.utils.math_utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config

View File

@@ -3,7 +3,13 @@ from typing import Optional, Union
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.utils import cdiv
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm_ascend.utils import prefill_context_parallel_enable

View File

@@ -72,7 +72,15 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import cdiv
else:
from vllm.utils.math_utils import cdiv
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (

View File

@@ -142,7 +142,11 @@ class NPUWorker(WorkerBase):
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
if vllm_version_is("0.11.0"):
from vllm.utils import init_cached_hf_modules
else:
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
self.profiler = self._init_profiler()