diff --git a/pyproject.toml b/pyproject.toml index 7e90ef2d..3aee3d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,19 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", + + # (8) + "vllm_ascend/ops/__init__.py", + "vllm_ascend/ops/activation.py", + "vllm_ascend/ops/flashcomm2_oshard_manager.py", + "vllm_ascend/ops/layernorm.py", + "vllm_ascend/ops/mla.py", + "vllm_ascend/ops/mm_encoder_attention.py", + "vllm_ascend/ops/register_custom_ops.py", + "vllm_ascend/ops/vocab_parallel_embedding.py", + "vllm_ascend/ops/weight_prefetch.py", + "vllm_ascend/spec_decode/**", + ] [tool.ruff.lint] diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 46c0ceff..aadb6164 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -27,7 +27,8 @@ if HAS_TRITON: import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul -from vllm_ascend.ops.rotary_embedding import AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding +from vllm_ascend.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) class dummyFusionOp: @@ -39,13 +40,23 @@ class dummyFusionOp: def register_dummy_fusion_op() -> None: torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") - torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") - torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(name="static_scaled_fp8_quant") - torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(name="dynamic_scaled_fp8_quant") - torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(name="dynamic_per_token_scaled_fp8_quant") - torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(name="rms_norm_static_fp8_quant") - torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(name="fused_add_rms_norm_static_fp8_quant") - torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(name="rms_norm_dynamic_per_token_quant") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( + name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( + name="static_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( + name="dynamic_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + name="dynamic_per_token_scaled_fp8_quant") + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( + name="rms_norm_static_fp8_quant") + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + name="fused_add_rms_norm_static_fp8_quant") + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( + name="rms_norm_dynamic_per_token_quant") -__all__ = ["AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", "AscendDeepseekScalingRotaryEmbedding"] +__all__ = [ + "AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", + "AscendDeepseekScalingRotaryEmbedding" +] diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index ac8730af..a22513fe 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -17,11 +17,10 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul - from vllm_ascend.utils import get_weight_prefetch_method - class AscendQuickGELU(QuickGELU): + def forward_oot(self, x: torch.tensor) -> torch.Tensor: import torch_npu @@ -30,6 +29,7 @@ class AscendQuickGELU(QuickGELU): class AscendSiluAndMul(SiluAndMul): + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu diff --git a/vllm_ascend/ops/flashcomm2_oshard_manager.py b/vllm_ascend/ops/flashcomm2_oshard_manager.py index 54f38617..d95748e6 100644 --- a/vllm_ascend/ops/flashcomm2_oshard_manager.py +++ b/vllm_ascend/ops/flashcomm2_oshard_manager.py @@ -1,14 +1,11 @@ -from typing import Any +from typing import Any, Dict, Optional from vllm.model_executor.models.utils import extract_layer_index from vllm_ascend.distributed.parallel_state import get_shard_weight_group from vllm_ascend.ops.layer_shard_linear import ( - is_hidden_layer, - post_process_after_loading_for_shard_weight_series, - reach_layer_for_shard_weight_series, - register_layer_to_shard_weight_series, -) + is_hidden_layer, post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series) from vllm_ascend.utils import flashcomm2_enable, o_shard_enable @@ -29,7 +26,7 @@ class Flashcomm2OShardManager: """ def __init__(self): - self._shard_layers: dict[int, Any] = {} + self._shard_layers: Dict[int, Any] = {} def flashcomm2_oshard_enable(self): return flashcomm2_enable() and o_shard_enable() @@ -55,10 +52,12 @@ class Flashcomm2OShardManager: self._shard_layers[layer_idx] = layer register_layer_to_shard_weight_series( - series_name="o_proj", group=get_shard_weight_group(), layer=layer, prefetch_step=prefetch_step - ) + series_name="o_proj", + group=get_shard_weight_group(), + layer=layer, + prefetch_step=prefetch_step) - def get_layer(self, layer_idx: int) -> Any | None: + def get_layer(self, layer_idx: int) -> Optional[Any]: """Safely retrieves a registered layer by its index. Args: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index b3a503e4..fa7ef0ae 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,53 +15,56 @@ # This file is a part of the vllm-ascend project. # +from typing import Optional, Tuple, Union import torch from torch import nn from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated - from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu -from vllm_ascend.utils import enable_custom_op, get_weight_prefetch_method +from vllm_ascend.utils import enable_custom_op +from vllm_ascend.utils import get_weight_prefetch_method class AscendRMSNorm(RMSNorm): + def __init__( self, hidden_size: int, eps: float = 1e-6, - var_hidden_size: int | None = None, + var_hidden_size: Optional[int] = None, has_weight: bool = True, - dtype: torch.dtype | None = None, + dtype: Optional[torch.dtype] = None, ) -> None: super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) vllm_config = get_current_vllm_config() self.bias = None # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and any( - "norm.bias" in name for name in vllm_config.quant_config.quant_description - ): - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) + if vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) def forward_oot( self, x: torch.Tensor, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu if residual is not None: if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( - x, residual, self.weight, self.bias, self.variance_epsilon - ) + x, residual, self.weight, self.bias, self.variance_epsilon) else: - x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) if self.bias is not None: x.add_(self.bias) return x, residual - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) if self.bias is not None: x.add_(self.bias) @@ -72,30 +75,42 @@ class AscendRMSNorm(RMSNorm): class AscendGemmaRMSNorm(GemmaRMSNorm): + def forward_oot( self, x: torch.Tensor, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type if residual is not None: if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( - x, residual, 1.0 + self.weight, None, self.variance_epsilon - ) + x, residual, 1.0 + self.weight, None, + self.variance_epsilon) else: - x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, 1.0 + self.weight, self.variance_epsilon) return x, residual - x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) return x - class LayerNormFn(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + def forward(ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ x_shape_og = x.shape # reshape input data into 2D tensor @@ -128,16 +143,16 @@ class LayerNormFn(torch.autograd.Function): ctx.is_rms_norm = is_rms_norm return y.reshape(x_shape_og) - class AscendRMSNormGated(RMSNormGated): + def __init__( self, hidden_size, eps: float = 1e-5, - group_size: int | None = None, + group_size: Optional[int] = None, norm_before_gate: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). @@ -155,5 +170,7 @@ class AscendRMSNormGated(RMSNormGated): torch.nn.init.ones_(self.weight) def forward_oot(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, self.norm_before_gate, True) + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, + self.norm_before_gate, True) \ No newline at end of file diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 64d5d36a..bf3bda6c 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -19,13 +19,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import torch from torch import nn from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.mla import (MLAModules, + MultiHeadLatentAttentionWrapper) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata # type: ignore @@ -34,20 +36,20 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import vllm_version_is if vllm_version_is("v0.15.0"): - from vllm.attention.layer import MLAAttention # type: ignore + from vllm.attention.layer import MLAAttention # type: ignore else: from vllm.model_executor.layers.attention import MLAAttention class IndexerWrapper(nn.Module): - """ + ''' A wrapper of Indexer for Deepseek v3.2. This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py. It wraps the original Indexer, inherits its module weights (including wq_b, wk, weights_proj, k_norm) - while deletes the unused topk_indices_buffer and k_cache to save memory. + while deletes the unused topk_indices_buffer and k_cache to save memory. TODO: Will be removed once original Indexer supports different quantization methods. - """ + ''' def __init__(self, vllm_indexer: nn.Module) -> None: super().__init__() @@ -69,6 +71,7 @@ class IndexerWrapper(nn.Module): class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): + def __init__( self, hidden_size: int, @@ -77,11 +80,11 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, - q_lora_rank: int | None, + q_lora_rank: Optional[int], kv_lora_rank: int, mla_modules: MLAModules, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: nn.Module.__init__(self) @@ -94,7 +97,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): self.v_head_dim = v_head_dim self.prefix = prefix hf_config = get_current_vllm_config().model_config.hf_text_config - self.enable_shared_expert_dp = get_ascend_config().enable_shared_expert_dp + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp self.tp_size = get_tensor_model_parallel_world_size() self.layers = hf_config.num_hidden_layers if mla_modules.indexer is not None: @@ -130,7 +134,6 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): def wrapped_process_weights(act_dtype: torch.dtype): from vllm_ascend.attention.sfa_v1 import AscendSFAImpl - if not isinstance(self.mla_attn.impl, AscendSFAImpl): original_process_weights(act_dtype) self.mla_attn.impl.process_weights_after_loading(act_dtype) @@ -143,17 +146,19 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): compilation_config.static_forward_context[prefix] = self def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, - ) -> torch.Tensor: + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: need_gather_q_kv = get_forward_context().sp_enabled output_shape = hidden_states.shape # FIXME: This does not seem right, should make sure the buffer is fixed - output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) - torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, + self.prefix) output = output.view(-1, output_shape[-1]) return output @@ -171,9 +176,9 @@ def mla_forward( else: attn_metadata = forward_context.attn_metadata kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] - self.mla_attn.impl.forward( - self.mla_attn.layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv, output - ) + self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states, + kv_cache, attn_metadata, need_gather_q_kv, + output) return diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 7beb7b50..aa9a0737 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -19,15 +19,18 @@ import einops import torch import torch.nn.functional as F import torch_npu +from vllm.config import MultiModalConfig from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore import vllm_ascend.envs as envs_ascend + MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight class AscendMMEncoderAttention(MMEncoderAttention): + def __init__( self, num_heads: int, @@ -79,12 +82,13 @@ class AscendMMEncoderAttention(MMEncoderAttention): return query, key, value def forward_oot( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor + | None = None, # Only used for Flash Attention ): bsz, q_len = query.size()[:2] kv_len = key.size(1) @@ -93,7 +97,9 @@ class AscendMMEncoderAttention(MMEncoderAttention): # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) - enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE + enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL + and self.head_size > MIN_PAD_SIZE + and self.head_size < MAX_PAD_SIZE) if enable_pad: origin_shape = q.shape[-1] @@ -108,7 +114,10 @@ class AscendMMEncoderAttention(MMEncoderAttention): context_layer = torch.empty_like(q) if cu_seqlens is None: - cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device) + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) cu_seqlens = torch.diff(cu_seqlens).to("cpu") @@ -128,7 +137,11 @@ class AscendMMEncoderAttention(MMEncoderAttention): context_layer = context_layer[..., :origin_shape] if is_reshaped: - context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous() + context_layer = einops.rearrange(context_layer, + "(b s) h d -> b s h d", + b=bsz).contiguous() else: - context_layer = einops.rearrange(context_layer, "(b s) h d -> b s (h d)", b=bsz).contiguous() + context_layer = einops.rearrange(context_layer, + "(b s) h d -> b s (h d)", + b=bsz).contiguous() return context_layer diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 2027369b..11d6e8ed 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -1,25 +1,24 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.distributed import ( - get_dp_group, - get_ep_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter, -) +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context from vllm.utils.torch_utils import direct_register_custom_op +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream +from typing import Optional, Tuple +from vllm_ascend.ops.triton.rope import rope_forward_triton - -def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -27,7 +26,8 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch if x.size(0) != residual.size(0): sp_enabled = forward_context.sp_enabled - assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled" + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -38,7 +38,10 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch return residual -def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_all_gather_and_maybe_unpad_impl( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -56,20 +59,24 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c x = get_ep_group().all_gather(x, 0) # unpad num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu - result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) + result = torch.empty( + (num_tokens_across_dp_cpu.sum(), *x.shape[1:]), + device=x.device, + dtype=x.dtype) dp_size = get_dp_group().world_size x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] - result[offset : offset + num_tokens_dp] = x[idx, :num_tokens_dp] + result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp] offset += num_tokens_dp x = result return x -def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_pad_and_reduce_impl(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -87,44 +94,63 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor else: # padding dp_size = get_dp_group().world_size - num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu - padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype) + num_tokens_across_dp_cpu = \ + get_forward_context().dp_metadata.num_tokens_across_dp_cpu + padded_x = torch.empty( + (dp_size, forward_context.padded_length, *x.shape[1:]), + device=x.device, + dtype=x.dtype) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] - padded_x[idx, :num_tokens_dp] = x[offset : offset + num_tokens_dp] + padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp] offset += num_tokens_dp - return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), 0) + return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), + 0) -def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_all_gather_and_maybe_unpad_fake( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: + if get_forward_context().sp_enabled and label: return torch.empty( - (x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype - ) + (x.shape[0] * get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) return x -def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: +def _maybe_pad_and_reduce_fake(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: if get_forward_context().sp_enabled: return torch.empty( - (x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype - ) + (x.shape[0] // get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) return x -def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None: +def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, + max_weight_size: int) -> None: calculation_stream = torch_npu.npu.current_stream() weight_prefetch_stream = prefetch_stream() weight_prefetch_stream.wait_stream(calculation_stream) with npu_stream_switch(weight_prefetch_stream): - maybe_npu_prefetch(inputs=weight, dependency=start_flag, max_size=max_weight_size) + maybe_npu_prefetch(inputs=weight, + dependency=start_flag, + max_size=max_weight_size) -def _prefetch_preprocess_impl_fake(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None: +def _prefetch_preprocess_impl_fake(weight: torch.Tensor, + start_flag: torch.Tensor, + max_weight_size: int) -> None: return @@ -138,16 +164,20 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: return -def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor: +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled: + if moe_comm_type in { + MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2 + } or forward_context.sp_enabled: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) -def _matmul_and_reduce_impl(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor: +def _matmul_and_reduce_impl(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.custom_op is not None @@ -157,15 +187,16 @@ def _matmul_and_reduce_impl(input_parallel: torch.Tensor, layer_name: str) -> to return output -def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor: +def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] num_tokens = input_parallel.size(0) if forward_context.sp_enabled: num_tokens = num_tokens // self.tp_size - output = torch.empty( - size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype - ) + output = torch.empty(size=(num_tokens, self.output_size_per_partition), + device=input_parallel.device, + dtype=input_parallel.dtype) return output @@ -176,96 +207,77 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) # pass input_scale and input_scale_reciprocal at the same time to avoid redundant # reciprocal calculation in fussion pass. We shall remove this once # aclnnAddRmsNormQuantV2 supports div_moe=False. -def _quantize_impl( - in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor -) -> torch.Tensor: - return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) - - -def _quantize_impl_fake( - in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor -) -> torch.Tensor: - return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) +def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor, + input_scale_reciprocal: torch.Tensor, + input_offset: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, + input_offset, torch.qint8, -1, False) +def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, + input_scale_reciprocal: torch.Tensor, + input_offset: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, + input_offset, torch.qint8, -1, False) def _rope_forward_triton_fake( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_dim: int = -1, - is_neox_style: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: + is_neox_style: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(q), torch.empty_like(k) +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: x, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="maybe_chunk_residual", - op_func=_maybe_chunk_residual_impl, - fake_impl=lambda x, residual: x, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=_maybe_all_gather_and_maybe_unpad_fake, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="maybe_all_gather_and_maybe_unpad", - op_func=_maybe_all_gather_and_maybe_unpad_impl, - fake_impl=_maybe_all_gather_and_maybe_unpad_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="maybe_pad_and_reduce", + op_func=_maybe_pad_and_reduce_impl, + fake_impl=_maybe_pad_and_reduce_fake, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="prefetch_preprocess", - op_func=_prefetch_preprocess_impl, - fake_impl=_prefetch_preprocess_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="prefetch_preprocess", + op_func=_prefetch_preprocess_impl, + fake_impl=_prefetch_preprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="prefetch_preprocess", - op_func=_prefetch_preprocess_impl, - fake_impl=_prefetch_preprocess_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="prefetch_postprocess", + op_func=_prefetch_postprocess_impl, + fake_impl=_prefetch_postprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="prefetch_postprocess", - op_func=_prefetch_postprocess_impl, - fake_impl=_prefetch_postprocess_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="maybe_all_reduce_tensor_model_parallel", - op_func=_maybe_all_reduce_tensor_model_parallel_impl, - fake_impl=lambda x: x, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="matmul_and_reduce", + op_func=_matmul_and_reduce_impl, + fake_impl=_matmul_and_reduce_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") -direct_register_custom_op( - op_name="matmul_and_reduce", - op_func=_matmul_and_reduce_impl, - fake_impl=_matmul_and_reduce_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) - -direct_register_custom_op( - op_name="quantize", - op_func=_quantize_impl, - fake_impl=_quantize_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) -direct_register_custom_op( - op_name="rope_forward_triton", - op_func=rope_forward_triton, - fake_impl=_rope_forward_triton_fake, - mutates_args=[], - dispatch_key="PrivateUse1", -) +direct_register_custom_op(op_name="quantize", + op_func=_quantize_impl, + fake_impl=_quantize_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") +direct_register_custom_op(op_name="rope_forward_triton", + op_func=rope_forward_triton, + fake_impl=_rope_forward_triton_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index c23cfdc2..8fb11724 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -15,6 +15,7 @@ # limitations under the License. # +from typing import Optional, Tuple import torch from torch import nn @@ -23,20 +24,14 @@ from vllm.distributed import divide from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, - method_has_implemented_embedding, -) + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - UnquantizedEmbeddingMethod, - VocabParallelEmbedding, - pad_vocab_size, -) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod, + VocabParallelEmbedding, pad_vocab_size) from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import get_embed_tp_group, get_lmhead_tp_group +from vllm_ascend.distributed.parallel_state import (get_embed_tp_group, + get_lmhead_tp_group) from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable @@ -47,16 +42,14 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): Added the feature of lmheadTP in pure dp scenario """ - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - params_dtype: torch.dtype | None = None, - org_num_embeddings: int | None = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): nn.Module.__init__(self) self.forward_type = None if lmhead_tp_enable() and "head" in prefix: @@ -74,20 +67,18 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size) + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, self.padding_size - ) + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) assert self.org_vocab_size_padded <= self.num_embeddings_padded - self.shard_indices = self._get_indices( - self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, - self.tp_rank, - self.tp_size, - ) + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tp_rank, self.tp_size) self.embedding_dim = embedding_dim quant_method = None if quant_config is not None: @@ -99,12 +90,12 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self) is VocabParallelEmbedding - quant_method_implements_embedding = method_has_implemented_embedding(type(quant_method)) + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod." - ) + "the 'embedding' method, see UnquantizedEmbeddingMethod.") self.quant_method: QuantizeMethodBase = quant_method @@ -113,47 +104,46 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): self.params_dtype = params_dtype # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.tp_size) - assert self.shard_indices.num_elements_padded == self.num_embeddings_per_partition + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - self.shard_indices.org_vocab_start_index - ) + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index - ) + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) - self.quant_method.create_weights( - self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader, - ) + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) def _get_masked_input_and_mask( - self, - input_: torch.Tensor, - org_vocab_start_index: int, - org_vocab_end_index: int, - num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int, - ) -> tuple[torch.Tensor, torch.Tensor]: + self, input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. if added_vocab_start_index == added_vocab_end_index: - valid_offset = org_vocab_start_index * org_vocab_mask + valid_offset = (org_vocab_start_index * org_vocab_mask) vocab_mask = org_vocab_mask else: - added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index) - added_offset = ( - added_vocab_start_index - (org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding - ) - valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - + org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) vocab_mask = org_vocab_mask | added_vocab_mask # Adapt end. input_ = vocab_mask * (input_ - valid_offset) @@ -168,15 +158,14 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): def _forward_embed_tp(self, input_): complete_input = self.comm_group.all_gather(input_, dim=0) masked_input, input_mask = self._get_masked_input_and_mask( - complete_input, - self.shard_indices.org_vocab_start_index, + complete_input, self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index, - ) + self.shard_indices.added_vocab_end_index) # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) + output_parallel = self.quant_method.embedding(self, + masked_input.long()) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output = self.comm_group.reduce_scatter(output_parallel, dim=0) output = output.view(input_.shape[0], -1) @@ -186,17 +175,16 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): if self.tp_size > 1: # Build the mask. masked_input, input_mask = self._get_masked_input_and_mask( - input_, - self.shard_indices.org_vocab_start_index, + input_, self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index, - ) + self.shard_indices.added_vocab_end_index) else: masked_input = input_ # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) + output_parallel = self.quant_method.embedding(self, + masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) @@ -209,31 +197,29 @@ class AscendParallelLMHead(ParallelLMHead): """ Register ParallelLMHead as a custom op for Ascend.""" - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype: torch.dtype | None = None, - org_num_embeddings: int | None = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - AscendVocabParallelEmbedding.__init__( - self, num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix - ) + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + AscendVocabParallelEmbedding.__init__(self, num_embeddings, + embedding_dim, params_dtype, + org_num_embeddings, padding_size, + quant_config, prefix) self.quant_config = quant_config if bias: - self.bias = Parameter(torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)) - set_weight_attrs( - self.bias, - { - "output_dim": 0, - "weight_loader": self.weight_loader, - }, - ) + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) else: self.register_parameter("bias", None) @@ -248,41 +234,48 @@ class AscendLogitsProcessor(LogitsProcessor): self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: torch.Tensor | None = None, - ) -> torch.Tensor | None: + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: if lmhead_tp_enable(): - return self._get_logits_lmheadtp(hidden_states, lm_head, embedding_bias) + return self._get_logits_lmheadtp(hidden_states, lm_head, + embedding_bias) else: - return self._get_logits_normal(hidden_states, lm_head, embedding_bias) + return self._get_logits_normal(hidden_states, lm_head, + embedding_bias) def _get_logits_lmheadtp( self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: torch.Tensor | None, - ) -> torch.Tensor | None: + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: # Gather hidden states from all devices in tensor parallel group - gathered_hidden_states = get_lmhead_tp_group().all_gather(hidden_states, dim=0) - local_logits = lm_head.quant_method.apply(lm_head, gathered_hidden_states, bias=embedding_bias) + gathered_hidden_states = get_lmhead_tp_group().all_gather( + hidden_states, dim=0) + local_logits = lm_head.quant_method.apply(lm_head, + gathered_hidden_states, + bias=embedding_bias) # Gather logits for tensor parallel logits = get_lmhead_tp_group().all_to_all(local_logits) # Remove paddings in vocab (if any) if logits is not None: - logits = logits[..., : self.org_vocab_size] + logits = logits[..., :self.org_vocab_size] return logits def _get_logits_normal( self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: torch.Tensor | None, - ) -> torch.Tensor | None: - local_logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + local_logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) # Gather logits for tensor parallel logits = self._gather_logits(local_logits) # Remove paddings in vocab (if any) if logits is not None: - logits = logits[..., : self.org_vocab_size] + logits = logits[..., :self.org_vocab_size] return logits diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index e53e899b..e41390ee 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -2,18 +2,19 @@ from dataclasses import dataclass, field import torch import torch_npu -from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context +from vllm.config import get_current_vllm_config +from vllm.logger import logger from vllm_ascend.ascend_config import WeightPrefetchConfig -from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear +from vllm_ascend.ops.linear import (AscendQKVParallelLinear, + AscendRowParallelLinear) from vllm_ascend.utils import is_moe_model SUPPORTED_MODULES = ["attn", "mlp", "moe"] MOE_PREFETCH_TOKEN_THRESHOLD = 96 MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024 - @dataclass class ModuleWeightPrefetchConfig: module_name: str @@ -23,7 +24,10 @@ class ModuleWeightPrefetchConfig: linear_prefix_map: dict = field(default_factory=dict) def __post_init__(self) -> None: - self.prefetch_ratio = {prefix: ratio for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1} + self.prefetch_ratio = { + prefix: ratio + for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 + } assert self.module_name in SUPPORTED_MODULES, ( f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" @@ -37,7 +41,6 @@ class WeightPrefetchMethod: """ Unified weight prefetch method. """ - is_moe: bool = True MLP_GATE_UP: str = "gate_up" MLP_DOWN: str = "down" @@ -51,53 +54,60 @@ class WeightPrefetchMethod: self.attn = ModuleWeightPrefetchConfig( module_name="attn", enable=weight_prefetch_config.enabled, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("attn", {}) or {"qkv": 1.0, "o": 1.0}, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "attn", {}) or {'qkv': 1.0, 'o': 1.0}, linear_prefix_map={ AscendQKVParallelLinear.__name__: "qkv", AscendRowParallelLinear.__name__: "o", - }, - ) + }) self.moe = ModuleWeightPrefetchConfig( module_name="moe", enable=weight_prefetch_config.enabled and self.is_moe, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("moe", {}) or {"gate_up": 0.8}, - ) + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "moe", {}) or {'gate_up': 0.8}) self.mlp = ModuleWeightPrefetchConfig( module_name="mlp", enable=weight_prefetch_config.enabled and not self.is_moe, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0}, - ) + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "mlp", {}) or {'gate_up': 1.0, 'down': 1.0}) self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config def maybe_prefetch_attn_weight_preprocess( - self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor - ) -> None: + self, layer_cls_name: str, weight: torch.Tensor, + start_flag: torch.Tensor) -> None: if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") - weight_size = weight.data.element_size() * weight.data.numel() * self.attn.prefetch_ratio.get(prefix, 0) + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.attn.prefetch_ratio.get(prefix, 0) - torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=start_flag, max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=start_flag, + max_weight_size=int(weight_size)) - def maybe_prefetch_attn_weight_postprocess(self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: + def maybe_prefetch_attn_weight_postprocess( + self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return torch.ops.vllm.prefetch_postprocess(stop_flag) def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix): - self.moe.is_active_this_forward = ( - hidden_states.shape[0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False - ) + self.moe.is_active_this_forward = hidden_states.shape[ + 0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False if not self.moe.is_active_this_forward: return forward_context = get_forward_context() # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. - weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight - weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0) - torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size)) + weight = forward_context.model_instance.model.layers[ + forward_context.layer_idx - 1].mlp.experts.w13_weight + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.moe.prefetch_ratio.get(prefix, 0) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=None, + max_weight_size=int(weight_size)) def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): if not self.moe.is_active_this_forward: @@ -106,9 +116,7 @@ class WeightPrefetchMethod: torch.ops.vllm.prefetch_postprocess(stop_flag) # x_dependency only eager mode can pass None - def maybe_prefetch_mlp_weight_preprocess( - self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None - ): + def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None): if not self.mlp.enable and not self.mlp_pre_version_compatibale_config: self.mlp.is_active_this_forward = False return @@ -132,26 +140,24 @@ class WeightPrefetchMethod: else: raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}") - def _maybe_prefetch_mlp_gate_up_weight_preprocess( - self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None - ): + def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None): if not curr_layer_prefix: raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight") # start point of gate_up_proj weight prefetch - if curr_layer_prefix.split(".")[-2] == "self_attn": + if curr_layer_prefix.split('.')[-2] == "self_attn": model_instance = forward_context.model_instance - layer_idx = int(curr_layer_prefix.split(".")[2]) + layer_idx = int(curr_layer_prefix.split('.')[2]) weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) else: - weight_size = ( - weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) - ) + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE - torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=x_dependency, + max_weight_size=int(weight_size)) forward_context.prefetch_mlp_gate_up_proj = True def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext): @@ -161,12 +167,12 @@ class WeightPrefetchMethod: if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) else: - weight_size = ( - weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) - ) + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE - torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=x_dependency, + max_weight_size=int(weight_size)) forward_context.prefetch_mlp_down_proj = True forward_context.layer_idx += 1 @@ -179,15 +185,19 @@ class WeightPrefetchMethod: except AssertionError: return - if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj: + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: torch.ops.vllm.prefetch_postprocess(stop_flag) forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False -def maybe_npu_prefetch( - inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True -) -> None: +def maybe_npu_prefetch(inputs: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + offset: int = 0, + *, + enabled: bool = True) -> None: if not enabled: return input_size = inputs.element_size() * inputs.numel() diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6a1a66c9..89f558db 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -30,9 +30,10 @@ def get_spec_decode_method(method, vllm_config, device, runner): return EagleProposer(vllm_config, device, runner) elif method == "mtp": return MtpProposer(vllm_config, device, runner) - elif method == "suffix": + elif method == 'suffix': return SuffixDecodingProposer(vllm_config, device, runner) elif method == "medusa": return MedusaProposer(vllm_config, device, runner) else: - raise ValueError(f"Unknown speculative decoding method: {method}") + raise ValueError("Unknown speculative decoding method: " + f"{method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 022e2253..986d7a71 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1,21 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from collections.abc import Callable -from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any +from contextlib import contextmanager, nullcontext +from typing import Any, Callable, ContextManager, Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config -from vllm.distributed.parallel_state import ( - get_pp_group, - get_tp_group, - get_world_group, - init_model_parallel_group, - patch_tensor_parallel_group, -) +from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, + get_world_group, + init_model_parallel_group, + patch_tensor_parallel_group) from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -38,11 +35,13 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params +from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + update_full_graph_params) from vllm_ascend.ops.rotary_embedding import update_cos_sin -from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel +from vllm_ascend.ops.triton.spec_decode.utils import \ + prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is +from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled, lmhead_tp_enable, vllm_version_is # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -77,21 +76,28 @@ def split_inputs_tp_to_sp(hidden_states, out): # copy only hidden_states in current rank hidden_states_curr_rank = hidden_states[start:end] - out[: hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank + out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank return out[:padded_num_tokens_per_rank] class EagleProposer(VllmEagleProposer): - _runnable: ACLGraphWrapper | Callable - def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + _runnable: Union[ACLGraphWrapper, Callable] + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): super().__init__(vllm_config, device, runner) self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.decode_threshold = 1 + self.num_speculative_tokens - self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32) - self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) + self.query_start_loc = self.runner._make_buffer( + self.runner.max_num_reqs + 1, dtype=torch.int32) + self.arange_cpu = torch.arange(self.arange.shape[0], + device="cpu", + dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) self.enable_shared_expert_dp = shared_expert_dp_enabled() @@ -102,11 +108,11 @@ class EagleProposer(VllmEagleProposer): self.dcp_rank = self.runner.dcp_rank self.full_indices = range( - self.runner.max_num_tokens * self.pcp_size * self.dcp_size - + self.pcp_size * self.dcp_size * self.runner.max_num_reqs - ) + self.runner.max_num_tokens * self.pcp_size * self.dcp_size + + self.pcp_size * self.dcp_size * self.runner.max_num_reqs) - self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") + self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, + "index_topk") # NOTE: # `draft_tensor_parallel_size` does not take effect for Eagle: # the draft model uses the same TP size as the target model in practice. @@ -116,7 +122,8 @@ class EagleProposer(VllmEagleProposer): # or the same as target model. # TODO(zhaomingyu13): If we want to adapt to the case where draft model tp # is not 1 and differs from target model, this part should be rewritten. - if vllm_config.parallel_config.tensor_parallel_size != self.speculative_config.draft_tensor_parallel_size: + if (vllm_config.parallel_config.tensor_parallel_size + != self.speculative_config.draft_tensor_parallel_size): tp_group = init_model_parallel_group( [[get_world_group().rank]], get_world_group().rank, @@ -128,37 +135,47 @@ class EagleProposer(VllmEagleProposer): else: self.tp_group_context = nullcontext() - self.use_cuda_graph = self.runner._use_aclgraph() and not self.speculative_config.enforce_eager + self.use_cuda_graph = (self.runner._use_aclgraph() + and not self.speculative_config.enforce_eager) if self.method == "mtp": self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling # TODO: Remove it when the bug of fx-graph is solved - self.maybe_eager_context: AbstractContextManager[Any] = nullcontext() + self.maybe_eager_context: ContextManager[Any] = nullcontext() if not self.use_cuda_graph and enable_sp(vllm_config): self.maybe_eager_context = _maybe_eager_context(vllm_config) self.last_token_indices = torch.zeros( - self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device - ) - slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device) + slot_mapping_lens = self.runner.max_num_tokens + \ + 2 * self.pcp_size * self.runner.max_num_reqs self.slot_mapping_group = [ - torch.zeros(slot_mapping_lens, dtype=torch.int32, device=device, pin_memory=self.runner.pin_memory) - for _ in range(self.num_speculative_tokens) - ] + torch.zeros( + slot_mapping_lens, dtype=torch.int32, device=device, + pin_memory=self.runner.pin_memory) + for _ in range(self.num_speculative_tokens)] self._runnable = self._run_merged_draft def load_model(self, model: nn.Module) -> None: - target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) - target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase).keys()) + target_indexer_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache).keys()) with self.maybe_eager_context: - self.model = get_model( - vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config - ) + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys() - draft_attn_layer = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache).keys() + draft_attn_layer = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase).keys() draft_attn_layer_names = draft_attn_layer - target_attn_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names @@ -167,25 +184,30 @@ class EagleProposer(VllmEagleProposer): self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.piece_all_attn_layer_name = [] for _ in range(self.num_speculative_tokens): - self.piece_all_attn_layer_name.append([name for name in self.attn_layer_names]) + self.piece_all_attn_layer_name.append([ + name for name in self.attn_layer_names]) self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.piece_all_attn_layer_name = [] for _ in range(self.num_speculative_tokens): - self.piece_all_attn_layer_name.append([name for name in self.attn_layer_names]) + self.piece_all_attn_layer_name.append([ + name for name in self.attn_layer_names]) if supports_multimodal(model): # handle multimodality if self.get_model_name(model) in [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen3VLForConditionalGeneration", - "Qwen3VLMoeForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", ]: self.model.config.image_token_index = model.config.image_token_id - elif self.get_model_name(model) == "PixtralForConditionalGeneration": - self.model.config.image_token_index = model.config.vision_config.image_token_id + elif self.get_model_name( + model) == "PixtralForConditionalGeneration": + self.model.config.image_token_index = ( + model.config.vision_config.image_token_id) else: - self.model.config.image_token_index = model.config.image_token_index + self.model.config.image_token_index = ( + model.config.image_token_index) target_language_model = model.get_language_model() else: target_language_model = model @@ -197,7 +219,9 @@ class EagleProposer(VllmEagleProposer): elif hasattr(target_language_model.model, "embedding"): target_embed_tokens = target_language_model.model.embedding else: - raise AttributeError("Target model does not have 'embed_tokens' or 'embedding' attribute") + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) # If pp>1, the weights of mtp and the main model's embedding are not on the same device. # check if mtp model use main model's embedding and LMhead share_embeddings = False @@ -234,7 +258,10 @@ class EagleProposer(VllmEagleProposer): else: # MTP model share_embeddings = True - logger.info("Detected MTP model. Sharing target model embedding weights with the draft model.") + logger.info( + "Detected MTP model. " + "Sharing target model embedding weights with the draft model." + ) if share_embeddings: if hasattr(self.model.model, "embed_tokens"): @@ -242,7 +269,7 @@ class EagleProposer(VllmEagleProposer): self.model.model.embed_tokens = target_embed_tokens else: logger.info( - "Since PP > 1 or other reasons the model head loaded its own vocab embedding" + "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \ " weights instead of sharing them with the target model." ) # share lm_head with the target model if needed @@ -255,19 +282,24 @@ class EagleProposer(VllmEagleProposer): else: self.model.lm_head = model.lm_head - if self.method == "mtp" and self.vllm_config.model_config.is_deepseek_mla: + if self.method == "mtp" and \ + self.vllm_config.model_config.is_deepseek_mla: for _, layer_module in self.model.model.layers.items(): - if torch.equal(layer_module.shared_head.head.weight, model.lm_head.weight): + if torch.equal(layer_module.shared_head.head.weight, + model.lm_head.weight): layer_module.shared_head.head = model.lm_head - if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph: + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and self.use_cuda_graph: self.update_stream = torch.npu.Stream() if self.method == "mtp": - self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) else: - self._runnable = ACLGraphWrapper( - self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL - ) + self._runnable = ACLGraphWrapper(self._run_merged_draft, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. @@ -281,23 +313,22 @@ class EagleProposer(VllmEagleProposer): return copy.copy(attn_metadata) @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: torch.Tensor | None = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ): + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False): ( num_tokens, num_tokens_across_dp, _, - ) = self.runner._sync_metadata_across_dp(num_tokens, is_draft_model=True) + ) = self.runner._sync_metadata_across_dp(num_tokens, + is_draft_model=True) # update global cos, sin update_cos_sin(self._get_positions(num_tokens)) @@ -305,15 +336,19 @@ class EagleProposer(VllmEagleProposer): multi_steps_attn_metadata = [] if not self.use_cuda_graph: aclgraph_runtime_mode = CUDAGraphMode.NONE - if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(self.runner.attn_groups) > 0: - num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - self.query_start_loc.cpu[: num_reqs + 1] = torch.tensor( - [0] + self.runner.actual_seq_lengths_q[:num_reqs], device="cpu", dtype=torch.int32 - ) + if aclgraph_runtime_mode == CUDAGraphMode.FULL and len( + self.runner.attn_groups) > 0: + num_computed_tokens_cpu = ( + self.runner.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs]) + self.query_start_loc.cpu[:num_reqs + 1] = torch.tensor( + [0] + self.runner.actual_seq_lengths_q[:num_reqs], + device="cpu", + dtype=torch.int32) self.query_start_loc.copy_to_gpu() common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], seq_lens_cpu=self.runner.seq_lens.cpu, seq_lens=self.runner.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, @@ -322,9 +357,11 @@ class EagleProposer(VllmEagleProposer): max_query_len=self.num_speculative_tokens + 1, num_computed_tokens_cpu=num_computed_tokens_cpu, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor()[:num_reqs], + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor()[:num_reqs], # This is used to hold a position. - slot_mapping=self.runner.input_batch.block_table[0].slot_mapping.gpu, + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping.gpu, positions=self.runner.positions.gpu, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, @@ -334,33 +371,35 @@ class EagleProposer(VllmEagleProposer): builder = self.runner.attn_groups[0][0].get_metadata_builder() # update the tensor's address for each step. for draft_step in range(self.num_speculative_tokens): - common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata) + common_attn_metadata = self.shallow_copy_metadata( + common_attn_metadata) # Set the real slot_mapping. common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] attn_metadata_eagle = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.ChunkedPrefill - ) + common_attn_metadata, AscendAttentionState.ChunkedPrefill) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata_eagle multi_steps_attn_metadata.append(per_layer_attn_metadata) + model_input_ids = self.input_ids[:num_tokens] model_positions = self._get_positions(num_tokens) + model_previous_hidden_states = self.hidden_states[:num_tokens] batch_size = num_tokens // (self.num_speculative_tokens + 1) with set_ascend_forward_context( - multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=0, - in_profile_run=is_profile, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True, - draft_attn_metadatas=multi_steps_attn_metadata, - ): + multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata): + if not vllm_version_is("v0.15.0"): # Reset MOE layer index before first model call forward_context = get_forward_context() @@ -378,8 +417,11 @@ class EagleProposer(VllmEagleProposer): is_dummy=True, ) forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: - self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) + if (forward_context.cudagraph_runtime_mode + == CUDAGraphMode.FULL + and not forward_context.capturing): + self._update_full_graph_params(forward_context, num_tokens, + multi_steps_attn_metadata) def _propose( self, @@ -391,10 +433,11 @@ class EagleProposer(VllmEagleProposer): target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, req_scheduled_tokens=None, long_seq_metadata=None, num_prefill_reqs=0, @@ -402,6 +445,7 @@ class EagleProposer(VllmEagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -410,17 +454,20 @@ class EagleProposer(VllmEagleProposer): if self.method == "eagle3": assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states(target_hidden_states) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids - if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: - num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] + if self.use_cuda_graph and \ + num_tokens <= self.runner.cudagraph_batch_sizes[-1]: + num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[ + num_tokens] else: num_input_tokens = num_tokens @@ -428,13 +475,13 @@ class EagleProposer(VllmEagleProposer): num_input_tokens, num_tokens_across_dp, _, - ) = self.runner._sync_metadata_across_dp(num_input_tokens, is_draft_model=True) + ) = self.runner._sync_metadata_across_dp(num_input_tokens, + is_draft_model=True) has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 if self.use_cuda_graph: - aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( - num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora - ) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora) else: aclgraph_runtime_mode = CUDAGraphMode.NONE batch_descriptor = None @@ -446,30 +493,33 @@ class EagleProposer(VllmEagleProposer): if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) inputs_embeds = self.model.embed_input_ids( - self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed - ) + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed) self.inputs_embeds[:num_tokens] = inputs_embeds inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = self.input_ids[:num_input_tokens] else: inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] # Update slot_mapping for different speculative. # NOTE: Currently, we only remake the slot_mapping, because it's the # only tensor which will be used in current FIA. # Strictly speaking, `query_start_loc`, `seq_lens` should also have # their memory allocated separately for each step just like `slot_mapping`. - slot_mapping_lens = ( - num_input_tokens - if num_input_tokens < common_attn_metadata.slot_mapping.shape[0] - else common_attn_metadata.slot_mapping.shape[0] - ) - self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens]) + slot_mapping_lens = num_input_tokens if num_input_tokens < \ + common_attn_metadata.slot_mapping.shape[0] else \ + common_attn_metadata.slot_mapping.shape[0] + self.slot_mapping_group[0][:slot_mapping_lens].copy_( + common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) # update global cos, sin update_cos_sin(self._get_positions(num_input_tokens)) @@ -486,34 +536,35 @@ class EagleProposer(VllmEagleProposer): # Copy the old attn_metadata and update for draft_step in range(1, self.num_speculative_tokens): - common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( - draft_step, - attn_metadata, - common_attn_metadata, - batch_size, - num_input_tokens, - used_update_positions, - aclgraph_runtime_mode, - ) + common_attn_metadata, attn_metadata = \ + self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata multi_steps_attn_metadata.append(per_layer_attn_metadata) last_token_indices_len = last_token_indices.shape[0] - self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) + self.last_token_indices[:last_token_indices_len].copy_( + last_token_indices) with set_ascend_forward_context( - multi_steps_attn_metadata[0], - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=num_tokens, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True, - draft_attn_metadatas=multi_steps_attn_metadata, - ): + multi_steps_attn_metadata[0], + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=num_tokens, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata): + if not vllm_version_is("v0.15.0"): # Reset MOE layer index for forward pass forward_context = get_forward_context() @@ -526,37 +577,36 @@ class EagleProposer(VllmEagleProposer): last_token_indices=self.last_token_indices[:last_token_indices_len], target_positions=target_positions, inputs_embeds=inputs_embeds, - multi_steps_attn_metadata=multi_steps_attn_metadata, - ) + multi_steps_attn_metadata=multi_steps_attn_metadata) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata) + self._update_full_graph_params(forward_context, + num_input_tokens, + multi_steps_attn_metadata) return draft_token_ids - def _run_merged_draft( - self, - num_input_tokens, - batch_size, - last_token_indices, - target_positions, - inputs_embeds, - multi_steps_attn_metadata, - is_dummy=False, + def _run_merged_draft(self, + num_input_tokens, + batch_size, + last_token_indices, + target_positions, + inputs_embeds, + multi_steps_attn_metadata, + is_dummy=False, ) -> torch.Tensor: - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative - # tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the - # inputs of speculative model. + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. model_input_ids = self.input_ids[:num_input_tokens] model_positions = self._get_positions(num_input_tokens) model_hidden_states = self.hidden_states[:num_input_tokens] - model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) # Expend the remaining moe layers for suiting vllm. forward_context = get_forward_context() - if forward_context and hasattr(forward_context, "remaining_moe_layers"): + if forward_context and hasattr(forward_context, 'remaining_moe_layers'): if self.num_speculative_tokens > 1: moe_layers_needed = len(forward_context.remaining_moe_layers) * self.num_speculative_tokens if len(forward_context.remaining_moe_layers) < moe_layers_needed: @@ -569,7 +619,7 @@ class EagleProposer(VllmEagleProposer): input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, - inputs_embeds=inputs_embeds, + inputs_embeds = inputs_embeds, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -578,15 +628,15 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states - ) + last_hidden_states, model_positions, hidden_states) num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len - ) - last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + self.vllm_config.scheduler_config.max_num_seqs * + self.runner.uniform_decode_query_len) + last_token_indices = nn.functional.pad( + last_token_indices, (0, max_num_reqs_across_dp - num_indices)) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -604,8 +654,9 @@ class EagleProposer(VllmEagleProposer): # Generate the remaining draft tokens. draft_token_ids_tensor = torch.zeros( - (self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device - ) + (self.num_speculative_tokens, *draft_token_ids.shape), + dtype=draft_token_ids.dtype, + device=self.device) draft_token_ids_tensor[0] = draft_token_ids if self.uses_mrope: positions = target_positions[:, last_token_indices] @@ -626,7 +677,7 @@ class EagleProposer(VllmEagleProposer): forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - + # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -640,22 +691,25 @@ class EagleProposer(VllmEagleProposer): # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. if self.uses_mrope: - exceeds_max_model_len = positions[0] >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = positions[ + 0] >= self.vllm_config.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), torch.zeros_like(positions), positions - ) + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), positions) else: exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) + self.inputs_embeds[:batch_size] = self.model.embed_input_ids( + input_ids) input_ids = self.input_ids[:input_batch_size] inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -668,24 +722,22 @@ class EagleProposer(VllmEagleProposer): # Run the model. - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative - # tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the - # inputs of speculative model. + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. model_input_ids = self.input_ids[:input_batch_size] model_positions = self._get_positions(input_batch_size) model_hidden_states = self.hidden_states[:input_batch_size] - model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) - forward_context.attn_metadata = ( - multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None - ) + forward_context.attn_metadata = multi_steps_attn_metadata[draft_step + 1] \ + if multi_steps_attn_metadata else None ret_hidden_states = self.model( input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, - inputs_embeds=inputs_embeds, + inputs_embeds = inputs_embeds, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -694,14 +746,13 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states - ) + last_hidden_states, model_positions, hidden_states) num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len - ) + self.vllm_config.scheduler_config.max_num_seqs * + self.runner.uniform_decode_query_len) last_token_indices = nn.functional.pad( last_token_indices, (0, max_num_reqs_across_dp - num_indices), @@ -723,40 +774,41 @@ class EagleProposer(VllmEagleProposer): draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids - def attn_update_stack_num_spec_norm( - self, - # `draft_step` must start from `1`, no `0` - draft_step, - old_attn_metadata, - old_common_metadata, - batch_size, - input_batch_size, - used_update_positions, - aclgraph_runtime_mode, - ): - assert draft_step > 0 + def attn_update_stack_num_spec_norm(self, + # `draft_step` must start from `1`, no `0` + draft_step, + old_attn_metadata, + old_common_metadata, + batch_size, + input_batch_size, + used_update_positions, + aclgraph_runtime_mode): + + assert(draft_step > 0) common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) if draft_step == 1: - if aclgraph_runtime_mode == CUDAGraphMode.FULL and (pad_size := input_batch_size - batch_size) > 0: + if ( + aclgraph_runtime_mode == CUDAGraphMode.FULL + and (pad_size := input_batch_size - batch_size) > 0 + ): common_attn_metadata.num_reqs = input_batch_size common_attn_metadata.block_table_tensor = self._pad_tensor( - common_attn_metadata.block_table_tensor, pad_size - ) - common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, pad_size) - common_attn_metadata.seq_lens_cpu = self._pad_tensor(common_attn_metadata.seq_lens_cpu, pad_size) + common_attn_metadata.block_table_tensor, pad_size) + common_attn_metadata.seq_lens = self._pad_tensor( + common_attn_metadata.seq_lens, pad_size) + common_attn_metadata.seq_lens_cpu = self._pad_tensor( + common_attn_metadata.seq_lens_cpu, pad_size) common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( - common_attn_metadata.num_computed_tokens_cpu, pad_size - ) - common_attn_metadata.query_start_loc = self.arange[: input_batch_size + 1] + common_attn_metadata.num_computed_tokens_cpu, pad_size) + common_attn_metadata.query_start_loc = self.arange[ + :input_batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[: input_batch_size + 1] - ).clone() + self.token_arange_np[:input_batch_size + 1]).clone() else: - common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[: batch_size + 1] - ).clone() + self.token_arange_np[:batch_size + 1]).clone() common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 @@ -764,7 +816,7 @@ class EagleProposer(VllmEagleProposer): common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill common_attn_metadata.graph_pad_size = -1 common_attn_metadata.num_input_tokens = input_batch_size - + # The loop part used_update_positions += 1 @@ -776,15 +828,18 @@ class EagleProposer(VllmEagleProposer): # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. if self.uses_mrope: - exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = used_update_positions[ + 0] >= self.vllm_config.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions - ) + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(used_update_positions), used_update_positions) else: - exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions) + exceeds_max_model_len = used_update_positions >= \ + self.vllm_config.model_config.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, + used_update_positions) # For data integrity when async scheduling, we shouldn't use in place # operations in case they are modified in next step's `prepare_input` @@ -793,11 +848,15 @@ class EagleProposer(VllmEagleProposer): common_attn_metadata.seq_lens[:batch_size] += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens[:batch_size].masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len, 1) - common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 - exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len - common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) + common_attn_metadata.seq_lens_cpu[:batch_size] = ( + common_attn_metadata.seq_lens_cpu[:batch_size] + 1) + exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \ + self.max_model_len + common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_( + exceeds_mask, 1) common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 if self.uses_mrope: common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) @@ -814,22 +873,29 @@ class EagleProposer(VllmEagleProposer): if self.uses_mrope: block_numbers = clamped_positions[0] // block_size else: - block_numbers = clamped_positions // block_size - block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1)) + block_numbers = (clamped_positions // block_size) + block_ids = old_common_metadata.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) if self.uses_mrope: - slot_mapping = block_ids * block_size + clamped_positions[0] % block_size + slot_mapping = (block_ids * block_size + + clamped_positions[0] % block_size) else: - slot_mapping = block_ids * block_size + clamped_positions % block_size + slot_mapping = (block_ids * block_size + + clamped_positions % block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) - self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) + slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + self.slot_mapping_group[draft_step][:slot_mapping.shape[0]].copy_( + slot_mapping.to(torch.int32)) + self.slot_mapping_group[draft_step][slot_mapping.shape[0]:].fill_( + PADDING_SLOT_ID) # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] - common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][ + :slot_mapping.shape[0]] # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore @@ -861,22 +927,24 @@ class EagleProposer(VllmEagleProposer): # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array( - [ - requests[gpu_input_batch.req_ids[i]].get_token_id(common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ] - ) + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) self.backup_next_token_ids.copy_to_gpu(num_reqs) # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = discard_request_indices[:num_discarded_requests] + discard_sampled_tokens_req_indices = discard_request_indices[: + num_discarded_requests] valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_(0, discard_sampled_tokens_req_indices, -1) + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) # Generate a mask for all valid tokens within those requests - valid_mask = (valid_sampled_token_ids_gpu != -1) & (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1) @@ -887,7 +955,9 @@ class EagleProposer(VllmEagleProposer): # Get last valid token from each row # (assume undefined state where there is no valid token) - selected_tokens = torch.gather(valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)).squeeze(1) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) # Use last token if valid, pre-computed backup if not batch_size = valid_sampled_token_ids_gpu.shape[0] @@ -929,17 +999,22 @@ class EagleProposer(VllmEagleProposer): num_actual_reqs = len(num_draft_tokens) num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_actual_reqs + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_actual_reqs + + 1] seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs] new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -960,38 +1035,43 @@ class EagleProposer(VllmEagleProposer): # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded + token_offests = (self.token_arange_np[:total_num_tokens] - + new_query_start_locs_expanded) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) - common_attn_metadata.slot_mapping[: token_indices.shape[0]].copy_( - common_attn_metadata.slot_mapping[token_indices] - ) - common_attn_metadata.slot_mapping[token_indices.shape[0] :].fill_(-1) + common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( + common_attn_metadata.slot_mapping[token_indices]) + common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) # NOTE: Currently positions and seq_lens are not used in attn forward # so we do not need to fixed them. But if they are used in the future, # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, @@ -1002,8 +1082,7 @@ class EagleProposer(VllmEagleProposer): positions=common_attn_metadata.positions[token_indices], attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0, - ) + max_seq_len=0) return spec_common_attn_metadata, token_indices def prepare_inputs_padded( @@ -1024,12 +1103,15 @@ class EagleProposer(VllmEagleProposer): num_reqs = common_attn_metadata.num_reqs device = valid_sampled_tokens_count.device - token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device) + token_indices_to_sample = torch.empty((num_reqs, ), + dtype=torch.int32, + device=device) - num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE) + num_blocks_needed = triton.cdiv(num_reqs, + _PREPARE_INPUTS_BLOCK_SIZE) num_vector_core = get_vectorcore_num() grid_size = min(num_blocks_needed, num_vector_core) - grid = (grid_size,) + grid = (grid_size, ) prepare_inputs_padded_kernel[grid]( spec_decode_metadata.cu_num_draft_tokens, @@ -1040,12 +1122,11 @@ class EagleProposer(VllmEagleProposer): BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE, ) else: - num_draft_tokens_gpu = torch.cat( - [ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - spec_decode_metadata.cu_num_draft_tokens[:-1], - ] - ) + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1], + ]) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, @@ -1053,11 +1134,14 @@ class EagleProposer(VllmEagleProposer): torch.zeros_like(num_draft_tokens_gpu), ) - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - + num_rejected_tokens_gpu) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] @@ -1070,7 +1154,8 @@ class EagleProposer(VllmEagleProposer): query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=common_attn_metadata.num_actual_tokens if self.pcp_size > 1 else total_num_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens + if self.pcp_size > 1 else total_num_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, max_query_len=new_query_len_per_req.max().item(), actual_seq_lengths_q=self.runner.actual_seq_lengths_q, @@ -1079,14 +1164,15 @@ class EagleProposer(VllmEagleProposer): positions=common_attn_metadata.positions, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, seq_lens=common_attn_metadata.seq_lens, - max_seq_len=0, - ) + max_seq_len=0) return spec_common_attn_metadata, token_indices, token_indices_to_sample - def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states): + def _split_pcp_input(self, req_scheduled_tokens, input_ids, + target_hidden_states): """ Split prefill input_ids and target_hidden_states in pcp group. 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] @@ -1099,27 +1185,29 @@ class EagleProposer(VllmEagleProposer): # no prefill inputs to split, return empty result return ( 0, - torch.zeros([0], device="npu"), - torch.zeros([0, target_hidden_states.size(1)], device="npu"), + torch.zeros([0], device='npu'), + torch.zeros([0, target_hidden_states.size(1)], device='npu'), 0, torch.zeros([0]), torch.tensor([0], dtype=torch.int32), ) def _pcp_pad_and_split(num_tokens): - num_pcp_padded_scheduled_tokens = cdiv(num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size + num_pcp_padded_scheduled_tokens = cdiv( + num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size) # split position_ids (and use split position_ids to split input_ids afterwards) req_position_cp: list[int] = [] - req_position_cp.extend(self.full_indices[self.pcp_rank * chunk_size : (self.pcp_rank + 1) * chunk_size]) req_position_cp.extend( - self.full_indices[ - num_pcp_padded_scheduled_tokens - (self.pcp_rank + 1) * chunk_size : num_pcp_padded_scheduled_tokens - - self.pcp_rank * chunk_size - ] - ) + self.full_indices[self.pcp_rank * + chunk_size:(self.pcp_rank + 1) * chunk_size]) + req_position_cp.extend( + self.full_indices[num_pcp_padded_scheduled_tokens - + (self.pcp_rank + 1) * + chunk_size:num_pcp_padded_scheduled_tokens - + self.pcp_rank * chunk_size]) return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad @@ -1129,12 +1217,17 @@ class EagleProposer(VllmEagleProposer): pcp_split_input_ids_list = [] pcp_split_hidden_states_list = [] for ori_num_tokens in req_scheduled_tokens.values(): - req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = _pcp_pad_and_split(ori_num_tokens) + req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \ + _pcp_pad_and_split(ori_num_tokens) actual_num_tokens = len(req_position_pcp) num_pcp_scheduled_tokens.append(actual_num_tokens) - pad_input_ids = F.pad(input_ids[ori_start_index : ori_start_index + ori_num_tokens], (0, num_pcp_pad)) + pad_input_ids = F.pad( + input_ids[ori_start_index:ori_start_index + ori_num_tokens], + (0, num_pcp_pad)) ori_start_index += ori_num_tokens - pcp_chunk_indices = [pad_start_index + pos for pos in req_position_pcp] + pcp_chunk_indices = [ + pad_start_index + pos for pos in req_position_pcp + ] pcp_split_input_ids = pad_input_ids[req_position_pcp] pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices] pcp_split_input_ids_list.append(pcp_split_input_ids) @@ -1145,20 +1238,16 @@ class EagleProposer(VllmEagleProposer): target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0) max_query_len = max(num_pcp_scheduled_tokens) seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32) - cu_num_tokens = torch.tensor(np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) + cu_num_tokens = torch.tensor( + np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens # update full-graph params for one spec token def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): update_full_graph_params( - self.runner.attn_backend, - self.update_stream, - forward_context, - num_tokens, - self.vllm_config, - self.vllm_config.speculative_config, - draft_attn_metadatas=draft_attn_metadatas, - ) + self.runner.attn_backend, self.update_stream, forward_context, num_tokens, + self.vllm_config, self.vllm_config.speculative_config, + draft_attn_metadatas=draft_attn_metadatas) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size): @@ -1173,14 +1262,16 @@ class EagleProposer(VllmEagleProposer): ) -> tuple[torch.Tensor, torch.Tensor]: if self.method == "mtp": if self.enable_shared_expert_dp: - hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = positions.squeeze(-1) else: forward_context = get_forward_context() if forward_context.sp_enabled: - hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) + hidden_states = split_inputs_tp_to_sp( + hidden_states, hidden_states) return hidden_states, positions def maybe_all_gather_and_unpad( @@ -1192,17 +1283,17 @@ class EagleProposer(VllmEagleProposer): if self.method == "mtp": if self.enable_shared_expert_dp: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True - ) - positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) + last_hidden_states.contiguous(), True) + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions.contiguous(), True) if hidden_states is not None: hidden_states = last_hidden_states else: forward_context = get_forward_context() if forward_context.sp_enabled: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True - ) + last_hidden_states.contiguous(), True) if hidden_states is not None: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) return last_hidden_states, positions, hidden_states diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 138f88dc..feec5bcf 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -1,4 +1,5 @@ import enum +from typing import Optional import torch from vllm.config import CUDAGraphMode, VllmConfig @@ -17,7 +18,11 @@ class SpecDcodeType(enum.Enum): class Proposer: - def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device = None, + runner=None): pass def load_model(self, model): @@ -25,29 +30,25 @@ class Proposer: raise NotImplementedError @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: torch.Tensor | None = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - ): + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): """Called by dummy_run in modle_runner""" raise NotImplementedError - def generate_token_ids( - self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - aux_hidden_states: torch.Tensor = None, - ): + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + aux_hidden_states: torch.Tensor = None): """Called by execute_model in model_runner""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/spec_decode/medusa_proposer.py b/vllm_ascend/spec_decode/medusa_proposer.py index ff727cc8..eda2e41a 100644 --- a/vllm_ascend/spec_decode/medusa_proposer.py +++ b/vllm_ascend/spec_decode/medusa_proposer.py @@ -1,9 +1,14 @@ +from typing import Optional + import torch +import torch.nn as nn from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.spec_decode.interface import SpecDcodeType @@ -17,70 +22,72 @@ class MedusaProposer(VllmMedusaProposer): """ def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner, + self, + vllm_config: VllmConfig, + device: torch.device, + runner, ): # Save config parameters self.name = SpecDcodeType.MEDUSA self.vllm_config = vllm_config self.device = device self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size() + self.hidden_size = (vllm_config.speculative_config.draft_model_config. + get_hidden_size()) self.dtype = vllm_config.model_config.dtype self.runner = runner @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: torch.Tensor | None = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ): + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False): hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=self.device, ) with set_ascend_forward_context( - None, - self.vllm_config, - num_tokens=num_tokens, - num_actual_tokens=0, - in_profile_run=is_profile, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True, - ): + None, + self.vllm_config, + num_tokens=num_tokens, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True): self.model(hidden_states) dummy_compute_logits(hidden_states) - def generate_token_ids( - self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - spec_decode_metadata: SpecDecodeMetadata, - sample_hidden_states: torch.Tensor, - *args, - **kwargs, - ): + def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + spec_decode_metadata: SpecDecodeMetadata, + sample_hidden_states: torch.Tensor, + *args, + **kwargs + ): + if sample_hidden_states.shape[0] == len(valid_sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: num_accepted_tokens = torch.tensor( - [len(t) for t in valid_sampled_token_ids], device=self.device, dtype=torch.long - ) - num_draft_tokens = torch.tensor(spec_decode_metadata.num_draft_tokens, device=self.device, dtype=torch.long) + [len(t) for t in valid_sampled_token_ids], + device=self.device, + dtype=torch.long) + num_draft_tokens = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=self.device, + dtype=torch.long) - offsets = torch.cumsum(num_draft_tokens + 1, dim=0) - (num_draft_tokens + 1) + offsets = torch.cumsum(num_draft_tokens + 1, + dim=0) - (num_draft_tokens + 1) indices = offsets + num_accepted_tokens - 1 hidden_states = sample_hidden_states[indices] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 3f43b194..03f21fed 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,3 +1,5 @@ +from typing import Optional, Union + import torch import torch.nn as nn from vllm.config import CUDAGraphMode @@ -20,33 +22,29 @@ from vllm_ascend.utils import lmhead_tp_enable, vllm_version_is class MtpProposer(EagleProposer): + # TODO: Find out why ModelRunner does not this explicit typing? - model: nn.Module | ACLGraphWrapper + model: Union[nn.Module, ACLGraphWrapper] @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ) -> None: - if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False) -> None: + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + ): super().dummy_run( - num_tokens, - with_prefill, - in_graph_capturing, - num_reqs, - num_tokens_across_dp, - aclgraph_runtime_mode, - batch_descriptor, - dummy_compute_logits, - is_profile, + num_tokens, with_prefill, in_graph_capturing, num_reqs, + num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor, + dummy_compute_logits, is_profile ) return ( @@ -63,10 +61,14 @@ class MtpProposer(EagleProposer): aclgraph_runtime_mode = CUDAGraphMode.NONE if aclgraph_runtime_mode == CUDAGraphMode.FULL: if len(self.runner.attn_groups) > 0: - num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + num_computed_tokens_cpu = ( + self.runner.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs]) common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=self.runner.query_start_loc.cpu[: num_reqs + 1], + query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + + 1], + query_start_loc_cpu=self.runner.query_start_loc. + cpu[:num_reqs + 1], seq_lens_cpu=self.runner.seq_lens.cpu, seq_lens=self.runner.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, @@ -75,29 +77,27 @@ class MtpProposer(EagleProposer): max_query_len=self.num_speculative_tokens + 1, num_computed_tokens_cpu=num_computed_tokens_cpu, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), - slot_mapping=self.runner.input_batch.block_table[0].slot_mapping.gpu, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping.gpu, positions=self.runner.positions.gpu, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0, - ) + max_seq_len=0) if self.pcp_size * self.dcp_size > 1: # update long_seq related params and flatten block_table - common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata - common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[ - 0 - ].get_device_tensor()[: num_reqs * self.decode_threshold] + common_attn_metadata.prefill_context_parallel_metadata = \ + self.runner.pcp_manager.long_seq_metadata + common_attn_metadata.block_table_tensor = \ + self.runner.input_batch.block_table[0].get_device_tensor()[ + :num_reqs * self.decode_threshold] builder = self.runner.attn_groups[0][0].get_metadata_builder() - # `AscendAttentionState.SpecDecoding` is only designed for mla, - # `AscendAttentionState.ChunkedPrefill` is used in self-attention. - attn_state = ( - AscendAttentionState.SpecDecoding - if self.vllm_config.model_config.use_mla - else AscendAttentionState.ChunkedPrefill - ) - attn_metadata_mtp = builder.build_for_graph_capture(common_attn_metadata, attn_state) + # `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention. + attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill + attn_metadata_mtp = builder.build_for_graph_capture( + common_attn_metadata, attn_state) attn_metadata = {} for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp @@ -113,34 +113,32 @@ class MtpProposer(EagleProposer): if i > 0 and not in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: aclgraph_runtime_mode = CUDAGraphMode.NONE with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - is_draft_model=True, - in_profile_run=is_profile, - ): + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + is_draft_model=True, + in_profile_run=is_profile): if not vllm_version_is("v0.15.0"): # Reset MOE layer index for each MTP step iteration forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - previous_hidden_states, positions = self.maybe_pad_and_reduce(previous_hidden_states, positions) - self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + previous_hidden_states, positions = self.maybe_pad_and_reduce( + previous_hidden_states, positions) + self.model(input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states) forward_context = get_forward_context() - if ( - forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL - and not forward_context.capturing - and not self.use_sparse - ): + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing and not self.use_sparse: self._update_full_graph_params(forward_context, num_tokens) previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad( - previous_hidden_states, positions - ) + previous_hidden_states, positions) dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -155,10 +153,11 @@ class MtpProposer(EagleProposer): target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, req_scheduled_tokens=None, long_seq_metadata=None, num_prefill_reqs=0, @@ -166,22 +165,16 @@ class MtpProposer(EagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: - if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + ): draft_token_ids = super()._propose( - target_token_ids, - target_positions, - target_hidden_states, - next_token_ids, - last_token_indices, - common_attn_metadata, - sampling_metadata, - mm_embed_inputs, - req_scheduled_tokens, - long_seq_metadata, - num_prefill_reqs, - num_decode_reqs, - scheduler_output, - num_scheduled_tokens, + target_token_ids, target_positions, target_hidden_states, + next_token_ids, last_token_indices, common_attn_metadata, + sampling_metadata, mm_embed_inputs, req_scheduled_tokens, + long_seq_metadata, num_prefill_reqs, num_decode_reqs, + scheduler_output, num_scheduled_tokens ) return draft_token_ids @@ -193,12 +186,13 @@ class MtpProposer(EagleProposer): if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states(target_hidden_states) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids @@ -219,16 +213,20 @@ class MtpProposer(EagleProposer): num_tokens_d_padded = num_tokens_d * self.pcp_size input_ids_d = self.input_ids[:num_tokens_d] input_ids_p = self.input_ids[num_tokens_d:num_tokens] - target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] + target_hidden_states_d_padded = \ + target_hidden_states[:num_tokens_d_padded] if num_tokens_d: # remove padding (from pcp all-gather) in decode part - mask_start_loc = torch.cat( - [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] - ) + mask_start_loc = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1] + ]) mask_len = query_lens_d mask = [] for req_id in range(num_decode_reqs): - mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) + mask += list( + range(mask_start_loc[req_id], + mask_start_loc[req_id] + mask_len[req_id])) target_hidden_states_d = target_hidden_states_d_padded[mask] else: target_hidden_states_d = target_hidden_states_d_padded @@ -236,33 +234,46 @@ class MtpProposer(EagleProposer): req_scheduled_tokens_p = {} for i, req_id in enumerate(self.runner.input_batch.req_ids): if i >= num_decode_reqs: - req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] - (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( - self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) - ) + req_scheduled_tokens_p[req_id] = \ + req_scheduled_tokens[req_id] + (num_tokens_p, input_ids_p, target_hidden_states_p, + max_query_len_p, seq_lens_p, cu_num_tokens_p) = \ + self._split_pcp_input( + req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) num_tokens = num_tokens_d + num_tokens_p target_positions = target_positions[:num_tokens] - self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) - target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) + self.input_ids[:num_tokens].copy_( + torch.cat([input_ids_d, input_ids_p], dim=0)) + target_hidden_states = torch.cat( + [target_hidden_states_d, target_hidden_states_p], dim=0) # 2. update sample_indices according to main model if num_decode_reqs: - last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] + last_token_indices[:num_decode_reqs] = \ + self.runner.logits_indices[last_token_indices[:num_decode_reqs]] if num_prefill_reqs: - last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] + last_token_indices[-num_prefill_reqs:] = \ + self.runner.logits_indices[-num_prefill_reqs:] # 3. update attn_metadata params that may be influenced by pcp common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) + common_attn_metadata.max_query_len = max( + self.decode_threshold, max_query_len_p) common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p - common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p - query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item() - common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p - common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p + common_attn_metadata.seq_lens_cpu[ + -num_prefill_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + \ + common_attn_metadata.query_start_loc[num_decode_reqs].item() + common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \ + query_start_loc_p + common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \ + query_start_loc_p assert self.runner is not None # Note(qcs): We may need to refactor these check logics. - if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[-1]: - num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_scheduled_tokens] + if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[ + -1]: + num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[ + num_scheduled_tokens] else: # Eager mode, no padding needed num_input_tokens = num_tokens @@ -271,23 +282,23 @@ class MtpProposer(EagleProposer): self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states # eager/acl piecewise mode need to update num_tokens_across_dp - (num_input_tokens, num_tokens_across_dp, with_prefill) = self.runner._sync_metadata_across_dp( - num_input_tokens, self.runner.with_prefill - ) + (num_input_tokens, num_tokens_across_dp, + with_prefill) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill) # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. if scheduler_output and not self.enable_shared_expert_dp: max_query_len = common_attn_metadata.max_query_len - uniform_decode = (max_query_len in list(range(1, self.num_speculative_tokens + 2))) and ( - scheduler_output.total_num_scheduled_tokens - == self.runner.input_batch.num_reqs * (self.num_speculative_tokens + 1) - ) + uniform_decode = (max_query_len in list( + range(1, self.num_speculative_tokens + + 2))) and (scheduler_output.total_num_scheduled_tokens + == self.runner.input_batch.num_reqs * + (self.num_speculative_tokens + 1)) else: uniform_decode = False has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 - aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( - num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora - ) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) if not self.use_cuda_graph: # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the @@ -296,10 +307,8 @@ class MtpProposer(EagleProposer): # and _propose. aclgraph_runtime_mode = CUDAGraphMode.NONE - if ( - self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() - and aclgraph_runtime_mode == CUDAGraphMode.FULL - ): + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and aclgraph_runtime_mode == CUDAGraphMode.FULL: graph_pad_size = num_input_tokens else: graph_pad_size = -1 @@ -310,58 +319,64 @@ class MtpProposer(EagleProposer): common_attn_metadata.graph_pad_size = graph_pad_size common_attn_metadata.num_input_tokens = num_input_tokens builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, self.runner.get_model()) + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) attn_metadata = {} for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - num_actual_tokens=num_tokens, - is_draft_model=True, - ): + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=num_tokens, + is_draft_model=True): + if not vllm_version_is("v0.15.0"): # Reset MOE layer index for each MTP step to match all_moe_layers registration forward_context = get_forward_context() if forward_context is not None: forward_context.moe_layer_index = 0 - with record_function_or_nullcontext("mtp_forward"): + with record_function_or_nullcontext('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata input_ids = self.input_ids[:num_input_tokens] positions = self._get_positions(num_input_tokens) hidden_states = self.hidden_states[:num_input_tokens] - hidden_states, positions = self.maybe_pad_and_reduce(hidden_states, positions) + hidden_states, positions = self.maybe_pad_and_reduce( + hidden_states, positions) - hidden_states = self.model(input_ids=input_ids, positions=positions, hidden_states=hidden_states) - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: - self._update_full_graph_params(forward_context, num_input_tokens) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + hidden_states=hidden_states) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: + self._update_full_graph_params(forward_context, + num_input_tokens) - hidden_states, positions, _ = self.maybe_all_gather_and_unpad(hidden_states, positions) + hidden_states, positions, _ = self.maybe_all_gather_and_unpad( + hidden_states, positions) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): - max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len - ) - last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices)) if self.pcp_size > 1 and step == 0: # remove graph padding before all_gather hidden_states = hidden_states[:num_tokens] hidden_states = get_pcp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] - ) + hidden_states, 0, self.runner.pcp_manager. + pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]]) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -394,7 +409,7 @@ class MtpProposer(EagleProposer): hidden_states = hidden_states[last_token_indices] slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] attn_metadata_i.slot_mapping.fill_(-1) - attn_metadata_i.query_start_loc = self.arange[: batch_size + 1] + attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] last_token_indices = self.arange[:batch_size] if getattr(attn_metadata_i, "num_decode_tokens", 0): attn_metadata_i.num_decode_tokens = batch_size @@ -405,44 +420,44 @@ class MtpProposer(EagleProposer): # Instead, we pre-allocate mtp slot_mapping in model_runner # (_generate_pcp_mtp_input), and use updated slot_indices # to get corresponding slot_mapping in each step. - num_reject_tokens = ( - torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) - - ori_last_token_indices - - 1 - ) - num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens + num_reject_tokens = torch.tensor( + self.runner.pcp_manager.cu_num_tokens_pcp_full, + dtype=torch.int32).to( + self.device) - ori_last_token_indices - 1 + num_accept_tokens = \ + query_lens_d.to(self.device) - num_reject_tokens ori_seq_len = attn_metadata_i.seq_lens mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad # slot_mapping index base offset: # scheduled tokens + pre-allocated mtp tokens + accepted tokens slot_idx_base = ( - torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=self.device), - (torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device), - ] - ) - + torch.arange(num_decode_reqs, device=self.device) - * (self.num_speculative_tokens - 1) - * self.pcp_size - + (num_accept_tokens - 1) * self.pcp_size - ) + torch.cat([ + torch.tensor( + [0], dtype=torch.int32, device=self.device), + (torch.cumsum(query_lens_d, dim=0)[:-1] * + self.pcp_size).to(self.device) + ]) + + torch.arange(num_decode_reqs, device=self.device) * + (self.num_speculative_tokens - 1) * self.pcp_size + + (num_accept_tokens - 1) * self.pcp_size) slot_indices_list = [] for req_id in range(num_decode_reqs): slot_indices_list.append( - torch.arange( - slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device - ) - ) + torch.arange(slot_idx_base[req_id], + slot_idx_base[req_id] + self.pcp_size, + device=self.device)) slot_indices = torch.cat(slot_indices_list, dim=0) # fold block_table (restore it to original size before flattened) - block_indices = torch.cat( - [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]] - ) - attn_metadata_i.decode.block_table[:batch_size] = attn_metadata_i.decode.block_table[block_indices] - attn_metadata_i.decode.block_table = attn_metadata_i.decode.block_table[:batch_size] + block_indices = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(query_lens_d, dim=0)[:-1] + ]) + attn_metadata_i.decode.block_table[:batch_size] = \ + attn_metadata_i.decode.block_table[block_indices] + attn_metadata_i.decode.block_table = \ + attn_metadata_i.decode.block_table[:batch_size] input_ids = draft_token_ids_list[-1].int() positions += 1 @@ -450,32 +465,38 @@ class MtpProposer(EagleProposer): decode_metadata = getattr(attn_metadata_i, "decode", None) prefill_metadata = getattr(attn_metadata_i, "prefill", None) # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. - if decode_metadata is not None and ( - self.speculative_config.disable_padded_drafter_batch or aclgraph_runtime_mode != CUDAGraphMode.FULL - ): - decode_metadata.actual_seq_lengths_q = self.arange_cpu[1 : batch_size + 1].tolist() + if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \ + aclgraph_runtime_mode != CUDAGraphMode.FULL): + decode_metadata.actual_seq_lengths_q = self.arange_cpu[ + 1:batch_size + 1].tolist() if aclgraph_runtime_mode == CUDAGraphMode.FULL: - decode_metadata.actual_seq_lengths_q = builder.pad_actual_seq_len_q_mtp_disable_pad( - graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q - ) - decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(positions[:batch_size]) + decode_metadata.actual_seq_lengths_q = \ + builder.pad_actual_seq_len_q_mtp_disable_pad( + graph_pad_size - batch_size, + batch_size, + decode_metadata.actual_seq_lengths_q) + decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla( + positions[:batch_size]) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. - exceeds_max_model_len = positions[:batch_size] >= self.runner.model_config.max_model_len + exceeds_max_model_len = positions[: + batch_size] >= self.runner.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, positions[:batch_size]) + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions[:batch_size]) # Increment the sequence lengths. # This is an out-of-place operation to avoid modifying the original tensor # when enable async_scheduling. attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > self.runner.model_config.max_model_len + exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > \ + self.runner.model_config.max_model_len attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the @@ -483,14 +504,13 @@ class MtpProposer(EagleProposer): slot_mapping += 1 if self.pcp_size > 1: exceeds_max_model_len = exceeds_max_model_len.repeat_interleave( - slot_mapping.size(0) // exceeds_max_model_len.size(0) - ) + slot_mapping.size(0) // exceeds_max_model_len.size(0)) slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) - self.hidden_states[: hidden_states.shape[0]] = hidden_states + self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: # update local seq_len num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( @@ -499,17 +519,19 @@ class MtpProposer(EagleProposer): self.dcp_size, self.runner.parallel_config.cp_kv_cache_interleave_size, ) - cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] + cp_seq_len = \ + num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] attn_metadata_i.decode.cp_seq_len = cp_seq_len # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices] - attn_metadata_i.slot_mapping[: batch_size * self.pcp_size] = slot_mapping + attn_metadata_i.slot_mapping[:batch_size * + self.pcp_size] = slot_mapping else: attn_metadata_i.slot_mapping[:batch_size] = slot_mapping if self.speculative_config.disable_padded_drafter_batch: if self.uses_mrope: - self.mrope_positions[:, batch_size:num_input_tokens] = 0 + self.mrope_positions[:, batch_size:num_input_tokens] = 0 else: self.positions[batch_size:num_input_tokens] = 0 self.input_ids[batch_size:num_input_tokens] = 0 @@ -517,24 +539,31 @@ class MtpProposer(EagleProposer): if prefill_metadata is not None: prefill_metadata.seq_lens = attn_metadata_i.seq_lens - prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist() + prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist( + ) prefill_metadata.context_lens = attn_metadata_i.seq_lens - prefill_metadata.input_positions = self._get_positions(num_input_tokens) + prefill_metadata.input_positions = self._get_positions( + num_input_tokens) prefill_metadata.max_seq_lens += 1 prefill_metadata.max_seq_lens = min( - prefill_metadata.max_seq_lens, self.runner.model_config.max_model_len - ) + prefill_metadata.max_seq_lens, + self.runner.model_config.max_model_len) if decode_metadata is not None: decode_metadata.seq_lens = attn_metadata_i.seq_lens - decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist() + decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist( + ) decode_seq_lens_list = decode_metadata.seq_lens_list - if aclgraph_runtime_mode == CUDAGraphMode.FULL and self.speculative_config.disable_padded_drafter_batch: - decode_metadata.seq_lens_list = decode_seq_lens_list + [0] * ( - graph_pad_size - len(decode_seq_lens_list) - ) - decode_metadata.input_positions = self._get_positions(num_input_tokens) + if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ + self.speculative_config.disable_padded_drafter_batch: + decode_metadata.seq_lens_list = decode_seq_lens_list + [ + 0 + ] * (graph_pad_size - len(decode_seq_lens_list)) + decode_metadata.input_positions = self._get_positions( + num_input_tokens) decode_metadata.max_seq_lens += 1 - decode_metadata.max_seq_lens = min(decode_metadata.max_seq_lens, self.runner.model_config.max_model_len) + decode_metadata.max_seq_lens = min( + decode_metadata.max_seq_lens, + self.runner.model_config.max_model_len) # mtp>1: [batch_size, k] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 280d2ca8..22d28b61 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,11 +1,13 @@ import torch from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer +from vllm.v1.spec_decode.ngram_proposer import \ + NgramProposer as VllmNgramProposer from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType class NgramProposer(VllmNgramProposer, Proposer): + def __init__(self, vllm_config, device, runner): super().__init__(vllm_config) self.name = SpecDcodeType.NGRAM @@ -17,31 +19,27 @@ class NgramProposer(VllmNgramProposer, Proposer): pass @torch.inference_mode() - def dummy_run( - self, - num_tokens, - with_prefill=None, - in_graph_capturing=None, - num_reqs=None, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ): + def dummy_run(self, + num_tokens, + with_prefill=None, + in_graph_capturing=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False): pass - def generate_token_ids( - self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None, - ) -> list[list[int]]: + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + aux_hidden_states=None) -> list[list[int]]: valid_ngram_requests = [] for i, sampled_ids in enumerate(valid_sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -59,7 +57,8 @@ class NgramProposer(VllmNgramProposer, Proposer): start_idx = self.runner.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids - self.runner.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + self.runner.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_ids valid_ngram_requests.append(i) diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py index 1cdbec3c..ea9f0f72 100644 --- a/vllm_ascend/spec_decode/suffix_proposer.py +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -1,11 +1,13 @@ import torch from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer +from vllm.v1.spec_decode.suffix_decoding import \ + SuffixDecodingProposer as VllmSuffixDecodingProposer from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): + def __init__(self, vllm_config, device, runner): super().__init__(vllm_config) self.name = SpecDcodeType.SUFFIX @@ -17,30 +19,27 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): pass @torch.inference_mode() - def dummy_run( - self, - num_tokens, - with_prefill=None, - in_graph_capturing=None, - num_reqs=None, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ): + def dummy_run(self, + num_tokens, + with_prefill=None, + in_graph_capturing=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False): pass - def generate_token_ids( - self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None, - ) -> list[list[int]]: - draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids) + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + aux_hidden_states=None) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, + valid_sampled_token_ids) return draft_token_ids