From 4fb3d5e1b2588e91d3c00a434e1eb3daae220a48 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Fri, 6 Feb 2026 15:25:08 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #8) (#6129) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | vllm_ascend/ops/\_\_init\_\_.py | | vllm_ascend/ops/activation.py | | vllm_ascend/ops/flashcomm2_oshard_manager.py | | vllm_ascend/ops/layernorm.py | | vllm_ascend/ops/mla.py | | vllm_ascend/ops/mm_encoder_attention.py | | vllm_ascend/ops/register_custom_ops.py | | vllm_ascend/ops/vocab_parallel_embedding.py | | vllm_ascend/ops/weight_prefetch.py | | vllm_ascend/spec_decode/\_\_init\_\_.py | | vllm_ascend/spec_decode/eagle_proposer.py | | vllm_ascend/spec_decode/interface.py | | vllm_ascend/spec_decode/mtp_proposer.py | | vllm_ascend/spec_decode/ngram_proposer.py | | vllm_ascend/spec_decode/suffix_proposer.py | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com> Signed-off-by: SILONG ZENG <2609716663@qq.com> --- pyproject.toml | 12 - vllm_ascend/ops/__init__.py | 29 +- vllm_ascend/ops/activation.py | 4 +- vllm_ascend/ops/flashcomm2_oshard_manager.py | 19 +- vllm_ascend/ops/layernorm.py | 77 +-- vllm_ascend/ops/mla.py | 47 +- vllm_ascend/ops/mm_encoder_attention.py | 33 +- vllm_ascend/ops/register_custom_ops.py | 238 ++++--- vllm_ascend/ops/vocab_parallel_embedding.py | 213 +++--- vllm_ascend/ops/weight_prefetch.py | 96 ++- vllm_ascend/spec_decode/__init__.py | 5 +- vllm_ascend/spec_decode/eagle_proposer.py | 677 ++++++++----------- vllm_ascend/spec_decode/interface.py | 47 +- vllm_ascend/spec_decode/medusa_proposer.py | 87 ++- vllm_ascend/spec_decode/mtp_proposer.py | 413 ++++++----- vllm_ascend/spec_decode/ngram_proposer.py | 49 +- vllm_ascend/spec_decode/suffix_proposer.py | 49 +- 17 files changed, 948 insertions(+), 1147 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 300492c7..b44570a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,18 +52,6 @@ line-length = 120 exclude = [ "tests/**", - # (8) - "vllm_ascend/ops/__init__.py", - "vllm_ascend/ops/activation.py", - "vllm_ascend/ops/flashcomm2_oshard_manager.py", - "vllm_ascend/ops/layernorm.py", - "vllm_ascend/ops/mla.py", - "vllm_ascend/ops/mm_encoder_attention.py", - "vllm_ascend/ops/register_custom_ops.py", - "vllm_ascend/ops/vocab_parallel_embedding.py", - "vllm_ascend/ops/weight_prefetch.py", - "vllm_ascend/spec_decode/**", - # (10) "vllm_ascend/ops/*linear*.py", "vllm_ascend/worker/worker.py", diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index aadb6164..46c0ceff 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -27,8 +27,7 @@ if HAS_TRITON: import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul -from vllm_ascend.ops.rotary_embedding import ( - AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) +from vllm_ascend.ops.rotary_embedding import AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding class dummyFusionOp: @@ -40,23 +39,13 @@ class dummyFusionOp: def register_dummy_fusion_op() -> None: torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") - torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( - name="fused_add_rms_norm") - torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( - name="static_scaled_fp8_quant") - torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( - name="dynamic_scaled_fp8_quant") - torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( - name="dynamic_per_token_scaled_fp8_quant") - torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( - name="rms_norm_static_fp8_quant") - torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( - name="fused_add_rms_norm_static_fp8_quant") - torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( - name="rms_norm_dynamic_per_token_quant") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(name="static_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(name="dynamic_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(name="dynamic_per_token_scaled_fp8_quant") + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(name="rms_norm_static_fp8_quant") + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(name="fused_add_rms_norm_static_fp8_quant") + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(name="rms_norm_dynamic_per_token_quant") -__all__ = [ - "AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", - "AscendDeepseekScalingRotaryEmbedding" -] +__all__ = ["AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", "AscendDeepseekScalingRotaryEmbedding"] diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index a22513fe..ac8730af 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -17,10 +17,11 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul + from vllm_ascend.utils import get_weight_prefetch_method -class AscendQuickGELU(QuickGELU): +class AscendQuickGELU(QuickGELU): def forward_oot(self, x: torch.tensor) -> torch.Tensor: import torch_npu @@ -29,7 +30,6 @@ class AscendQuickGELU(QuickGELU): class AscendSiluAndMul(SiluAndMul): - def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu diff --git a/vllm_ascend/ops/flashcomm2_oshard_manager.py b/vllm_ascend/ops/flashcomm2_oshard_manager.py index d95748e6..54f38617 100644 --- a/vllm_ascend/ops/flashcomm2_oshard_manager.py +++ b/vllm_ascend/ops/flashcomm2_oshard_manager.py @@ -1,11 +1,14 @@ -from typing import Any, Dict, Optional +from typing import Any from vllm.model_executor.models.utils import extract_layer_index from vllm_ascend.distributed.parallel_state import get_shard_weight_group from vllm_ascend.ops.layer_shard_linear import ( - is_hidden_layer, post_process_after_loading_for_shard_weight_series, - reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series) + is_hidden_layer, + post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, + register_layer_to_shard_weight_series, +) from vllm_ascend.utils import flashcomm2_enable, o_shard_enable @@ -26,7 +29,7 @@ class Flashcomm2OShardManager: """ def __init__(self): - self._shard_layers: Dict[int, Any] = {} + self._shard_layers: dict[int, Any] = {} def flashcomm2_oshard_enable(self): return flashcomm2_enable() and o_shard_enable() @@ -52,12 +55,10 @@ class Flashcomm2OShardManager: self._shard_layers[layer_idx] = layer register_layer_to_shard_weight_series( - series_name="o_proj", - group=get_shard_weight_group(), - layer=layer, - prefetch_step=prefetch_step) + series_name="o_proj", group=get_shard_weight_group(), layer=layer, prefetch_step=prefetch_step + ) - def get_layer(self, layer_idx: int) -> Optional[Any]: + def get_layer(self, layer_idx: int) -> Any | None: """Safely retrieves a registered layer by its index. Args: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index fa7ef0ae..b3a503e4 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,56 +15,53 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union import torch from torch import nn from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated -from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu -from vllm_ascend.utils import enable_custom_op -from vllm_ascend.utils import get_weight_prefetch_method +from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu +from vllm_ascend.utils import enable_custom_op, get_weight_prefetch_method + class AscendRMSNorm(RMSNorm): - def __init__( self, hidden_size: int, eps: float = 1e-6, - var_hidden_size: Optional[int] = None, + var_hidden_size: int | None = None, has_weight: bool = True, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, ) -> None: super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) vllm_config = get_current_vllm_config() self.bias = None # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) + if vllm_config.quant_config is not None and any( + "norm.bias" in name for name in vllm_config.quant_config.quant_description + ): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) def forward_oot( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: import torch_npu if residual is not None: if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( - x, residual, self.weight, self.bias, self.variance_epsilon) + x, residual, self.weight, self.bias, self.variance_epsilon + ) else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon) if self.bias is not None: x.add_(self.bias) return x, residual - x, residual = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) + x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) if self.bias is not None: x.add_(self.bias) @@ -75,42 +72,30 @@ class AscendRMSNorm(RMSNorm): class AscendGemmaRMSNorm(GemmaRMSNorm): - def forward_oot( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: import torch_npu - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type if residual is not None: if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( - x, residual, 1.0 + self.weight, None, - self.variance_epsilon) + x, residual, 1.0 + self.weight, None, self.variance_epsilon + ) else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, 1.0 + self.weight, self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon) return x, residual - x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, - self.variance_epsilon) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) return x + class LayerNormFn(torch.autograd.Function): @staticmethod - def forward(ctx, - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" x_shape_og = x.shape # reshape input data into 2D tensor @@ -143,16 +128,16 @@ class LayerNormFn(torch.autograd.Function): ctx.is_rms_norm = is_rms_norm return y.reshape(x_shape_og) -class AscendRMSNormGated(RMSNormGated): +class AscendRMSNormGated(RMSNormGated): def __init__( self, hidden_size, eps: float = 1e-5, - group_size: Optional[int] = None, + group_size: int | None = None, norm_before_gate: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). @@ -170,7 +155,5 @@ class AscendRMSNormGated(RMSNormGated): torch.nn.init.ones_(self.weight) def forward_oot(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ - return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, - self.norm_before_gate, True) \ No newline at end of file + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, self.norm_before_gate, True) diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index bf3bda6c..64d5d36a 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -19,15 +19,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import torch from torch import nn from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.layers.mla import (MLAModules, - MultiHeadLatentAttentionWrapper) +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata # type: ignore @@ -36,20 +34,20 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import vllm_version_is if vllm_version_is("v0.15.0"): - from vllm.attention.layer import MLAAttention # type: ignore + from vllm.attention.layer import MLAAttention # type: ignore else: from vllm.model_executor.layers.attention import MLAAttention class IndexerWrapper(nn.Module): - ''' + """ A wrapper of Indexer for Deepseek v3.2. This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py. It wraps the original Indexer, inherits its module weights (including wq_b, wk, weights_proj, k_norm) - while deletes the unused topk_indices_buffer and k_cache to save memory. + while deletes the unused topk_indices_buffer and k_cache to save memory. TODO: Will be removed once original Indexer supports different quantization methods. - ''' + """ def __init__(self, vllm_indexer: nn.Module) -> None: super().__init__() @@ -71,7 +69,6 @@ class IndexerWrapper(nn.Module): class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): - def __init__( self, hidden_size: int, @@ -80,11 +77,11 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, mla_modules: MLAModules, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: nn.Module.__init__(self) @@ -97,8 +94,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): self.v_head_dim = v_head_dim self.prefix = prefix hf_config = get_current_vllm_config().model_config.hf_text_config - self.enable_shared_expert_dp = get_ascend_config( - ).enable_shared_expert_dp + self.enable_shared_expert_dp = get_ascend_config().enable_shared_expert_dp self.tp_size = get_tensor_model_parallel_world_size() self.layers = hf_config.num_hidden_layers if mla_modules.indexer is not None: @@ -134,6 +130,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): def wrapped_process_weights(act_dtype: torch.dtype): from vllm_ascend.attention.sfa_v1 import AscendSFAImpl + if not isinstance(self.mla_attn.impl, AscendSFAImpl): original_process_weights(act_dtype) self.mla_attn.impl.process_weights_after_loading(act_dtype) @@ -146,19 +143,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): compilation_config.static_forward_context[prefix] = self def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: need_gather_q_kv = get_forward_context().sp_enabled output_shape = hidden_states.shape # FIXME: This does not seem right, should make sure the buffer is fixed - output = torch.empty(output_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, - self.prefix) + output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix) output = output.view(-1, output_shape[-1]) return output @@ -176,9 +171,9 @@ def mla_forward( else: attn_metadata = forward_context.attn_metadata kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] - self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states, - kv_cache, attn_metadata, need_gather_q_kv, - output) + self.mla_attn.impl.forward( + self.mla_attn.layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv, output + ) return diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index aa9a0737..7beb7b50 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -19,18 +19,15 @@ import einops import torch import torch.nn.functional as F import torch_npu -from vllm.config import MultiModalConfig from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore import vllm_ascend.envs as envs_ascend - MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight class AscendMMEncoderAttention(MMEncoderAttention): - def __init__( self, num_heads: int, @@ -82,13 +79,12 @@ class AscendMMEncoderAttention(MMEncoderAttention): return query, key, value def forward_oot( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor - | None = None, # Only used for Flash Attention + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ): bsz, q_len = query.size()[:2] kv_len = key.size(1) @@ -97,9 +93,7 @@ class AscendMMEncoderAttention(MMEncoderAttention): # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) - enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL - and self.head_size > MIN_PAD_SIZE - and self.head_size < MAX_PAD_SIZE) + enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE if enable_pad: origin_shape = q.shape[-1] @@ -114,10 +108,7 @@ class AscendMMEncoderAttention(MMEncoderAttention): context_layer = torch.empty_like(q) if cu_seqlens is None: - cu_seqlens = torch.arange(0, (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=query.device) + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device) cu_seqlens = torch.diff(cu_seqlens).to("cpu") @@ -137,11 +128,7 @@ class AscendMMEncoderAttention(MMEncoderAttention): context_layer = context_layer[..., :origin_shape] if is_reshaped: - context_layer = einops.rearrange(context_layer, - "(b s) h d -> b s h d", - b=bsz).contiguous() + context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous() else: - context_layer = einops.rearrange(context_layer, - "(b s) h d -> b s (h d)", - b=bsz).contiguous() + context_layer = einops.rearrange(context_layer, "(b s) h d -> b s (h d)", b=bsz).contiguous() return context_layer diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 11d6e8ed..2027369b 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -1,24 +1,25 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter, +) from vllm.forward_context import get_forward_context from vllm.utils.torch_utils import direct_register_custom_op -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream -from typing import Optional, Tuple -from vllm_ascend.ops.triton.rope import rope_forward_triton -def _maybe_chunk_residual_impl(x: torch.Tensor, - residual: torch.Tensor) -> torch.Tensor: + +def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -26,8 +27,7 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, if x.size(0) != residual.size(0): sp_enabled = forward_context.sp_enabled - assert sp_enabled is True, ("Currently, this situation only occurs " - "when sp is enabled") + assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled" pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -38,10 +38,7 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, return residual -def _maybe_all_gather_and_maybe_unpad_impl( - x: torch.Tensor, - label: bool, - is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -59,24 +56,20 @@ def _maybe_all_gather_and_maybe_unpad_impl( x = get_ep_group().all_gather(x, 0) # unpad num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu - result = torch.empty( - (num_tokens_across_dp_cpu.sum(), *x.shape[1:]), - device=x.device, - dtype=x.dtype) + result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) dp_size = get_dp_group().world_size x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] - result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp] + result[offset : offset + num_tokens_dp] = x[idx, :num_tokens_dp] offset += num_tokens_dp x = result return x -def _maybe_pad_and_reduce_impl(x: torch.Tensor, - is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -94,63 +87,44 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, else: # padding dp_size = get_dp_group().world_size - num_tokens_across_dp_cpu = \ - get_forward_context().dp_metadata.num_tokens_across_dp_cpu - padded_x = torch.empty( - (dp_size, forward_context.padded_length, *x.shape[1:]), - device=x.device, - dtype=x.dtype) + num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu + padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] - padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp] + padded_x[idx, :num_tokens_dp] = x[offset : offset + num_tokens_dp] offset += num_tokens_dp - return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), - 0) + return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), 0) -def _maybe_all_gather_and_maybe_unpad_fake( - x: torch.Tensor, - label: bool, - is_ep_comm: bool = False) -> torch.Tensor: - +def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: if get_forward_context().sp_enabled and label: return torch.empty( - (x.shape[0] * get_tensor_model_parallel_world_size(), - *x.shape[1:]), - device=x.device, - dtype=x.dtype) + (x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype + ) return x -def _maybe_pad_and_reduce_fake(x: torch.Tensor, - is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: if get_forward_context().sp_enabled: return torch.empty( - (x.shape[0] // get_tensor_model_parallel_world_size(), - *x.shape[1:]), - device=x.device, - dtype=x.dtype) + (x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype + ) return x -def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, - max_weight_size: int) -> None: +def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None: calculation_stream = torch_npu.npu.current_stream() weight_prefetch_stream = prefetch_stream() weight_prefetch_stream.wait_stream(calculation_stream) with npu_stream_switch(weight_prefetch_stream): - maybe_npu_prefetch(inputs=weight, - dependency=start_flag, - max_size=max_weight_size) + maybe_npu_prefetch(inputs=weight, dependency=start_flag, max_size=max_weight_size) -def _prefetch_preprocess_impl_fake(weight: torch.Tensor, - start_flag: torch.Tensor, - max_weight_size: int) -> None: +def _prefetch_preprocess_impl_fake(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None: return @@ -164,20 +138,16 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: return -def _maybe_all_reduce_tensor_model_parallel_impl( - final_hidden_states: torch.Tensor) -> torch.Tensor: +def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in { - MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2 - } or forward_context.sp_enabled: + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) -def _matmul_and_reduce_impl(input_parallel: torch.Tensor, - layer_name: str) -> torch.Tensor: +def _matmul_and_reduce_impl(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor: forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.custom_op is not None @@ -187,16 +157,15 @@ def _matmul_and_reduce_impl(input_parallel: torch.Tensor, return output -def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, - layer_name: str) -> torch.Tensor: +def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor: forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] num_tokens = input_parallel.size(0) if forward_context.sp_enabled: num_tokens = num_tokens // self.tp_size - output = torch.empty(size=(num_tokens, self.output_size_per_partition), - device=input_parallel.device, - dtype=input_parallel.dtype) + output = torch.empty( + size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype + ) return output @@ -207,77 +176,96 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, # pass input_scale and input_scale_reciprocal at the same time to avoid redundant # reciprocal calculation in fussion pass. We shall remove this once # aclnnAddRmsNormQuantV2 supports div_moe=False. -def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor, - input_scale_reciprocal: torch.Tensor, - input_offset: torch.Tensor) -> torch.Tensor: - return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, - input_offset, torch.qint8, -1, False) +def _quantize_impl( + in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor +) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) + + +def _quantize_impl_fake( + in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor +) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) -def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, - input_scale_reciprocal: torch.Tensor, - input_offset: torch.Tensor) -> torch.Tensor: - return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, - input_offset, torch.qint8, -1, False) def _rope_forward_triton_fake( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_dim: int = -1, - is_neox_style: bool = True -) -> Tuple[torch.Tensor, torch.Tensor]: + is_neox_style: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(q), torch.empty_like(k) -direct_register_custom_op(op_name="maybe_chunk_residual", - op_func=_maybe_chunk_residual_impl, - fake_impl=lambda x, residual: x, - mutates_args=[], - dispatch_key="PrivateUse1") -direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", - op_func=_maybe_all_gather_and_maybe_unpad_impl, - fake_impl=_maybe_all_gather_and_maybe_unpad_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: x, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="maybe_pad_and_reduce", - op_func=_maybe_pad_and_reduce_impl, - fake_impl=_maybe_pad_and_reduce_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=_maybe_all_gather_and_maybe_unpad_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="prefetch_preprocess", - op_func=_prefetch_preprocess_impl, - fake_impl=_prefetch_preprocess_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="prefetch_preprocess", + op_func=_prefetch_preprocess_impl, + fake_impl=_prefetch_preprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="prefetch_postprocess", - op_func=_prefetch_postprocess_impl, - fake_impl=_prefetch_postprocess_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="prefetch_preprocess", + op_func=_prefetch_preprocess_impl, + fake_impl=_prefetch_preprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", - op_func=_maybe_all_reduce_tensor_model_parallel_impl, - fake_impl=lambda x: x, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="prefetch_postprocess", + op_func=_prefetch_postprocess_impl, + fake_impl=_prefetch_postprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="matmul_and_reduce", - op_func=_matmul_and_reduce_impl, - fake_impl=_matmul_and_reduce_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1", +) -direct_register_custom_op(op_name="quantize", - op_func=_quantize_impl, - fake_impl=_quantize_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") -direct_register_custom_op(op_name="rope_forward_triton", - op_func=rope_forward_triton, - fake_impl=_rope_forward_triton_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="matmul_and_reduce", + op_func=_matmul_and_reduce_impl, + fake_impl=_matmul_and_reduce_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) + +direct_register_custom_op( + op_name="quantize", + op_func=_quantize_impl, + fake_impl=_quantize_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) +direct_register_custom_op( + op_name="rope_forward_triton", + op_func=rope_forward_triton, + fake_impl=_rope_forward_triton_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 8fb11724..c23cfdc2 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -15,7 +15,6 @@ # limitations under the License. # -from typing import Optional, Tuple import torch from torch import nn @@ -24,14 +23,20 @@ from vllm.distributed import divide from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod, - VocabParallelEmbedding, pad_vocab_size) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + UnquantizedEmbeddingMethod, + VocabParallelEmbedding, + pad_vocab_size, +) from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import (get_embed_tp_group, - get_lmhead_tp_group) +from vllm_ascend.distributed.parallel_state import get_embed_tp_group, get_lmhead_tp_group from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable @@ -42,14 +47,16 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): Added the feature of lmheadTP in pure dp scenario """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): nn.Module.__init__(self) self.forward_type = None if lmhead_tp_enable() and "head" in prefix: @@ -67,18 +74,20 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, - self.padding_size) + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, - self.padding_size) + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) assert self.org_vocab_size_padded <= self.num_embeddings_padded - self.shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, - self.tp_rank, self.tp_size) + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tp_rank, + self.tp_size, + ) self.embedding_dim = embedding_dim quant_method = None if quant_config is not None: @@ -90,12 +99,12 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self) is VocabParallelEmbedding - quant_method_implements_embedding = method_has_implemented_embedding( - type(quant_method)) + quant_method_implements_embedding = method_has_implemented_embedding(type(quant_method)) if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod.") + "the 'embedding' method, see UnquantizedEmbeddingMethod." + ) self.quant_method: QuantizeMethodBase = quant_method @@ -104,46 +113,47 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): self.params_dtype = params_dtype # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide(self.num_embeddings_padded, - self.tp_size) - assert (self.shard_indices.num_elements_padded == - self.num_embeddings_per_partition) + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.tp_size) + assert self.shard_indices.num_elements_padded == self.num_embeddings_per_partition self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index) + self.shard_indices.org_vocab_end_index - self.shard_indices.org_vocab_start_index + ) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index) + self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index + ) - self.quant_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) def _get_masked_input_and_mask( - self, input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + self, + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, + ) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. if added_vocab_start_index == added_vocab_end_index: - valid_offset = (org_vocab_start_index * org_vocab_mask) + valid_offset = org_vocab_start_index * org_vocab_mask vocab_mask = org_vocab_mask else: - added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) + added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index) + added_offset = ( + added_vocab_start_index - (org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask) vocab_mask = org_vocab_mask | added_vocab_mask # Adapt end. input_ = vocab_mask * (input_ - valid_offset) @@ -158,14 +168,15 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): def _forward_embed_tp(self, input_): complete_input = self.comm_group.all_gather(input_, dim=0) masked_input, input_mask = self._get_masked_input_and_mask( - complete_input, self.shard_indices.org_vocab_start_index, + complete_input, + self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) + self.shard_indices.added_vocab_end_index, + ) # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, masked_input.long()) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output = self.comm_group.reduce_scatter(output_parallel, dim=0) output = output.view(input_.shape[0], -1) @@ -175,16 +186,17 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): if self.tp_size > 1: # Build the mask. masked_input, input_mask = self._get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, + input_, + self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) + self.shard_indices.added_vocab_end_index, + ) else: masked_input = input_ # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) @@ -197,29 +209,31 @@ class AscendParallelLMHead(ParallelLMHead): """ Register ParallelLMHead as a custom op for Ascend.""" - def __init__(self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - AscendVocabParallelEmbedding.__init__(self, num_embeddings, - embedding_dim, params_dtype, - org_num_embeddings, padding_size, - quant_config, prefix) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + AscendVocabParallelEmbedding.__init__( + self, num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix + ) self.quant_config = quant_config if bias: - self.bias = Parameter( - torch.empty(self.num_embeddings_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + self.bias = Parameter(torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) @@ -234,48 +248,41 @@ class AscendLogitsProcessor(LogitsProcessor): self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: + embedding_bias: torch.Tensor | None = None, + ) -> torch.Tensor | None: if lmhead_tp_enable(): - return self._get_logits_lmheadtp(hidden_states, lm_head, - embedding_bias) + return self._get_logits_lmheadtp(hidden_states, lm_head, embedding_bias) else: - return self._get_logits_normal(hidden_states, lm_head, - embedding_bias) + return self._get_logits_normal(hidden_states, lm_head, embedding_bias) def _get_logits_lmheadtp( self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + embedding_bias: torch.Tensor | None, + ) -> torch.Tensor | None: # Gather hidden states from all devices in tensor parallel group - gathered_hidden_states = get_lmhead_tp_group().all_gather( - hidden_states, dim=0) - local_logits = lm_head.quant_method.apply(lm_head, - gathered_hidden_states, - bias=embedding_bias) + gathered_hidden_states = get_lmhead_tp_group().all_gather(hidden_states, dim=0) + local_logits = lm_head.quant_method.apply(lm_head, gathered_hidden_states, bias=embedding_bias) # Gather logits for tensor parallel logits = get_lmhead_tp_group().all_to_all(local_logits) # Remove paddings in vocab (if any) if logits is not None: - logits = logits[..., :self.org_vocab_size] + logits = logits[..., : self.org_vocab_size] return logits def _get_logits_normal( self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: - local_logits = lm_head.quant_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + embedding_bias: torch.Tensor | None, + ) -> torch.Tensor | None: + local_logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) # Gather logits for tensor parallel logits = self._gather_logits(local_logits) # Remove paddings in vocab (if any) if logits is not None: - logits = logits[..., :self.org_vocab_size] + logits = logits[..., : self.org_vocab_size] return logits diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index e41390ee..e53e899b 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -2,19 +2,18 @@ from dataclasses import dataclass, field import torch import torch_npu -from vllm.forward_context import ForwardContext, get_forward_context from vllm.config import get_current_vllm_config -from vllm.logger import logger +from vllm.forward_context import ForwardContext, get_forward_context from vllm_ascend.ascend_config import WeightPrefetchConfig -from vllm_ascend.ops.linear import (AscendQKVParallelLinear, - AscendRowParallelLinear) +from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear from vllm_ascend.utils import is_moe_model SUPPORTED_MODULES = ["attn", "mlp", "moe"] MOE_PREFETCH_TOKEN_THRESHOLD = 96 MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024 + @dataclass class ModuleWeightPrefetchConfig: module_name: str @@ -24,10 +23,7 @@ class ModuleWeightPrefetchConfig: linear_prefix_map: dict = field(default_factory=dict) def __post_init__(self) -> None: - self.prefetch_ratio = { - prefix: ratio - for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 - } + self.prefetch_ratio = {prefix: ratio for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1} assert self.module_name in SUPPORTED_MODULES, ( f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" @@ -41,6 +37,7 @@ class WeightPrefetchMethod: """ Unified weight prefetch method. """ + is_moe: bool = True MLP_GATE_UP: str = "gate_up" MLP_DOWN: str = "down" @@ -54,60 +51,53 @@ class WeightPrefetchMethod: self.attn = ModuleWeightPrefetchConfig( module_name="attn", enable=weight_prefetch_config.enabled, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "attn", {}) or {'qkv': 1.0, 'o': 1.0}, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("attn", {}) or {"qkv": 1.0, "o": 1.0}, linear_prefix_map={ AscendQKVParallelLinear.__name__: "qkv", AscendRowParallelLinear.__name__: "o", - }) + }, + ) self.moe = ModuleWeightPrefetchConfig( module_name="moe", enable=weight_prefetch_config.enabled and self.is_moe, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "moe", {}) or {'gate_up': 0.8}) + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("moe", {}) or {"gate_up": 0.8}, + ) self.mlp = ModuleWeightPrefetchConfig( module_name="mlp", enable=weight_prefetch_config.enabled and not self.is_moe, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "mlp", {}) or {'gate_up': 1.0, 'down': 1.0}) + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0}, + ) self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config def maybe_prefetch_attn_weight_preprocess( - self, layer_cls_name: str, weight: torch.Tensor, - start_flag: torch.Tensor) -> None: + self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor + ) -> None: if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") - weight_size = weight.data.element_size() * weight.data.numel( - ) * self.attn.prefetch_ratio.get(prefix, 0) + weight_size = weight.data.element_size() * weight.data.numel() * self.attn.prefetch_ratio.get(prefix, 0) - torch.ops.vllm.prefetch_preprocess(weight=weight, - start_flag=start_flag, - max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=start_flag, max_weight_size=int(weight_size)) - def maybe_prefetch_attn_weight_postprocess( - self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: + def maybe_prefetch_attn_weight_postprocess(self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return torch.ops.vllm.prefetch_postprocess(stop_flag) def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix): - self.moe.is_active_this_forward = hidden_states.shape[ - 0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False + self.moe.is_active_this_forward = ( + hidden_states.shape[0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False + ) if not self.moe.is_active_this_forward: return forward_context = get_forward_context() # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. - weight = forward_context.model_instance.model.layers[ - forward_context.layer_idx - 1].mlp.experts.w13_weight - weight_size = weight.data.element_size() * weight.data.numel( - ) * self.moe.prefetch_ratio.get(prefix, 0) - torch.ops.vllm.prefetch_preprocess(weight=weight, - start_flag=None, - max_weight_size=int(weight_size)) + weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight + weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0) + torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size)) def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): if not self.moe.is_active_this_forward: @@ -116,7 +106,9 @@ class WeightPrefetchMethod: torch.ops.vllm.prefetch_postprocess(stop_flag) # x_dependency only eager mode can pass None - def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None): + def maybe_prefetch_mlp_weight_preprocess( + self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None + ): if not self.mlp.enable and not self.mlp_pre_version_compatibale_config: self.mlp.is_active_this_forward = False return @@ -140,24 +132,26 @@ class WeightPrefetchMethod: else: raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}") - def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None): + def _maybe_prefetch_mlp_gate_up_weight_preprocess( + self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None + ): if not curr_layer_prefix: raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight") # start point of gate_up_proj weight prefetch - if curr_layer_prefix.split('.')[-2] == "self_attn": + if curr_layer_prefix.split(".")[-2] == "self_attn": model_instance = forward_context.model_instance - layer_idx = int(curr_layer_prefix.split('.')[2]) + layer_idx = int(curr_layer_prefix.split(".")[2]) weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) else: - weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) + weight_size = ( + weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) + ) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE - torch.ops.vllm.prefetch_preprocess(weight=weight, - start_flag=x_dependency, - max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) forward_context.prefetch_mlp_gate_up_proj = True def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext): @@ -167,12 +161,12 @@ class WeightPrefetchMethod: if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) else: - weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) + weight_size = ( + weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) + ) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE - torch.ops.vllm.prefetch_preprocess(weight=weight, - start_flag=x_dependency, - max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) forward_context.prefetch_mlp_down_proj = True forward_context.layer_idx += 1 @@ -185,19 +179,15 @@ class WeightPrefetchMethod: except AssertionError: return - if forward_context.prefetch_mlp_gate_up_proj or \ - forward_context.prefetch_mlp_down_proj: + if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj: torch.ops.vllm.prefetch_postprocess(stop_flag) forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False -def maybe_npu_prefetch(inputs: torch.Tensor, - dependency: torch.Tensor, - max_size: int = 0, - offset: int = 0, - *, - enabled: bool = True) -> None: +def maybe_npu_prefetch( + inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True +) -> None: if not enabled: return input_size = inputs.element_size() * inputs.numel() diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 89f558db..6a1a66c9 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -30,10 +30,9 @@ def get_spec_decode_method(method, vllm_config, device, runner): return EagleProposer(vllm_config, device, runner) elif method == "mtp": return MtpProposer(vllm_config, device, runner) - elif method == 'suffix': + elif method == "suffix": return SuffixDecodingProposer(vllm_config, device, runner) elif method == "medusa": return MedusaProposer(vllm_config, device, runner) else: - raise ValueError("Unknown speculative decoding method: " - f"{method}") + raise ValueError(f"Unknown speculative decoding method: {method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 986d7a71..022e2253 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from contextlib import contextmanager, nullcontext -from typing import Any, Callable, ContextManager, Optional, Union +from collections.abc import Callable +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config) -from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, - get_world_group, - init_model_parallel_group, - patch_tensor_parallel_group) +from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import ( + get_pp_group, + get_tp_group, + get_world_group, + init_model_parallel_group, + patch_tensor_parallel_group, +) from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -35,13 +38,11 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - update_full_graph_params) +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params from vllm_ascend.ops.rotary_embedding import update_cos_sin -from vllm_ascend.ops.triton.spec_decode.utils import \ - prepare_inputs_padded_kernel +from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled, lmhead_tp_enable, vllm_version_is +from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -76,28 +77,21 @@ def split_inputs_tp_to_sp(hidden_states, out): # copy only hidden_states in current rank hidden_states_curr_rank = hidden_states[start:end] - out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank + out[: hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank return out[:padded_num_tokens_per_rank] class EagleProposer(VllmEagleProposer): + _runnable: ACLGraphWrapper | Callable - _runnable: Union[ACLGraphWrapper, Callable] - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device, - runner=None): + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): super().__init__(vllm_config, device, runner) self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.decode_threshold = 1 + self.num_speculative_tokens - self.query_start_loc = self.runner._make_buffer( - self.runner.max_num_reqs + 1, dtype=torch.int32) - self.arange_cpu = torch.arange(self.arange.shape[0], - device="cpu", - dtype=torch.int32) + self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32) + self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) self.enable_shared_expert_dp = shared_expert_dp_enabled() @@ -108,11 +102,11 @@ class EagleProposer(VllmEagleProposer): self.dcp_rank = self.runner.dcp_rank self.full_indices = range( - self.runner.max_num_tokens * self.pcp_size * self.dcp_size + - self.pcp_size * self.dcp_size * self.runner.max_num_reqs) + self.runner.max_num_tokens * self.pcp_size * self.dcp_size + + self.pcp_size * self.dcp_size * self.runner.max_num_reqs + ) - self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, - "index_topk") + self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") # NOTE: # `draft_tensor_parallel_size` does not take effect for Eagle: # the draft model uses the same TP size as the target model in practice. @@ -122,8 +116,7 @@ class EagleProposer(VllmEagleProposer): # or the same as target model. # TODO(zhaomingyu13): If we want to adapt to the case where draft model tp # is not 1 and differs from target model, this part should be rewritten. - if (vllm_config.parallel_config.tensor_parallel_size - != self.speculative_config.draft_tensor_parallel_size): + if vllm_config.parallel_config.tensor_parallel_size != self.speculative_config.draft_tensor_parallel_size: tp_group = init_model_parallel_group( [[get_world_group().rank]], get_world_group().rank, @@ -135,47 +128,37 @@ class EagleProposer(VllmEagleProposer): else: self.tp_group_context = nullcontext() - self.use_cuda_graph = (self.runner._use_aclgraph() - and not self.speculative_config.enforce_eager) + self.use_cuda_graph = self.runner._use_aclgraph() and not self.speculative_config.enforce_eager if self.method == "mtp": self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling # TODO: Remove it when the bug of fx-graph is solved - self.maybe_eager_context: ContextManager[Any] = nullcontext() + self.maybe_eager_context: AbstractContextManager[Any] = nullcontext() if not self.use_cuda_graph and enable_sp(vllm_config): self.maybe_eager_context = _maybe_eager_context(vllm_config) self.last_token_indices = torch.zeros( - self.vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.int32, - device=device) - slot_mapping_lens = self.runner.max_num_tokens + \ - 2 * self.pcp_size * self.runner.max_num_reqs + self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device + ) + slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs self.slot_mapping_group = [ - torch.zeros( - slot_mapping_lens, dtype=torch.int32, device=device, - pin_memory=self.runner.pin_memory) - for _ in range(self.num_speculative_tokens)] + torch.zeros(slot_mapping_lens, dtype=torch.int32, device=device, pin_memory=self.runner.pin_memory) + for _ in range(self.num_speculative_tokens) + ] self._runnable = self._run_merged_draft def load_model(self, model: nn.Module) -> None: - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase).keys()) - target_indexer_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache).keys()) + target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) + target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) with self.maybe_eager_context: - self.model = get_model(vllm_config=self.vllm_config, - model_config=self.vllm_config. - speculative_config.draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config + ) - indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache).keys() - draft_attn_layer = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase).keys() + indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys() + draft_attn_layer = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() draft_attn_layer_names = draft_attn_layer - target_attn_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names @@ -184,30 +167,25 @@ class EagleProposer(VllmEagleProposer): self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.piece_all_attn_layer_name = [] for _ in range(self.num_speculative_tokens): - self.piece_all_attn_layer_name.append([ - name for name in self.attn_layer_names]) + self.piece_all_attn_layer_name.append([name for name in self.attn_layer_names]) self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.piece_all_attn_layer_name = [] for _ in range(self.num_speculative_tokens): - self.piece_all_attn_layer_name.append([ - name for name in self.attn_layer_names]) + self.piece_all_attn_layer_name.append([name for name in self.attn_layer_names]) if supports_multimodal(model): # handle multimodality if self.get_model_name(model) in [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen3VLForConditionalGeneration", - "Qwen3VLMoeForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", ]: self.model.config.image_token_index = model.config.image_token_id - elif self.get_model_name( - model) == "PixtralForConditionalGeneration": - self.model.config.image_token_index = ( - model.config.vision_config.image_token_id) + elif self.get_model_name(model) == "PixtralForConditionalGeneration": + self.model.config.image_token_index = model.config.vision_config.image_token_id else: - self.model.config.image_token_index = ( - model.config.image_token_index) + self.model.config.image_token_index = model.config.image_token_index target_language_model = model.get_language_model() else: target_language_model = model @@ -219,9 +197,7 @@ class EagleProposer(VllmEagleProposer): elif hasattr(target_language_model.model, "embedding"): target_embed_tokens = target_language_model.model.embedding else: - raise AttributeError( - "Target model does not have 'embed_tokens' or 'embedding' attribute" - ) + raise AttributeError("Target model does not have 'embed_tokens' or 'embedding' attribute") # If pp>1, the weights of mtp and the main model's embedding are not on the same device. # check if mtp model use main model's embedding and LMhead share_embeddings = False @@ -258,10 +234,7 @@ class EagleProposer(VllmEagleProposer): else: # MTP model share_embeddings = True - logger.info( - "Detected MTP model. " - "Sharing target model embedding weights with the draft model." - ) + logger.info("Detected MTP model. Sharing target model embedding weights with the draft model.") if share_embeddings: if hasattr(self.model.model, "embed_tokens"): @@ -269,7 +242,7 @@ class EagleProposer(VllmEagleProposer): self.model.model.embed_tokens = target_embed_tokens else: logger.info( - "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \ + "Since PP > 1 or other reasons the model head loaded its own vocab embedding" " weights instead of sharing them with the target model." ) # share lm_head with the target model if needed @@ -282,24 +255,19 @@ class EagleProposer(VllmEagleProposer): else: self.model.lm_head = model.lm_head - if self.method == "mtp" and \ - self.vllm_config.model_config.is_deepseek_mla: + if self.method == "mtp" and self.vllm_config.model_config.is_deepseek_mla: for _, layer_module in self.model.model.layers.items(): - if torch.equal(layer_module.shared_head.head.weight, - model.lm_head.weight): + if torch.equal(layer_module.shared_head.head.weight, model.lm_head.weight): layer_module.shared_head.head = model.lm_head - if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( - ) and self.use_cuda_graph: + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph: self.update_stream = torch.npu.Stream() if self.method == "mtp": - self.model = ACLGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) else: - self._runnable = ACLGraphWrapper(self._run_merged_draft, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + self._runnable = ACLGraphWrapper( + self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. @@ -313,22 +281,23 @@ class EagleProposer(VllmEagleProposer): return copy.copy(attn_metadata) @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False): + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: torch.Tensor | None = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ): ( num_tokens, num_tokens_across_dp, _, - ) = self.runner._sync_metadata_across_dp(num_tokens, - is_draft_model=True) + ) = self.runner._sync_metadata_across_dp(num_tokens, is_draft_model=True) # update global cos, sin update_cos_sin(self._get_positions(num_tokens)) @@ -336,19 +305,15 @@ class EagleProposer(VllmEagleProposer): multi_steps_attn_metadata = [] if not self.use_cuda_graph: aclgraph_runtime_mode = CUDAGraphMode.NONE - if aclgraph_runtime_mode == CUDAGraphMode.FULL and len( - self.runner.attn_groups) > 0: - num_computed_tokens_cpu = ( - self.runner.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs]) - self.query_start_loc.cpu[:num_reqs + 1] = torch.tensor( - [0] + self.runner.actual_seq_lengths_q[:num_reqs], - device="cpu", - dtype=torch.int32) + if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(self.runner.attn_groups) > 0: + num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + self.query_start_loc.cpu[: num_reqs + 1] = torch.tensor( + [0] + self.runner.actual_seq_lengths_q[:num_reqs], device="cpu", dtype=torch.int32 + ) self.query_start_loc.copy_to_gpu() common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], seq_lens_cpu=self.runner.seq_lens.cpu, seq_lens=self.runner.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, @@ -357,11 +322,9 @@ class EagleProposer(VllmEagleProposer): max_query_len=self.num_speculative_tokens + 1, num_computed_tokens_cpu=num_computed_tokens_cpu, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor()[:num_reqs], + block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor()[:num_reqs], # This is used to hold a position. - slot_mapping=self.runner.input_batch.block_table[0]. - slot_mapping.gpu, + slot_mapping=self.runner.input_batch.block_table[0].slot_mapping.gpu, positions=self.runner.positions.gpu, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, @@ -371,35 +334,33 @@ class EagleProposer(VllmEagleProposer): builder = self.runner.attn_groups[0][0].get_metadata_builder() # update the tensor's address for each step. for draft_step in range(self.num_speculative_tokens): - common_attn_metadata = self.shallow_copy_metadata( - common_attn_metadata) + common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata) # Set the real slot_mapping. common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] attn_metadata_eagle = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.ChunkedPrefill) + common_attn_metadata, AscendAttentionState.ChunkedPrefill + ) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata_eagle multi_steps_attn_metadata.append(per_layer_attn_metadata) - model_input_ids = self.input_ids[:num_tokens] model_positions = self._get_positions(num_tokens) - model_previous_hidden_states = self.hidden_states[:num_tokens] batch_size = num_tokens // (self.num_speculative_tokens + 1) with set_ascend_forward_context( - multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=0, - in_profile_run=is_profile, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True, - draft_attn_metadatas=multi_steps_attn_metadata): - + multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata, + ): if not vllm_version_is("v0.15.0"): # Reset MOE layer index before first model call forward_context = get_forward_context() @@ -417,11 +378,8 @@ class EagleProposer(VllmEagleProposer): is_dummy=True, ) forward_context = get_forward_context() - if (forward_context.cudagraph_runtime_mode - == CUDAGraphMode.FULL - and not forward_context.capturing): - self._update_full_graph_params(forward_context, num_tokens, - multi_steps_attn_metadata) + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: + self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) def _propose( self, @@ -433,11 +391,10 @@ class EagleProposer(VllmEagleProposer): target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: Optional[torch.Tensor], + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, req_scheduled_tokens=None, long_seq_metadata=None, num_prefill_reqs=0, @@ -445,7 +402,6 @@ class EagleProposer(VllmEagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -454,20 +410,17 @@ class EagleProposer(VllmEagleProposer): if self.method == "eagle3": assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids - if self.use_cuda_graph and \ - num_tokens <= self.runner.cudagraph_batch_sizes[-1]: - num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[ - num_tokens] + if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: + num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] else: num_input_tokens = num_tokens @@ -475,13 +428,13 @@ class EagleProposer(VllmEagleProposer): num_input_tokens, num_tokens_across_dp, _, - ) = self.runner._sync_metadata_across_dp(num_input_tokens, - is_draft_model=True) + ) = self.runner._sync_metadata_across_dp(num_input_tokens, is_draft_model=True) has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 if self.use_cuda_graph: - aclgraph_runtime_mode, batch_descriptor = \ - self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora) + aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( + num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora + ) else: aclgraph_runtime_mode = CUDAGraphMode.NONE batch_descriptor = None @@ -493,33 +446,30 @@ class EagleProposer(VllmEagleProposer): if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) inputs_embeds = self.model.embed_input_ids( - self.input_ids[:num_tokens], - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed) + self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed + ) self.inputs_embeds[:num_tokens] = inputs_embeds inputs_embeds = self.inputs_embeds[:num_input_tokens] - input_ids = self.input_ids[:num_input_tokens] else: inputs_embeds = None - input_ids = self.input_ids[:num_input_tokens] # Update slot_mapping for different speculative. # NOTE: Currently, we only remake the slot_mapping, because it's the # only tensor which will be used in current FIA. # Strictly speaking, `query_start_loc`, `seq_lens` should also have # their memory allocated separately for each step just like `slot_mapping`. - slot_mapping_lens = num_input_tokens if num_input_tokens < \ - common_attn_metadata.slot_mapping.shape[0] else \ - common_attn_metadata.slot_mapping.shape[0] - self.slot_mapping_group[0][:slot_mapping_lens].copy_( - common_attn_metadata.slot_mapping[:slot_mapping_lens]) + slot_mapping_lens = ( + num_input_tokens + if num_input_tokens < common_attn_metadata.slot_mapping.shape[0] + else common_attn_metadata.slot_mapping.shape[0] + ) + self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata = builder.build(0, common_attn_metadata, - self.runner.get_model()) + attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) # update global cos, sin update_cos_sin(self._get_positions(num_input_tokens)) @@ -536,35 +486,34 @@ class EagleProposer(VllmEagleProposer): # Copy the old attn_metadata and update for draft_step in range(1, self.num_speculative_tokens): - common_attn_metadata, attn_metadata = \ - self.attn_update_stack_num_spec_norm( - draft_step, - attn_metadata, - common_attn_metadata, - batch_size, - num_input_tokens, - used_update_positions, - aclgraph_runtime_mode) + common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode, + ) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata multi_steps_attn_metadata.append(per_layer_attn_metadata) last_token_indices_len = last_token_indices.shape[0] - self.last_token_indices[:last_token_indices_len].copy_( - last_token_indices) + self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) with set_ascend_forward_context( - multi_steps_attn_metadata[0], - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=num_tokens, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True, - draft_attn_metadatas=multi_steps_attn_metadata): - + multi_steps_attn_metadata[0], + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=num_tokens, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata, + ): if not vllm_version_is("v0.15.0"): # Reset MOE layer index for forward pass forward_context = get_forward_context() @@ -577,36 +526,37 @@ class EagleProposer(VllmEagleProposer): last_token_indices=self.last_token_indices[:last_token_indices_len], target_positions=target_positions, inputs_embeds=inputs_embeds, - multi_steps_attn_metadata=multi_steps_attn_metadata) + multi_steps_attn_metadata=multi_steps_attn_metadata, + ) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - self._update_full_graph_params(forward_context, - num_input_tokens, - multi_steps_attn_metadata) + self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata) return draft_token_ids - def _run_merged_draft(self, - num_input_tokens, - batch_size, - last_token_indices, - target_positions, - inputs_embeds, - multi_steps_attn_metadata, - is_dummy=False, + def _run_merged_draft( + self, + num_input_tokens, + batch_size, + last_token_indices, + target_positions, + inputs_embeds, + multi_steps_attn_metadata, + is_dummy=False, ) -> torch.Tensor: - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative + # tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the + # inputs of speculative model. model_input_ids = self.input_ids[:num_input_tokens] model_positions = self._get_positions(num_input_tokens) model_hidden_states = self.hidden_states[:num_input_tokens] - model_hidden_states, model_positions = self.maybe_pad_and_reduce( - model_hidden_states, model_positions) + model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) # Expend the remaining moe layers for suiting vllm. forward_context = get_forward_context() - if forward_context and hasattr(forward_context, 'remaining_moe_layers'): + if forward_context and hasattr(forward_context, "remaining_moe_layers"): if self.num_speculative_tokens > 1: moe_layers_needed = len(forward_context.remaining_moe_layers) * self.num_speculative_tokens if len(forward_context.remaining_moe_layers) < moe_layers_needed: @@ -619,7 +569,7 @@ class EagleProposer(VllmEagleProposer): input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, - inputs_embeds = inputs_embeds, + inputs_embeds=inputs_embeds, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -628,15 +578,15 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states) + last_hidden_states, model_positions, hidden_states + ) num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * - self.runner.uniform_decode_query_len) - last_token_indices = nn.functional.pad( - last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len + ) + last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -654,9 +604,8 @@ class EagleProposer(VllmEagleProposer): # Generate the remaining draft tokens. draft_token_ids_tensor = torch.zeros( - (self.num_speculative_tokens, *draft_token_ids.shape), - dtype=draft_token_ids.dtype, - device=self.device) + (self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device + ) draft_token_ids_tensor[0] = draft_token_ids if self.uses_mrope: positions = target_positions[:, last_token_indices] @@ -677,7 +626,7 @@ class EagleProposer(VllmEagleProposer): forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - + # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -691,25 +640,22 @@ class EagleProposer(VllmEagleProposer): # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. if self.uses_mrope: - exceeds_max_model_len = positions[ - 0] >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = positions[0] >= self.vllm_config.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), - torch.zeros_like(positions), positions) + exceeds_max_model_len.unsqueeze(0), torch.zeros_like(positions), positions + ) else: exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = self.model.embed_input_ids( - input_ids) + self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) input_ids = self.input_ids[:input_batch_size] inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -722,22 +668,24 @@ class EagleProposer(VllmEagleProposer): # Run the model. - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative + # tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the + # inputs of speculative model. model_input_ids = self.input_ids[:input_batch_size] model_positions = self._get_positions(input_batch_size) model_hidden_states = self.hidden_states[:input_batch_size] - model_hidden_states, model_positions = self.maybe_pad_and_reduce( - model_hidden_states, model_positions) + model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) - forward_context.attn_metadata = multi_steps_attn_metadata[draft_step + 1] \ - if multi_steps_attn_metadata else None + forward_context.attn_metadata = ( + multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None + ) ret_hidden_states = self.model( input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, - inputs_embeds = inputs_embeds, + inputs_embeds=inputs_embeds, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -746,13 +694,14 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states) + last_hidden_states, model_positions, hidden_states + ) num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * - self.runner.uniform_decode_query_len) + self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len + ) last_token_indices = nn.functional.pad( last_token_indices, (0, max_num_reqs_across_dp - num_indices), @@ -774,41 +723,40 @@ class EagleProposer(VllmEagleProposer): draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids - def attn_update_stack_num_spec_norm(self, - # `draft_step` must start from `1`, no `0` - draft_step, - old_attn_metadata, - old_common_metadata, - batch_size, - input_batch_size, - used_update_positions, - aclgraph_runtime_mode): - - assert(draft_step > 0) + def attn_update_stack_num_spec_norm( + self, + # `draft_step` must start from `1`, no `0` + draft_step, + old_attn_metadata, + old_common_metadata, + batch_size, + input_batch_size, + used_update_positions, + aclgraph_runtime_mode, + ): + assert draft_step > 0 common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) if draft_step == 1: - if ( - aclgraph_runtime_mode == CUDAGraphMode.FULL - and (pad_size := input_batch_size - batch_size) > 0 - ): + if aclgraph_runtime_mode == CUDAGraphMode.FULL and (pad_size := input_batch_size - batch_size) > 0: common_attn_metadata.num_reqs = input_batch_size common_attn_metadata.block_table_tensor = self._pad_tensor( - common_attn_metadata.block_table_tensor, pad_size) - common_attn_metadata.seq_lens = self._pad_tensor( - common_attn_metadata.seq_lens, pad_size) - common_attn_metadata.seq_lens_cpu = self._pad_tensor( - common_attn_metadata.seq_lens_cpu, pad_size) + common_attn_metadata.block_table_tensor, pad_size + ) + common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, pad_size) + common_attn_metadata.seq_lens_cpu = self._pad_tensor(common_attn_metadata.seq_lens_cpu, pad_size) common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( - common_attn_metadata.num_computed_tokens_cpu, pad_size) - common_attn_metadata.query_start_loc = self.arange[ - :input_batch_size + 1] + common_attn_metadata.num_computed_tokens_cpu, pad_size + ) + common_attn_metadata.query_start_loc = self.arange[: input_batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:input_batch_size + 1]).clone() + self.token_arange_np[: input_batch_size + 1] + ).clone() else: - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone() + self.token_arange_np[: batch_size + 1] + ).clone() common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 @@ -816,7 +764,7 @@ class EagleProposer(VllmEagleProposer): common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill common_attn_metadata.graph_pad_size = -1 common_attn_metadata.num_input_tokens = input_batch_size - + # The loop part used_update_positions += 1 @@ -828,18 +776,15 @@ class EagleProposer(VllmEagleProposer): # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. if self.uses_mrope: - exceeds_max_model_len = used_update_positions[ - 0] >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), - torch.zeros_like(used_update_positions), used_update_positions) + exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions + ) else: - exceeds_max_model_len = used_update_positions >= \ - self.vllm_config.model_config.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, - used_update_positions) + exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions) # For data integrity when async scheduling, we shouldn't use in place # operations in case they are modified in next step's `prepare_input` @@ -848,15 +793,11 @@ class EagleProposer(VllmEagleProposer): common_attn_metadata.seq_lens[:batch_size] += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens[:batch_size].masked_fill_( - exceeds_max_model_len, 1) + common_attn_metadata.seq_lens[:batch_size].masked_fill_(exceeds_max_model_len, 1) - common_attn_metadata.seq_lens_cpu[:batch_size] = ( - common_attn_metadata.seq_lens_cpu[:batch_size] + 1) - exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \ - self.max_model_len - common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_( - exceeds_mask, 1) + common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 + exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len + common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 if self.uses_mrope: common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) @@ -873,29 +814,22 @@ class EagleProposer(VllmEagleProposer): if self.uses_mrope: block_numbers = clamped_positions[0] // block_size else: - block_numbers = (clamped_positions // block_size) - block_ids = old_common_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) + block_numbers = clamped_positions // block_size + block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) if self.uses_mrope: - slot_mapping = (block_ids * block_size + - clamped_positions[0] % block_size) + slot_mapping = block_ids * block_size + clamped_positions[0] % block_size else: - slot_mapping = (block_ids * block_size + - clamped_positions % block_size) + slot_mapping = block_ids * block_size + clamped_positions % block_size # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - self.slot_mapping_group[draft_step][:slot_mapping.shape[0]].copy_( - slot_mapping.to(torch.int32)) - self.slot_mapping_group[draft_step][slot_mapping.shape[0]:].fill_( - PADDING_SLOT_ID) + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) + self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] - common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][ - :slot_mapping.shape[0]] + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore @@ -927,24 +861,22 @@ class EagleProposer(VllmEagleProposer): # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id(common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ] + ) self.backup_next_token_ids.copy_to_gpu(num_reqs) # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = discard_request_indices[: - num_discarded_requests] + discard_sampled_tokens_req_indices = discard_request_indices[:num_discarded_requests] valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) + valid_sampled_token_ids_gpu.index_fill_(0, discard_sampled_tokens_req_indices, -1) # Generate a mask for all valid tokens within those requests - valid_mask = (valid_sampled_token_ids_gpu != -1) & ( - valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) + valid_mask = (valid_sampled_token_ids_gpu != -1) & (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1) @@ -955,9 +887,7 @@ class EagleProposer(VllmEagleProposer): # Get last valid token from each row # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) + selected_tokens = torch.gather(valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)).squeeze(1) # Use last token if valid, pre-computed backup if not batch_size = valid_sampled_token_ids_gpu.shape[0] @@ -999,22 +929,17 @@ class EagleProposer(VllmEagleProposer): num_actual_reqs = len(num_draft_tokens) num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_actual_reqs - + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_actual_reqs + 1] seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs] new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = query_start_loc_cpu[ - 1:] - query_start_loc_cpu[:-1] + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -1035,43 +960,38 @@ class EagleProposer(VllmEagleProposer): # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = (self.token_arange_np[:total_num_tokens] - - new_query_start_locs_expanded) + token_offests = self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) - common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( - common_attn_metadata.slot_mapping[token_indices]) - common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) + common_attn_metadata.slot_mapping[: token_indices.shape[0]].copy_( + common_attn_metadata.slot_mapping[token_indices] + ) + common_attn_metadata.slot_mapping[token_indices.shape[0] :].fill_(-1) # NOTE: Currently positions and seq_lens are not used in attn forward # so we do not need to fixed them. But if they are used in the future, # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, @@ -1082,7 +1002,8 @@ class EagleProposer(VllmEagleProposer): positions=common_attn_metadata.positions[token_indices], attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0) + max_seq_len=0, + ) return spec_common_attn_metadata, token_indices def prepare_inputs_padded( @@ -1103,15 +1024,12 @@ class EagleProposer(VllmEagleProposer): num_reqs = common_attn_metadata.num_reqs device = valid_sampled_tokens_count.device - token_indices_to_sample = torch.empty((num_reqs, ), - dtype=torch.int32, - device=device) + token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device) - num_blocks_needed = triton.cdiv(num_reqs, - _PREPARE_INPUTS_BLOCK_SIZE) + num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE) num_vector_core = get_vectorcore_num() grid_size = min(num_blocks_needed, num_vector_core) - grid = (grid_size, ) + grid = (grid_size,) prepare_inputs_padded_kernel[grid]( spec_decode_metadata.cu_num_draft_tokens, @@ -1122,11 +1040,12 @@ class EagleProposer(VllmEagleProposer): BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE, ) else: - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1], - ]) + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, @@ -1134,14 +1053,11 @@ class EagleProposer(VllmEagleProposer): torch.zeros_like(num_draft_tokens_gpu), ) - token_indices_to_sample = ( - common_attn_metadata.query_start_loc[1:] - 1 - - num_rejected_tokens_gpu) + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = query_start_loc_cpu[ - 1:] - query_start_loc_cpu[:-1] + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] @@ -1154,8 +1070,7 @@ class EagleProposer(VllmEagleProposer): query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=common_attn_metadata.num_actual_tokens - if self.pcp_size > 1 else total_num_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens if self.pcp_size > 1 else total_num_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, max_query_len=new_query_len_per_req.max().item(), actual_seq_lengths_q=self.runner.actual_seq_lengths_q, @@ -1164,15 +1079,14 @@ class EagleProposer(VllmEagleProposer): positions=common_attn_metadata.positions, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, seq_lens=common_attn_metadata.seq_lens, - max_seq_len=0) + max_seq_len=0, + ) return spec_common_attn_metadata, token_indices, token_indices_to_sample - def _split_pcp_input(self, req_scheduled_tokens, input_ids, - target_hidden_states): + def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states): """ Split prefill input_ids and target_hidden_states in pcp group. 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] @@ -1185,29 +1099,27 @@ class EagleProposer(VllmEagleProposer): # no prefill inputs to split, return empty result return ( 0, - torch.zeros([0], device='npu'), - torch.zeros([0, target_hidden_states.size(1)], device='npu'), + torch.zeros([0], device="npu"), + torch.zeros([0, target_hidden_states.size(1)], device="npu"), 0, torch.zeros([0]), torch.tensor([0], dtype=torch.int32), ) def _pcp_pad_and_split(num_tokens): - num_pcp_padded_scheduled_tokens = cdiv( - num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size + num_pcp_padded_scheduled_tokens = cdiv(num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size) # split position_ids (and use split position_ids to split input_ids afterwards) req_position_cp: list[int] = [] + req_position_cp.extend(self.full_indices[self.pcp_rank * chunk_size : (self.pcp_rank + 1) * chunk_size]) req_position_cp.extend( - self.full_indices[self.pcp_rank * - chunk_size:(self.pcp_rank + 1) * chunk_size]) - req_position_cp.extend( - self.full_indices[num_pcp_padded_scheduled_tokens - - (self.pcp_rank + 1) * - chunk_size:num_pcp_padded_scheduled_tokens - - self.pcp_rank * chunk_size]) + self.full_indices[ + num_pcp_padded_scheduled_tokens - (self.pcp_rank + 1) * chunk_size : num_pcp_padded_scheduled_tokens + - self.pcp_rank * chunk_size + ] + ) return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad @@ -1217,17 +1129,12 @@ class EagleProposer(VllmEagleProposer): pcp_split_input_ids_list = [] pcp_split_hidden_states_list = [] for ori_num_tokens in req_scheduled_tokens.values(): - req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \ - _pcp_pad_and_split(ori_num_tokens) + req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = _pcp_pad_and_split(ori_num_tokens) actual_num_tokens = len(req_position_pcp) num_pcp_scheduled_tokens.append(actual_num_tokens) - pad_input_ids = F.pad( - input_ids[ori_start_index:ori_start_index + ori_num_tokens], - (0, num_pcp_pad)) + pad_input_ids = F.pad(input_ids[ori_start_index : ori_start_index + ori_num_tokens], (0, num_pcp_pad)) ori_start_index += ori_num_tokens - pcp_chunk_indices = [ - pad_start_index + pos for pos in req_position_pcp - ] + pcp_chunk_indices = [pad_start_index + pos for pos in req_position_pcp] pcp_split_input_ids = pad_input_ids[req_position_pcp] pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices] pcp_split_input_ids_list.append(pcp_split_input_ids) @@ -1238,16 +1145,20 @@ class EagleProposer(VllmEagleProposer): target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0) max_query_len = max(num_pcp_scheduled_tokens) seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32) - cu_num_tokens = torch.tensor( - np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) + cu_num_tokens = torch.tensor(np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens # update full-graph params for one spec token def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): update_full_graph_params( - self.runner.attn_backend, self.update_stream, forward_context, num_tokens, - self.vllm_config, self.vllm_config.speculative_config, - draft_attn_metadatas=draft_attn_metadatas) + self.runner.attn_backend, + self.update_stream, + forward_context, + num_tokens, + self.vllm_config, + self.vllm_config.speculative_config, + draft_attn_metadatas=draft_attn_metadatas, + ) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size): @@ -1262,16 +1173,14 @@ class EagleProposer(VllmEagleProposer): ) -> tuple[torch.Tensor, torch.Tensor]: if self.method == "mtp": if self.enable_shared_expert_dp: - hidden_states = torch.ops.vllm.maybe_pad_and_reduce( - hidden_states) + hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = positions.squeeze(-1) else: forward_context = get_forward_context() if forward_context.sp_enabled: - hidden_states = split_inputs_tp_to_sp( - hidden_states, hidden_states) + hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) return hidden_states, positions def maybe_all_gather_and_unpad( @@ -1283,17 +1192,17 @@ class EagleProposer(VllmEagleProposer): if self.method == "mtp": if self.enable_shared_expert_dp: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True) - positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - positions.contiguous(), True) + last_hidden_states.contiguous(), True + ) + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) if hidden_states is not None: hidden_states = last_hidden_states else: forward_context = get_forward_context() if forward_context.sp_enabled: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True) + last_hidden_states.contiguous(), True + ) if hidden_states is not None: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), True) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True) return last_hidden_states, positions, hidden_states diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index feec5bcf..138f88dc 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -1,5 +1,4 @@ import enum -from typing import Optional import torch from vllm.config import CUDAGraphMode, VllmConfig @@ -18,11 +17,7 @@ class SpecDcodeType(enum.Enum): class Proposer: - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device = None, - runner=None): + def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None): pass def load_model(self, model): @@ -30,25 +25,29 @@ class Proposer: raise NotImplementedError @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: torch.Tensor | None = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + ): """Called by dummy_run in modle_runner""" raise NotImplementedError - def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - aux_hidden_states: torch.Tensor = None): + def generate_token_ids( + self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + aux_hidden_states: torch.Tensor = None, + ): """Called by execute_model in model_runner""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/vllm_ascend/spec_decode/medusa_proposer.py b/vllm_ascend/spec_decode/medusa_proposer.py index eda2e41a..ff727cc8 100644 --- a/vllm_ascend/spec_decode/medusa_proposer.py +++ b/vllm_ascend/spec_decode/medusa_proposer.py @@ -1,14 +1,9 @@ -from typing import Optional - import torch -import torch.nn as nn from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.spec_decode.interface import SpecDcodeType @@ -22,72 +17,70 @@ class MedusaProposer(VllmMedusaProposer): """ def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner, + self, + vllm_config: VllmConfig, + device: torch.device, + runner, ): # Save config parameters self.name = SpecDcodeType.MEDUSA self.vllm_config = vllm_config self.device = device self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.hidden_size = (vllm_config.speculative_config.draft_model_config. - get_hidden_size()) + self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size() self.dtype = vllm_config.model_config.dtype self.runner = runner @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False): + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: torch.Tensor | None = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ): hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=self.device, ) with set_ascend_forward_context( - None, - self.vllm_config, - num_tokens=num_tokens, - num_actual_tokens=0, - in_profile_run=is_profile, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True): + None, + self.vllm_config, + num_tokens=num_tokens, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + ): self.model(hidden_states) dummy_compute_logits(hidden_states) - def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - spec_decode_metadata: SpecDecodeMetadata, - sample_hidden_states: torch.Tensor, - *args, - **kwargs - ): - + def generate_token_ids( + self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + spec_decode_metadata: SpecDecodeMetadata, + sample_hidden_states: torch.Tensor, + *args, + **kwargs, + ): if sample_hidden_states.shape[0] == len(valid_sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: num_accepted_tokens = torch.tensor( - [len(t) for t in valid_sampled_token_ids], - device=self.device, - dtype=torch.long) - num_draft_tokens = torch.tensor( - spec_decode_metadata.num_draft_tokens, - device=self.device, - dtype=torch.long) + [len(t) for t in valid_sampled_token_ids], device=self.device, dtype=torch.long + ) + num_draft_tokens = torch.tensor(spec_decode_metadata.num_draft_tokens, device=self.device, dtype=torch.long) - offsets = torch.cumsum(num_draft_tokens + 1, - dim=0) - (num_draft_tokens + 1) + offsets = torch.cumsum(num_draft_tokens + 1, dim=0) - (num_draft_tokens + 1) indices = offsets + num_accepted_tokens - 1 hidden_states = sample_hidden_states[indices] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 03f21fed..3f43b194 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch import torch.nn as nn from vllm.config import CUDAGraphMode @@ -22,29 +20,33 @@ from vllm_ascend.utils import lmhead_tp_enable, vllm_version_is class MtpProposer(EagleProposer): - # TODO: Find out why ModelRunner does not this explicit typing? - model: Union[nn.Module, ACLGraphWrapper] + model: nn.Module | ACLGraphWrapper @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False) -> None: - if ( - self.pcp_size * self.dcp_size == 1 - and not self.speculative_config.disable_padded_drafter_batch - ): + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ) -> None: + if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: super().dummy_run( - num_tokens, with_prefill, in_graph_capturing, num_reqs, - num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor, - dummy_compute_logits, is_profile + num_tokens, + with_prefill, + in_graph_capturing, + num_reqs, + num_tokens_across_dp, + aclgraph_runtime_mode, + batch_descriptor, + dummy_compute_logits, + is_profile, ) return ( @@ -61,14 +63,10 @@ class MtpProposer(EagleProposer): aclgraph_runtime_mode = CUDAGraphMode.NONE if aclgraph_runtime_mode == CUDAGraphMode.FULL: if len(self.runner.attn_groups) > 0: - num_computed_tokens_cpu = ( - self.runner.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs]) + num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + - 1], - query_start_loc_cpu=self.runner.query_start_loc. - cpu[:num_reqs + 1], + query_start_loc=self.runner.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc.cpu[: num_reqs + 1], seq_lens_cpu=self.runner.seq_lens.cpu, seq_lens=self.runner.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, @@ -77,27 +75,29 @@ class MtpProposer(EagleProposer): max_query_len=self.num_speculative_tokens + 1, num_computed_tokens_cpu=num_computed_tokens_cpu, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping=self.runner.input_batch.block_table[0]. - slot_mapping.gpu, + block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), + slot_mapping=self.runner.input_batch.block_table[0].slot_mapping.gpu, positions=self.runner.positions.gpu, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0) + max_seq_len=0, + ) if self.pcp_size * self.dcp_size > 1: # update long_seq related params and flatten block_table - common_attn_metadata.prefill_context_parallel_metadata = \ - self.runner.pcp_manager.long_seq_metadata - common_attn_metadata.block_table_tensor = \ - self.runner.input_batch.block_table[0].get_device_tensor()[ - :num_reqs * self.decode_threshold] + common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata + common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[ + 0 + ].get_device_tensor()[: num_reqs * self.decode_threshold] builder = self.runner.attn_groups[0][0].get_metadata_builder() - # `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention. - attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill - attn_metadata_mtp = builder.build_for_graph_capture( - common_attn_metadata, attn_state) + # `AscendAttentionState.SpecDecoding` is only designed for mla, + # `AscendAttentionState.ChunkedPrefill` is used in self-attention. + attn_state = ( + AscendAttentionState.SpecDecoding + if self.vllm_config.model_config.use_mla + else AscendAttentionState.ChunkedPrefill + ) + attn_metadata_mtp = builder.build_for_graph_capture(common_attn_metadata, attn_state) attn_metadata = {} for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp @@ -113,32 +113,34 @@ class MtpProposer(EagleProposer): if i > 0 and not in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: aclgraph_runtime_mode = CUDAGraphMode.NONE with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - is_draft_model=True, - in_profile_run=is_profile): + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + is_draft_model=True, + in_profile_run=is_profile, + ): if not vllm_version_is("v0.15.0"): # Reset MOE layer index for each MTP step iteration forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - previous_hidden_states, positions = self.maybe_pad_and_reduce( - previous_hidden_states, positions) - self.model(input_ids=input_ids, - positions=positions, - hidden_states=previous_hidden_states) + previous_hidden_states, positions = self.maybe_pad_and_reduce(previous_hidden_states, positions) + self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing and not self.use_sparse: + if ( + forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL + and not forward_context.capturing + and not self.use_sparse + ): self._update_full_graph_params(forward_context, num_tokens) previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad( - previous_hidden_states, positions) + previous_hidden_states, positions + ) dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -153,11 +155,10 @@ class MtpProposer(EagleProposer): target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: Optional[torch.Tensor], + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, req_scheduled_tokens=None, long_seq_metadata=None, num_prefill_reqs=0, @@ -165,16 +166,22 @@ class MtpProposer(EagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: - if ( - self.pcp_size * self.dcp_size == 1 - and not self.speculative_config.disable_padded_drafter_batch - ): + if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: draft_token_ids = super()._propose( - target_token_ids, target_positions, target_hidden_states, - next_token_ids, last_token_indices, common_attn_metadata, - sampling_metadata, mm_embed_inputs, req_scheduled_tokens, - long_seq_metadata, num_prefill_reqs, num_decode_reqs, - scheduler_output, num_scheduled_tokens + target_token_ids, + target_positions, + target_hidden_states, + next_token_ids, + last_token_indices, + common_attn_metadata, + sampling_metadata, + mm_embed_inputs, + req_scheduled_tokens, + long_seq_metadata, + num_prefill_reqs, + num_decode_reqs, + scheduler_output, + num_scheduled_tokens, ) return draft_token_ids @@ -186,13 +193,12 @@ class MtpProposer(EagleProposer): if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids @@ -213,20 +219,16 @@ class MtpProposer(EagleProposer): num_tokens_d_padded = num_tokens_d * self.pcp_size input_ids_d = self.input_ids[:num_tokens_d] input_ids_p = self.input_ids[num_tokens_d:num_tokens] - target_hidden_states_d_padded = \ - target_hidden_states[:num_tokens_d_padded] + target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] if num_tokens_d: # remove padding (from pcp all-gather) in decode part - mask_start_loc = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1] - ]) + mask_start_loc = torch.cat( + [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] + ) mask_len = query_lens_d mask = [] for req_id in range(num_decode_reqs): - mask += list( - range(mask_start_loc[req_id], - mask_start_loc[req_id] + mask_len[req_id])) + mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) target_hidden_states_d = target_hidden_states_d_padded[mask] else: target_hidden_states_d = target_hidden_states_d_padded @@ -234,46 +236,33 @@ class MtpProposer(EagleProposer): req_scheduled_tokens_p = {} for i, req_id in enumerate(self.runner.input_batch.req_ids): if i >= num_decode_reqs: - req_scheduled_tokens_p[req_id] = \ - req_scheduled_tokens[req_id] - (num_tokens_p, input_ids_p, target_hidden_states_p, - max_query_len_p, seq_lens_p, cu_num_tokens_p) = \ - self._split_pcp_input( - req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) + req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] + (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( + self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) + ) num_tokens = num_tokens_d + num_tokens_p target_positions = target_positions[:num_tokens] - self.input_ids[:num_tokens].copy_( - torch.cat([input_ids_d, input_ids_p], dim=0)) - target_hidden_states = torch.cat( - [target_hidden_states_d, target_hidden_states_p], dim=0) + self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) + target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) # 2. update sample_indices according to main model if num_decode_reqs: - last_token_indices[:num_decode_reqs] = \ - self.runner.logits_indices[last_token_indices[:num_decode_reqs]] + last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] if num_prefill_reqs: - last_token_indices[-num_prefill_reqs:] = \ - self.runner.logits_indices[-num_prefill_reqs:] + last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] # 3. update attn_metadata params that may be influenced by pcp common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max( - self.decode_threshold, max_query_len_p) + common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p - common_attn_metadata.seq_lens_cpu[ - -num_prefill_reqs:] = seq_lens_p - query_start_loc_p = cu_num_tokens_p[1:] + \ - common_attn_metadata.query_start_loc[num_decode_reqs].item() - common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \ - query_start_loc_p - common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \ - query_start_loc_p + common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item() + common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p + common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p assert self.runner is not None # Note(qcs): We may need to refactor these check logics. - if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[ - -1]: - num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[ - num_scheduled_tokens] + if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[-1]: + num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_scheduled_tokens] else: # Eager mode, no padding needed num_input_tokens = num_tokens @@ -282,23 +271,23 @@ class MtpProposer(EagleProposer): self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states # eager/acl piecewise mode need to update num_tokens_across_dp - (num_input_tokens, num_tokens_across_dp, - with_prefill) = self.runner._sync_metadata_across_dp( - num_input_tokens, self.runner.with_prefill) + (num_input_tokens, num_tokens_across_dp, with_prefill) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill + ) # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. if scheduler_output and not self.enable_shared_expert_dp: max_query_len = common_attn_metadata.max_query_len - uniform_decode = (max_query_len in list( - range(1, self.num_speculative_tokens + - 2))) and (scheduler_output.total_num_scheduled_tokens - == self.runner.input_batch.num_reqs * - (self.num_speculative_tokens + 1)) + uniform_decode = (max_query_len in list(range(1, self.num_speculative_tokens + 2))) and ( + scheduler_output.total_num_scheduled_tokens + == self.runner.input_batch.num_reqs * (self.num_speculative_tokens + 1) + ) else: uniform_decode = False has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 - aclgraph_runtime_mode, batch_descriptor = \ - self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) + aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( + num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) if not self.use_cuda_graph: # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the @@ -307,8 +296,10 @@ class MtpProposer(EagleProposer): # and _propose. aclgraph_runtime_mode = CUDAGraphMode.NONE - if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( - ) and aclgraph_runtime_mode == CUDAGraphMode.FULL: + if ( + self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() + and aclgraph_runtime_mode == CUDAGraphMode.FULL + ): graph_pad_size = num_input_tokens else: graph_pad_size = -1 @@ -319,64 +310,58 @@ class MtpProposer(EagleProposer): common_attn_metadata.graph_pad_size = graph_pad_size common_attn_metadata.num_input_tokens = num_input_tokens builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, - self.runner.get_model()) + attn_metadata_mtp = builder.build(0, common_attn_metadata, self.runner.get_model()) attn_metadata = {} for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - num_actual_tokens=num_tokens, - is_draft_model=True): - + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=num_tokens, + is_draft_model=True, + ): if not vllm_version_is("v0.15.0"): # Reset MOE layer index for each MTP step to match all_moe_layers registration forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - with record_function_or_nullcontext('mtp_forward'): + with record_function_or_nullcontext("mtp_forward"): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata input_ids = self.input_ids[:num_input_tokens] positions = self._get_positions(num_input_tokens) hidden_states = self.hidden_states[:num_input_tokens] - hidden_states, positions = self.maybe_pad_and_reduce( - hidden_states, positions) + hidden_states, positions = self.maybe_pad_and_reduce(hidden_states, positions) - hidden_states = self.model(input_ids=input_ids, - positions=positions, - hidden_states=hidden_states) - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: - self._update_full_graph_params(forward_context, - num_input_tokens) + hidden_states = self.model(input_ids=input_ids, positions=positions, hidden_states=hidden_states) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: + self._update_full_graph_params(forward_context, num_input_tokens) - hidden_states, positions, _ = self.maybe_all_gather_and_unpad( - hidden_states, positions) + hidden_states, positions, _ = self.maybe_all_gather_and_unpad(hidden_states, positions) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): - max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len - last_token_indices = nn.functional.pad( - last_token_indices, - (0, max_num_reqs_across_dp - num_indices)) + max_num_reqs_across_dp = ( + self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len + ) + last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) if self.pcp_size > 1 and step == 0: # remove graph padding before all_gather hidden_states = hidden_states[:num_tokens] hidden_states = get_pcp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, self.runner.pcp_manager. - pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]]) + hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] + ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -409,7 +394,7 @@ class MtpProposer(EagleProposer): hidden_states = hidden_states[last_token_indices] slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] attn_metadata_i.slot_mapping.fill_(-1) - attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] + attn_metadata_i.query_start_loc = self.arange[: batch_size + 1] last_token_indices = self.arange[:batch_size] if getattr(attn_metadata_i, "num_decode_tokens", 0): attn_metadata_i.num_decode_tokens = batch_size @@ -420,44 +405,44 @@ class MtpProposer(EagleProposer): # Instead, we pre-allocate mtp slot_mapping in model_runner # (_generate_pcp_mtp_input), and use updated slot_indices # to get corresponding slot_mapping in each step. - num_reject_tokens = torch.tensor( - self.runner.pcp_manager.cu_num_tokens_pcp_full, - dtype=torch.int32).to( - self.device) - ori_last_token_indices - 1 - num_accept_tokens = \ - query_lens_d.to(self.device) - num_reject_tokens + num_reject_tokens = ( + torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) + - ori_last_token_indices + - 1 + ) + num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens ori_seq_len = attn_metadata_i.seq_lens mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad # slot_mapping index base offset: # scheduled tokens + pre-allocated mtp tokens + accepted tokens slot_idx_base = ( - torch.cat([ - torch.tensor( - [0], dtype=torch.int32, device=self.device), - (torch.cumsum(query_lens_d, dim=0)[:-1] * - self.pcp_size).to(self.device) - ]) + - torch.arange(num_decode_reqs, device=self.device) * - (self.num_speculative_tokens - 1) * self.pcp_size + - (num_accept_tokens - 1) * self.pcp_size) + torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=self.device), + (torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device), + ] + ) + + torch.arange(num_decode_reqs, device=self.device) + * (self.num_speculative_tokens - 1) + * self.pcp_size + + (num_accept_tokens - 1) * self.pcp_size + ) slot_indices_list = [] for req_id in range(num_decode_reqs): slot_indices_list.append( - torch.arange(slot_idx_base[req_id], - slot_idx_base[req_id] + self.pcp_size, - device=self.device)) + torch.arange( + slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device + ) + ) slot_indices = torch.cat(slot_indices_list, dim=0) # fold block_table (restore it to original size before flattened) - block_indices = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(query_lens_d, dim=0)[:-1] - ]) - attn_metadata_i.decode.block_table[:batch_size] = \ - attn_metadata_i.decode.block_table[block_indices] - attn_metadata_i.decode.block_table = \ - attn_metadata_i.decode.block_table[:batch_size] + block_indices = torch.cat( + [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]] + ) + attn_metadata_i.decode.block_table[:batch_size] = attn_metadata_i.decode.block_table[block_indices] + attn_metadata_i.decode.block_table = attn_metadata_i.decode.block_table[:batch_size] input_ids = draft_token_ids_list[-1].int() positions += 1 @@ -465,38 +450,32 @@ class MtpProposer(EagleProposer): decode_metadata = getattr(attn_metadata_i, "decode", None) prefill_metadata = getattr(attn_metadata_i, "prefill", None) # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. - if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \ - aclgraph_runtime_mode != CUDAGraphMode.FULL): - decode_metadata.actual_seq_lengths_q = self.arange_cpu[ - 1:batch_size + 1].tolist() + if decode_metadata is not None and ( + self.speculative_config.disable_padded_drafter_batch or aclgraph_runtime_mode != CUDAGraphMode.FULL + ): + decode_metadata.actual_seq_lengths_q = self.arange_cpu[1 : batch_size + 1].tolist() if aclgraph_runtime_mode == CUDAGraphMode.FULL: - decode_metadata.actual_seq_lengths_q = \ - builder.pad_actual_seq_len_q_mtp_disable_pad( - graph_pad_size - batch_size, - batch_size, - decode_metadata.actual_seq_lengths_q) - decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla( - positions[:batch_size]) + decode_metadata.actual_seq_lengths_q = builder.pad_actual_seq_len_q_mtp_disable_pad( + graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q + ) + decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(positions[:batch_size]) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. - exceeds_max_model_len = positions[: - batch_size] >= self.runner.model_config.max_model_len + exceeds_max_model_len = positions[:batch_size] >= self.runner.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions[:batch_size]) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions[:batch_size]) # Increment the sequence lengths. # This is an out-of-place operation to avoid modifying the original tensor # when enable async_scheduling. attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > \ - self.runner.model_config.max_model_len + exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > self.runner.model_config.max_model_len attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the @@ -504,13 +483,14 @@ class MtpProposer(EagleProposer): slot_mapping += 1 if self.pcp_size > 1: exceeds_max_model_len = exceeds_max_model_len.repeat_interleave( - slot_mapping.size(0) // exceeds_max_model_len.size(0)) + slot_mapping.size(0) // exceeds_max_model_len.size(0) + ) slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) - self.hidden_states[:hidden_states.shape[0]] = hidden_states + self.hidden_states[: hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: # update local seq_len num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( @@ -519,19 +499,17 @@ class MtpProposer(EagleProposer): self.dcp_size, self.runner.parallel_config.cp_kv_cache_interleave_size, ) - cp_seq_len = \ - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] + cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] attn_metadata_i.decode.cp_seq_len = cp_seq_len # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices] - attn_metadata_i.slot_mapping[:batch_size * - self.pcp_size] = slot_mapping + attn_metadata_i.slot_mapping[: batch_size * self.pcp_size] = slot_mapping else: attn_metadata_i.slot_mapping[:batch_size] = slot_mapping if self.speculative_config.disable_padded_drafter_batch: if self.uses_mrope: - self.mrope_positions[:, batch_size:num_input_tokens] = 0 + self.mrope_positions[:, batch_size:num_input_tokens] = 0 else: self.positions[batch_size:num_input_tokens] = 0 self.input_ids[batch_size:num_input_tokens] = 0 @@ -539,31 +517,24 @@ class MtpProposer(EagleProposer): if prefill_metadata is not None: prefill_metadata.seq_lens = attn_metadata_i.seq_lens - prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist( - ) + prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist() prefill_metadata.context_lens = attn_metadata_i.seq_lens - prefill_metadata.input_positions = self._get_positions( - num_input_tokens) + prefill_metadata.input_positions = self._get_positions(num_input_tokens) prefill_metadata.max_seq_lens += 1 prefill_metadata.max_seq_lens = min( - prefill_metadata.max_seq_lens, - self.runner.model_config.max_model_len) + prefill_metadata.max_seq_lens, self.runner.model_config.max_model_len + ) if decode_metadata is not None: decode_metadata.seq_lens = attn_metadata_i.seq_lens - decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist( - ) + decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist() decode_seq_lens_list = decode_metadata.seq_lens_list - if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ - self.speculative_config.disable_padded_drafter_batch: - decode_metadata.seq_lens_list = decode_seq_lens_list + [ - 0 - ] * (graph_pad_size - len(decode_seq_lens_list)) - decode_metadata.input_positions = self._get_positions( - num_input_tokens) + if aclgraph_runtime_mode == CUDAGraphMode.FULL and self.speculative_config.disable_padded_drafter_batch: + decode_metadata.seq_lens_list = decode_seq_lens_list + [0] * ( + graph_pad_size - len(decode_seq_lens_list) + ) + decode_metadata.input_positions = self._get_positions(num_input_tokens) decode_metadata.max_seq_lens += 1 - decode_metadata.max_seq_lens = min( - decode_metadata.max_seq_lens, - self.runner.model_config.max_model_len) + decode_metadata.max_seq_lens = min(decode_metadata.max_seq_lens, self.runner.model_config.max_model_len) # mtp>1: [batch_size, k] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 22d28b61..280d2ca8 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,13 +1,11 @@ import torch from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.ngram_proposer import \ - NgramProposer as VllmNgramProposer +from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType class NgramProposer(VllmNgramProposer, Proposer): - def __init__(self, vllm_config, device, runner): super().__init__(vllm_config) self.name = SpecDcodeType.NGRAM @@ -19,27 +17,31 @@ class NgramProposer(VllmNgramProposer, Proposer): pass @torch.inference_mode() - def dummy_run(self, - num_tokens, - with_prefill=None, - in_graph_capturing=None, - num_reqs=None, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False): + def dummy_run( + self, + num_tokens, + with_prefill=None, + in_graph_capturing=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ): pass - def generate_token_ids(self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None) -> list[list[int]]: + def generate_token_ids( + self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + aux_hidden_states=None, + ) -> list[list[int]]: valid_ngram_requests = [] for i, sampled_ids in enumerate(valid_sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -57,8 +59,7 @@ class NgramProposer(VllmNgramProposer, Proposer): start_idx = self.runner.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids - self.runner.input_batch.token_ids_cpu[ - i, start_idx:end_idx] = sampled_ids + self.runner.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids valid_ngram_requests.append(i) diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py index ea9f0f72..1cdbec3c 100644 --- a/vllm_ascend/spec_decode/suffix_proposer.py +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -1,13 +1,11 @@ import torch from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.suffix_decoding import \ - SuffixDecodingProposer as VllmSuffixDecodingProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): - def __init__(self, vllm_config, device, runner): super().__init__(vllm_config) self.name = SpecDcodeType.SUFFIX @@ -19,27 +17,30 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): pass @torch.inference_mode() - def dummy_run(self, - num_tokens, - with_prefill=None, - in_graph_capturing=None, - num_reqs=None, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False): + def dummy_run( + self, + num_tokens, + with_prefill=None, + in_graph_capturing=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ): pass - def generate_token_ids(self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None) -> list[list[int]]: - draft_token_ids = self.propose(self.runner.input_batch, - valid_sampled_token_ids) + def generate_token_ids( + self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + aux_hidden_states=None, + ) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids) return draft_token_ids