From 82544aa0cc23b5d2c951ef0e9703394a4ffd772a Mon Sep 17 00:00:00 2001 From: chanzhennan Date: Sat, 28 Feb 2026 11:15:50 +0800 Subject: [PATCH] [Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222) Signed-off-by: xyDong0223 Co-authored-by: xyDong0223 --- vllm_kunlun/models/__init__.py | 104 ++- vllm_kunlun/models/qwen3_next.py | 755 ++++++++++-------- vllm_kunlun/models/qwen3_next_mtp.py | 303 +++++++ vllm_kunlun/ops/_kunlun_ops.py | 546 +++++++------ vllm_kunlun/ops/fla/chunk.py | 402 ++++++++-- vllm_kunlun/ops/fla/chunk_o.py | 104 ++- vllm_kunlun/ops/fla/fused_recurrent.py | 39 +- vllm_kunlun/ops/fla/l2norm.py | 94 +-- vllm_kunlun/ops/fla/layernorm_guard.py | 209 ++--- vllm_kunlun/ops/fla/wy_fast.py | 56 +- vllm_kunlun/ops/layernorm.py | 120 +-- vllm_kunlun/ops/mamba/causal_conv1d.py | 755 ++++++++---------- vllm_kunlun/v1/attention/backends/gdn_attn.py | 390 +++++++++ .../v1/attention/backends/kunlun_attn.py | 26 +- vllm_kunlun/v1/sample/rejection_sampler.py | 154 ++-- vllm_kunlun/v1/sample/spec_decode/eagle.py | 2 +- vllm_kunlun/vllm_utils_wrapper.py | 141 +++- 17 files changed, 2668 insertions(+), 1532 deletions(-) create mode 100644 vllm_kunlun/models/qwen3_next_mtp.py create mode 100644 vllm_kunlun/v1/attention/backends/gdn_attn.py diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index bceb5f5..4681f4e 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -3,95 +3,113 @@ from vllm import ModelRegistry def register_model(): # from .demo_model import DemoModel # noqa: F401 - from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401 - from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401 - from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401 - from .qwen3_vl import Qwen3VLForConditionalGeneration - from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration - from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration + from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401 + from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401 + from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401 + from .qwen3_omni_moe_thinker import ( # noqa: F401 + Qwen3OmniMoeThinkerForConditionalGeneration, + ) + from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401 + from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401 + # from .llama4 import Llama4ForCausalLM #noqa: F401 # from .mllama4 import Llama4ForConditionalGeneration #noqa: F401 # from .deepseek_v2 import KunlunDeepseekV2MoE - # ModelRegistry.register_model( # "DemoModel", # "vllm_kunlun.model_executor.models.demo_model:DemoModel") ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", - "vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration") + "vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration", + ) ModelRegistry.register_model( "Qwen2_5_VLForConditionalGeneration", - "vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration") + "vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration", + ) ModelRegistry.register_model( - "Qwen3ForCausalLM", - "vllm_kunlun.models.qwen3:Qwen3ForCausalLM") + "Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM" + ) ModelRegistry.register_model( - "Qwen3MoeForCausalLM", - "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM") + "Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM" + ) ModelRegistry.register_model( - "Qwen3NextForCausalLM", - "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM") + "Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM" + ) ModelRegistry.register_model( - "GptOssForCausalLM", - "vllm_kunlun.models.gpt_oss:GptOssForCausalLM") + "Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP" + ) ModelRegistry.register_model( - "InternLM2ForCausalLM", - "vllm_kunlun.models.internlm2:InternLM2ForCausalLM") - + "GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM" + ) + ModelRegistry.register_model( - "InternVLChatModel", - "vllm_kunlun.models.internvl:InternVLChatModel") + "GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM" + ) + + ModelRegistry.register_model( + "InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM" + ) + + ModelRegistry.register_model( + "InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel" + ) ModelRegistry.register_model( "InternS1ForConditionalGeneration", - "vllm_kunlun.models.interns1:InternS1ForConditionalGeneration") - + "vllm_kunlun.models.interns1:InternS1ForConditionalGeneration", + ) + ModelRegistry.register_model( "Qwen3VLForConditionalGeneration", - "vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration") - + "vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration", + ) + ModelRegistry.register_model( "Qwen3VLMoeForConditionalGeneration", - "vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration") + "vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration", + ) ModelRegistry.register_model( "Qwen3OmniMoeForConditionalGeneration", - "vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration") + "vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration", + ) ModelRegistry.register_model( - "SeedOssForCausalLM", - "vllm_kunlun.models.seed_oss:SeedOssForCausalLM") + "SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM" + ) ModelRegistry.register_model( "MiMoV2FlashForCausalLM", - "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") + "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM", + ) ModelRegistry.register_model( - "GptOssForCausalLM", - "vllm_kunlun.models.gpt_oss:GptOssForCausalLM") + "GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM" + ) ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") + "DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM" + ) ModelRegistry.register_model( - "DeepseekV32ForCausalLM", - "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") - - ModelRegistry.register_model( - "DeepSeekMTPModel", - "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP") + "DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM" + ) ModelRegistry.register_model( - "GlmMoeDsaForCausalLM", - "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM") + "DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP" + ) + + ModelRegistry.register_model( + "GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM" + ) + def register_quant_method(): """to do""" diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index d34589a..8bb0804 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -1,98 +1,132 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" + from collections.abc import Iterable from itertools import islice -from typing import Optional, Union +from typing import Optional +import kunlun_ops import torch import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.activations import ACT2FN - from vllm.attention import AttentionBackend, AttentionMetadata - -from vllm_kunlun.ops.attention.layer import Attention - from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, - VllmConfig, get_current_vllm_config) -from vllm.distributed import (divide, get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger -from vllm_kunlun.ops.fla import (fused_recurrent_gated_delta_rule, torch_chunk_gated_delta_rule, chunk_gated_delta_rule) -from vllm.model_executor.layers.fla.ops import ( - RMSNormGated) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE + # yapf conflicts with isort for this block # yapf: disable -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3NextRMSNorm) +from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm + # yapf: enable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - mamba_v2_sharded_weight_loader) +from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm_kunlun.ops.mamba.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.utils import sequence_parallel_chunk + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, + sequence_parallel_chunk, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, MixtureOfExperts, - SupportsLoRA, SupportsPP) -from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) -from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops -from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask -import kunlun_ops +from vllm_kunlun.ops.activation import SiluAndMul +from vllm_kunlun.ops.attention.layer import Attention +from vllm_kunlun.ops.fla import ( + RMSNormGated, + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, +) +from vllm_kunlun.ops.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from vllm_kunlun.v1.attention.backends.gdn_attn import GDNAttentionMetadata @torch.compile(dynamic=True, backend="aot_eager") def get_masked_input_and_mask_kunlun( - input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) + input_ < added_vocab_end_index + ) + added_offset = ( + added_vocab_start_index + - (org_vocab_end_index - org_vocab_start_index) + - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + ( + added_offset * added_vocab_mask + ) vocab_mask = org_vocab_mask | added_vocab_mask input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask + get_masked_input_and_mask = get_masked_input_and_mask_kunlun logger = init_logger(__name__) @@ -113,19 +147,25 @@ class Qwen3NextMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -134,6 +174,7 @@ class Qwen3NextMLP(nn.Module): x, _ = self.down_proj(x) return x + class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -155,7 +196,8 @@ class Qwen3NextSparseMoeBlock(nn.Module): if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -164,32 +206,35 @@ class Qwen3NextSparseMoeBlock(nn.Module): self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) - self.experts = FusedMoE(num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel) + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( @@ -197,15 +242,12 @@ class Qwen3NextSparseMoeBlock(nn.Module): intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=self.experts.must_reduce_shared_expert_outputs(), prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -220,29 +262,34 @@ class Qwen3NextSparseMoeBlock(nn.Module): if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - kunlun_linear_weights = self.gate.get_weights() - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits, - linear_weights=kunlun_linear_weights) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + ) return final_hidden_states.view(orig_shape) + class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @property @@ -250,17 +297,25 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): return "linear_attention" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + from vllm_kunlun.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype) + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, - self.head_v_dim, self.conv_kernel_size, self.num_spec) + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) def __init__( self, @@ -294,8 +349,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self.cache_config = cache_config self.quant_config = quant_config self.speculative_config = speculative_config - self.num_spec = (self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim @@ -331,31 +389,36 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( - self.conv1d.weight, { - "weight_loader": - mamba_v2_sharded_weight_loader([ - query_key_settings, - query_key_settings, - value_settings, - ], self.tp_size, self.tp_rank) - }) + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) # selective projection used to make dt, B and C input dependant # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), ) + torch.ones(self.num_v_heads // self.tp_size), + ) self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) - set_weight_attrs(self.A_log, - {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.norm = RMSNormGated( self.head_v_dim, @@ -363,15 +426,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): group_size=None, norm_before_gate=True, device=current_platform.current_device(), - dtype=config.torch_dtype, + dtype=torch.get_default_dtype(), # config.torch_dtype, ) - self.out_proj = RowParallelLinear(self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj") + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -388,9 +453,13 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, - (self.head_k_dim + self.head_k_dim + - (self.head_v_dim + self.head_v_dim) * self.num_v_heads // - self.num_k_heads), + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), ) new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, @@ -408,16 +477,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ] split_arg_list_ba = [ self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads + self.num_v_heads // self.num_k_heads, ] # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, - split_arg_list_qkvz, - dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=2) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), -1, self.head_v_dim) @@ -440,9 +507,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): dim=-1, ) query, key = map( - lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), - (query, key)) - value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query, key, value def forward( @@ -476,10 +544,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks spec_token_masks = attn_metadata.spec_token_masks - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + spec_state_indices_tensor = ( + attn_metadata.spec_state_indices_tensor + ) # noqa: E501 + non_spec_state_indices_tensor = ( + attn_metadata.non_spec_state_indices_tensor + ) # noqa: E501 + non_spec_state_indices_tensor_cpu = ( + attn_metadata.non_spec_state_indices_tensor_cpu + ) self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) + conv_state = self_kv_cache[0] ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens @@ -487,23 +562,23 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): spec_token_masks = spec_token_masks[:num_actual_tokens] # 1. Set up dimensions for reshapes later - projected_states_qkvz, _ = self.in_proj_qkvz( - hidden_states[:num_actual_tokens]) - projected_states_ba, _ = self.in_proj_ba( - hidden_states[:num_actual_tokens]) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) + projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) - query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), - (query, key, value)) + projected_states_qkvz, projected_states_ba + ) + query, key, value = map( + lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) + ) mixed_qkv = torch.cat((query, key, value), dim=-1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: @@ -521,8 +596,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0] - [:attn_metadata.num_spec_decodes], + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, max_query_len=spec_state_indices_tensor.size(-1), @@ -531,11 +607,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 2.2: process the remaining part if attn_metadata.num_prefills > 0: - mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( - mixed_qkv_non_spec_T, + mixed_qkv_non_spec, conv_weights, self.conv1d.bias, activation=self.activation, @@ -544,7 +619,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, - ).transpose(0, 1) + ) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, @@ -552,26 +627,29 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[:attn_metadata - .num_decodes], + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ], + conv_state_indices_cpu=non_spec_state_indices_tensor_cpu[ + : attn_metadata.num_decodes + ], validate_data=True, ) else: mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( - mixed_qkv_spec) + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec) + mixed_qkv_non_spec + ) beta = b.sigmoid() g = ops.fused_gdn_gating(self.A_log.float(), a, self.dt_bias.float()) - g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g beta_spec = beta g_non_spec = None @@ -591,37 +669,56 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 3.1: process the mutlti-query part if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[:attn_metadata. - num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - )) + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) else: core_attn_out_spec, last_recurrent_state = None, None # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - if non_spec_state_indices_tensor.shape[0] > 100: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() - else: - initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ] - initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device) - for i in range(non_spec_state_indices_tensor.shape[0]): - initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]] - - initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1) + slot_mapping = torch.full( + (ssm_state.shape[0],), -1, dtype=torch.int32, device="cuda" + ) + slot_mapping[non_spec_state_indices_tensor] = torch.arange( + len(non_spec_state_indices_tensor), dtype=torch.int32, device="cuda" + ) + + initial_state_shape = ( + non_spec_state_indices_tensor.shape + ssm_state.shape[1:] + ) + initial_state = torch.empty( + initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device + ) + initial_state = initial_state.view( + initial_state.shape[0], 1, -1, initial_state.shape[-1] + ) + cast_ssm_state = ssm_state.view(ssm_state.shape[0], -1, ssm_state.shape[-1]) + + kunlun_ops.reshape_and_cache_flash( + cast_ssm_state, + cast_ssm_state, + initial_state, + initial_state, + slot_mapping, + ) + initial_state = initial_state.view(initial_state_shape) + + initial_state = initial_state * has_initial_state.view( + has_initial_state.shape[0], 1, 1, 1 + ) initial_state = initial_state.transpose(-1, -2).contiguous() + ( core_attn_out_non_spec, last_recurrent_state, @@ -637,15 +734,23 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cu_seqlens=non_spec_query_start_loc, ) # Init cache - last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view( - last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) - cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1]) + last_recurrent_state = ( + last_recurrent_state.transpose(-1, -2) + .contiguous() + .to(ssm_state.dtype) + .view(last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) + ) + cast_ssm_state = ssm_state.view( + ssm_state.shape[0], 1, -1, ssm_state.shape[-1] + ) + kunlun_ops.reshape_and_cache_flash( - last_recurrent_state, - last_recurrent_state, - cast_ssm_state, - cast_ssm_state, - non_spec_state_indices_tensor) + last_recurrent_state, + last_recurrent_state, + cast_ssm_state, + cast_ssm_state, + non_spec_state_indices_tensor, + ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( @@ -656,17 +761,18 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[:attn_metadata. - num_decodes + 1], + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, - )) + ) + ) else: core_attn_out_non_spec, last_recurrent_state = None, None # Merge core attention output - if (spec_sequence_masks is not None - and core_attn_out_non_spec is not None): + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: core_attn_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, @@ -685,7 +791,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_actual_tokens], _ = self.out_proj(core_attn_out) @@ -722,7 +828,8 @@ class Qwen3NextAttention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = getattr( - config, "dual_chunk_attention_config", None) + config, "dual_chunk_attention_config", None + ) self.attn_output_gate = getattr(config, "attn_output_gate", True) self.qkv_proj = QKVParallelLinear( @@ -752,6 +859,9 @@ class Qwen3NextAttention(nn.Module): partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, ) + self.rotary_dim = self.head_dim + if config.partial_rotary_factor < 1.0: + self.rotary_dim = int(self.rotary_dim * config.partial_rotary_factor) self.attn = Attention( self.num_heads, @@ -761,11 +871,14 @@ class Qwen3NextAttention(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": - self.dual_chunk_attention_config, - } if self.dual_chunk_attention_config else {}, + **( + { + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": self.dual_chunk_attention_config, + } + if self.dual_chunk_attention_config + else {} + ), ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -779,24 +892,20 @@ class Qwen3NextAttention(nn.Module): ): qkv, _ = self.qkv_proj(hidden_states) - if self.attn_output_gate: - q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) - orig_shape = q_gate.shape[:-1] - q_gate = q_gate.view(*orig_shape, self.num_heads, -1) - q, gate = torch.chunk(q_gate, 2, dim=-1) - q = q.reshape(*orig_shape, -1) - gate = gate.reshape(*orig_shape, -1) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - - q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( - -1, self.num_heads * self.head_dim) - k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( - -1, self.num_kv_heads * self.head_dim) - - q, k = self.rotary_emb(positions, q, k) + q, k, v, gate = torch.ops.xspeedgate_ops.split_norm_rope_neox( + qkv=qkv, + q_weights=self.q_norm.weight, + k_weights=self.k_norm.weight, + positions=positions, + cos_sin_cache=self.rotary_emb.cos_sin_cache, + q_size=self.q_size, + kv_size=self.kv_size, + num_q_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + rotary_dim=self.rotary_dim, + attn_output_gate=True, + ) attn_output = self.attn(q, k, v) @@ -833,23 +942,26 @@ class Qwen3NextDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, speculative_config=speculative_config, - prefix=f'{prefix}.linear_attn') + prefix=f"{prefix}.linear_attn", + ) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, - prefix=f'{prefix}.self_attn', + prefix=f"{prefix}.self_attn", ) else: raise ValueError(f"Invalid layer_type {self.layer_type}") - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (self.layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (self.layer_idx + 1) % config.decoder_sparse_step == 0): + config.num_experts > 0 + and (self.layer_idx + 1) % config.decoder_sparse_step == 0 + ): self.mlp = Qwen3NextSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp", @@ -862,10 +974,12 @@ class Qwen3NextDecoderLayer(nn.Module): quant_config=quant_config, ) - self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.post_attention_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) + config.hidden_size, eps=config.rms_norm_eps + ) self.layer_scale = getattr(config, "layer_scale", False) if self.layer_scale: @@ -875,14 +989,16 @@ class Qwen3NextDecoderLayer(nn.Module): 1, config.hidden_size, dtype=config.torch_dtype, - ), ) + ), + ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.torch_dtype, - ), ) + ), + ) def forward( self, @@ -895,8 +1011,7 @@ class Qwen3NextDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": @@ -917,26 +1032,29 @@ class Qwen3NextDecoderLayer(nn.Module): if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) else: hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype) + 1) + self.attn_layer_scale.to(hidden_states.dtype) + 1 + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) else: assert len(hidden_states.shape) == len( self.ffn_layer_scale.shape - ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + ), f"shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}" # noqa: E501 hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype) + 1) + self.ffn_layer_scale.to(hidden_states.dtype) + 1 + ) return hidden_states, residual @@ -953,8 +1071,11 @@ class Qwen3NextModel(nn.Module): self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.embed_tokens = VocabParallelEmbedding( @@ -971,14 +1092,14 @@ class Qwen3NextModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.norm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() @@ -1011,10 +1132,9 @@ class Qwen3NextModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -1026,10 +1146,10 @@ class Qwen3NextModel(nn.Module): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -1080,16 +1200,19 @@ class Qwen3NextModel(nn.Module): if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -1098,15 +1221,17 @@ class Qwen3NextModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - MixtureOfExperts, IsHybrid): +class Qwen3NextForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1123,15 +1248,17 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Qwen3Next currently does not support prefix caching" + assert ( + not cache_config.enable_prefix_caching + ), "Qwen3Next currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = Qwen3NextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen3NextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -1139,15 +1266,21 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - prefix=maybe_prefix(prefix, "lm_head")) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters self.expert_weights = [] @@ -1199,8 +1332,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): moe = layer.mlp @@ -1220,8 +1352,9 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -1231,23 +1364,30 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype) + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) @classmethod def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" + cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size - num_spec = (vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0) + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, - num_spec) + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) def compute_logits( self, @@ -1255,8 +1395,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["mtp."], @@ -1315,8 +1454,9 @@ def fused_gdn_gating_kernel( blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where(beta * x <= threshold, - (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) @@ -1332,14 +1472,7 @@ def fused_gdn_gating( seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty_like(a, dtype=torch.float32) - fused_gdn_gating_kernel[grid](g, - A_log, - a, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1) + fused_gdn_gating_kernel[grid]( + g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 + ) return g diff --git a/vllm_kunlun/models/qwen3_next_mtp.py b/vllm_kunlun/models/qwen3_next_mtp.py new file mode 100644 index 0000000..7c687f9 --- /dev/null +++ b/vllm_kunlun/models/qwen3_next_mtp.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig + +from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Qwen3NextMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) + + self.layers = torch.nn.ModuleList( + Qwen3NextDecoderLayer( + vllm_config, + layer_type="full_attention", + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = spec_step_idx % self.num_mtp_layers + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class Qwen3NextMTP(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert ( + not cache_config.enable_prefix_caching + ), "Qwen3NextMTP currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3NextMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 25495db..41c6bf2 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -16,33 +16,33 @@ # limitations under the License. """kunlun custom op entry""" -import torch_xmlir + +from typing import Optional + import torch -import os -from typing import Optional, List, Dict -import vllm.envs as envs -import os -import ctypes from vllm.logger import init_logger logger = init_logger(__name__) try: import kunlun_ops - logger.info(f"Load custom ops library success!") + + logger.info("Load custom ops library success!") except ImportError as e: logger.warning("Import error msg: %s", e.msg) _per_token_smooth_quant = True + def is_per_token_smooth_quant(): - """ is per token smooth quant """ + """is per token smooth quant""" return _per_token_smooth_quant class KunlunOps: """KunlunOps""" + # Attention ops @staticmethod def paged_attention_v1( @@ -67,9 +67,9 @@ class KunlunOps: blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, - alibi_sqrt=False - ): - """ PagedAttentionV1 """ + alibi_sqrt=False, + ): + """PagedAttentionV1""" # block_size = value_cache.shape[2] kunlun_ops.paged_attention( x=query, @@ -81,7 +81,7 @@ class KunlunOps: is_context=is_context, is_causal=True, out=output, - vo_head_dim=128 + vo_head_dim=128, ) @staticmethod @@ -110,9 +110,9 @@ class KunlunOps: blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, - alibi_sqrt=False - ): - """ PagedAttentionV2 """ + alibi_sqrt=False, + ): + """PagedAttentionV2""" # block_size = value_cache.shape[2] kunlun_ops.paged_attention( x=query, @@ -124,31 +124,28 @@ class KunlunOps: is_context=is_context, is_causal=True, out=output, - vo_head_dim=128 + vo_head_dim=128, ) - # Activation ops @staticmethod - def silu_and_mul(out: torch.Tensor, - x: torch.Tensor): - """ silu and mul """ + def silu_and_mul(out: torch.Tensor, x: torch.Tensor): + """silu and mul""" kunlun_ops.silu_and_mul( x, axis=-1, turn=True, out=out, - ) + ) # Activation ops @staticmethod - def quick_gelu(out: torch.Tensor, - x: torch.Tensor): - """ quick gelu """ + def quick_gelu(out: torch.Tensor, x: torch.Tensor): + """quick gelu""" kunlun_ops.quick_gelu( x, out=out, - ) + ) # Layernorm @staticmethod @@ -159,9 +156,7 @@ class KunlunOps: epsilon, ): """rms_norm""" - kunlun_ops.rmsnorm( - x, weight.to(torch.float32), epsilon, out=out - ) + kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out) @staticmethod def fused_add_rms_norm( @@ -179,16 +174,11 @@ class KunlunOps: residual.copy_(fused_input, non_blocking=True) x.copy_(output) - # Rotary embedding @staticmethod def rotary_embedding( - positions, - query, - key, - head_size, - cos_sin_cache, - is_neox_style): + positions, query, key, head_size, cos_sin_cache, is_neox_style + ): """ refactor RotaryEmbedding forward function """ @@ -196,62 +186,38 @@ class KunlunOps: key_x = key.contiguous() torch.ops._C.rotary_embedding( - positions, - query_x, - key_x, - head_size, - cos_sin_cache, - is_neox_style) + positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style + ) return query_x, key_x # Rotary embedding @staticmethod def mrotary_embedding( - positions, - mrope_section, - query, - key, - head_size, - cos_sin_cache, - is_neox_style): + positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style + ): """ refactor RotaryEmbedding forward function """ query_x = query.contiguous() key_x = key.contiguous() - query_x_dim = query_x.dim() assert is_neox_style kunlun_ops.mrotary_embedding_neox( - positions, - query_x, - key_x, - head_size, - cos_sin_cache, - mrope_section) + positions, query_x, key_x, head_size, cos_sin_cache, mrope_section + ) query.data = query_x - key.data = key_x + key.data = key_x return query, key @staticmethod - def swap_blocks( - src, - dst, - block_mapping): - """ swap_blocks """ - kunlun_ops.swap_blocks( - src, - dst, - block_mapping - ) + def swap_blocks(src, dst, block_mapping): + """swap_blocks""" + kunlun_ops.swap_blocks(src, dst, block_mapping) @staticmethod - def copy_blocks( - key_caches, - value_caches, - block_mapping): - """ copy_blocks """ + def copy_blocks(key_caches, value_caches, block_mapping): + """copy_blocks""" for i in range(len(key_caches)): key_caches[i] = key_caches[i].contiguous() value_caches[i] = value_caches[i].contiguous() @@ -269,16 +235,10 @@ class KunlunOps: value_cache, slot_mapping, kv_cache_dtype, - ): - """ reshape_and_cache """ + ): + """reshape_and_cache""" # slot_mapping_cast = slot_mapping.to(torch.int32) - kunlun_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping - ) + kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) @staticmethod def multi_query_kv_attention( @@ -287,7 +247,7 @@ class KunlunOps: query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - **kargs + **kargs, ) -> torch.Tensor: """ query: shape = [num_prompt_tokens, num_heads, head_size] @@ -297,16 +257,12 @@ class KunlunOps: key = key.unsqueeze(0) value = value.unsqueeze(0) output = torch.empty_like(query) - alibi_slopes = kargs.get("alibi_slopes", None) - mask = kargs.get("mask", None) - is_causal = kargs.get("is_causal", True) - is_lvsl = kargs.get("is_lvsl", True) B, T, Qh, Hd = query.shape KVh = key.size(2) if KVh != Qh: repeat = Qh // KVh - key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd] + key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd] value = value.repeat_interleave(repeat, dim=2) kunlun_ops.attention( q=query, @@ -321,80 +277,90 @@ class KunlunOps: return output @staticmethod - def quant_fusedresidual_rmsnorm_op(x, - residual, - weight, - bias, - scale_to_int, - eps, - dyn_scale: bool, - type: int = 1): + def quant_fusedresidual_rmsnorm_op( + x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1 + ): """Quantized fused residual layer normalization""" out = torch.empty_like(x, dtype=torch.int8) if is_per_token_smooth_quant(): - out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) + out_scale = torch.empty( + x.shape[:-1], device=x.device, dtype=torch.float + ).unsqueeze(-1) else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) - kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps, - out=out, out_scale=out_scale , residual_tensor=residual) + kunlun_ops.quant_fusedresidual_rmsnorm( + x, + residual, + weight, + bias, + eps, + out=out, + out_scale=out_scale, + residual_tensor=residual, + ) if residual is None: return out, out_scale return out, out_scale, residual @staticmethod - def quant_rmsnorm_op(x, - weight, - bias, - scale_to_int, - eps, - dyn_scale : bool, - type: int = 1): + def quant_rmsnorm_op( + x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1 + ): """Quantized RMSNorm""" out = torch.empty_like(x, dtype=torch.int8) if is_per_token_smooth_quant(): - out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) + out_scale = torch.empty( + x.shape[:-1], device=x.device, dtype=torch.float + ).unsqueeze(-1) else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) - kunlun_ops.quant_rmsnorm(x, weight, bias, eps, - out=out, out_scale=out_scale) + kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale) return out, out_scale @staticmethod - def smooth_quant_matmul_column_row_kernels(input_tensor, - weight, - smoother, - input_scale, - weight_scale, - perTokenScaling, - perChannelScaling, - otype): + def smooth_quant_matmul_column_row_kernels( + input_tensor, + weight, + smoother, + input_scale, + weight_scale, + perTokenScaling, + perChannelScaling, + otype, + ): """smooth_quant_matmul_column_row_kernels""" input_shape = input_tensor.shape weight_shape = weight.shape if input_tensor.dim() == 3: input_tensor = input_tensor.reshape(-1, input_shape[-1]) - out = torch.empty((input_shape[0] * input_shape[1], - weight_shape[0]), - dtype=torch.float16, - device=weight.device) + out = torch.empty( + (input_shape[0] * input_shape[1], weight_shape[0]), + dtype=torch.float16, + device=weight.device, + ) output_bs_shape = [input_shape[0], input_shape[1]] elif input_tensor.dim() == 2: - out = torch.empty((input_shape[0], weight_shape[0]), - dtype=torch.float16, - device=weight.device) + out = torch.empty( + (input_shape[0], weight_shape[0]), + dtype=torch.float16, + device=weight.device, + ) output_bs_shape = [-1] - kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor, - weight, smoother, - input_scale, - weight_scale, - perTokenScaling, - perChannelScaling, - out=out) + kunlun_ops.smooth_quant_matmul_column_row_kernels( + input_tensor, + weight, + smoother, + input_scale, + weight_scale, + perTokenScaling, + perChannelScaling, + out=out, + ) out = out.view(*output_bs_shape, weight_shape[0]) @@ -404,6 +370,7 @@ class KunlunOps: if torch.is_tensor(x): return (type(x), x.device, x.dtype, x.shape, x.is_contiguous()) return (type(x), x) + @staticmethod def fused_moe( hidden_states: torch.Tensor, @@ -420,23 +387,24 @@ class KunlunOps: w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """fused_moe""" global_num_experts, up_gate_size, _ = w1.shape M, N = hidden_states.shape hidden_dim = w2.shape[1] - normed_score = torch.empty(M, - moe_top_k, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - moe_top_k, - dtype=torch.int32, - device=hidden_states.device) + normed_score = torch.empty( + M, moe_top_k, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, moe_top_k, dtype=torch.int32, device=hidden_states.device + ) num_blocks = 12 block_statistic = torch.zeros( - num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device + num_blocks, + global_num_experts, + dtype=torch.int32, + device=hidden_states.device, ) router_logits = router_logits.to(torch.float) if scoring_func == "softmax": @@ -445,24 +413,27 @@ class KunlunOps: normed_score=normed_score, topk_index=topk_ids, block_statistic=None, - stable=True) + stable=False, + ) elif scoring_func == "sigmoid": torch.ops._C.moe_sigmoid_group_topk_norm( - x=router_logits, - topk_index=topk_ids, - norm_score=normed_score, - block_static=block_statistic, - bias=e_score_correction_bias, - scale=1.0, - n_group=num_expert_group, - topk_group=topk_group, - ) + x=router_logits, + topk_index=topk_ids, + norm_score=normed_score, + block_static=block_statistic, + bias=e_score_correction_bias, + scale=1.0, + n_group=num_expert_group, + topk_group=topk_group, + ) - if w1_bias is not None or w2_bias is not None: + if w1_bias is not None or w2_bias is not None: # Rignt now this branch is for gpt oss # TODO (@xyDong23): faster here using moe_fc kernel normed_score = normed_score.to(hidden_states.dtype) - out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device) + out = torch.zeros( + M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device + ) repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0) topk_ids_flat = topk_ids.flatten() for i in range(global_num_experts): @@ -470,9 +441,13 @@ class KunlunOps: selected_token = topk_ids_flat == experts_id if selected_token.sum(): cur_token = repeat_x[selected_token] - up_gate = torch.empty(selected_token.sum(), up_gate_size//2, - dtype=cur_token.dtype, device=cur_token.device) - groupgemm1 = cur_token@ w1[i].T + up_gate = torch.empty( + selected_token.sum(), + up_gate_size // 2, + dtype=cur_token.dtype, + device=cur_token.device, + ) + groupgemm1 = cur_token @ w1[i].T # Add w13 bias if w1_bias is not None: groupgemm1 = groupgemm1 + w1_bias[i] @@ -482,53 +457,129 @@ class KunlunOps: if w2_bias is not None: groupgemm2 = groupgemm2 + w2_bias[i] out[selected_token] = groupgemm2 - ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype) + ouput = ( + (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)) + .sum(dim=1) + .to(hidden_states.dtype) + ) return ouput else: - moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float - expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] - sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] - sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) - - torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + # from vllm.forward_context import get_forward_context + # forward_context = get_forward_context() + # attn_metadata: AttentionMetadata = forward_context.attn_metadata + # prefix = "model.layers.0.linear_attn" + # if attn_metadata is not None: + # attn_metadata = attn_metadata[prefix] - torch.ops._C.moe_pre_sorted( - x=hidden_states, - topk_index=topk_ids, - block_statistic=block_statistic, - moe_expand=moe_expand, - moe_index=sorted_tokens_idx, - expert_m=expert_m, - sorted_tokens_num_lod=sorted_tokens_num_lod) + # if attn_metadata is None or attn_metadata.num_prefills > 0 or : + if M * moe_top_k < 400: + sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = ( + torch.ops.xspeedgate_ops.moe_pre_small( + topk_ids, global_num_experts, False, False, hidden_states + ) + ) + experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance( + topk_ids, global_num_experts, False + ) + out = torch.ops.xspeedgate_ops.fused_moe( + hidden_states, + w1, + w2, + normed_score.to(hidden_states.dtype), + sorted_tokens_num_lod, + sorted_tokens_idx, + experts_num_lod, + ) + return out.sum(1) - y = torch.empty(M,moe_top_k, - w1.shape[1], + if M * moe_top_k > 768: + moe_expand = torch.empty( + (M * moe_top_k, N), dtype=hidden_states.dtype, - device=hidden_states.device) + device=hidden_states.device, + ) # [M*top_k, N], float + expert_m = torch.zeros( + global_num_experts, dtype=torch.int32, device=hidden_states.device + ) # [E] + sorted_tokens_num_lod = torch.zeros( + global_num_experts + 1, + dtype=torch.int32, + device=hidden_states.device, + ) # [E+1] + sorted_tokens_idx = torch.zeros( + M * moe_top_k, dtype=torch.int32, device=hidden_states.device + ) + + torch.ops._C.gen_block_statistic(topk_ids, block_statistic) + + torch.ops._C.moe_pre_sorted( + x=hidden_states, + topk_index=topk_ids, + block_statistic=block_statistic, + moe_expand=moe_expand, + moe_index=sorted_tokens_idx, + expert_m=expert_m, + sorted_tokens_num_lod=sorted_tokens_num_lod, + ) + else: + sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = ( + torch.ops.xspeedgate_ops.moe_pre_small( + topk_ids, + global_num_experts, + index_have_neg=False, + sort_mode=True, + x=hidden_states, + ) + ) + + y = torch.empty( + M, + moe_top_k, + w1.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) - torch.ops._C.moe_fc( - x=moe_expand, - weight=w1, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=moe_top_k, - y=y, + if M < 1024: + torch.ops._C.moe_fc( + x=moe_expand, + weight=w1, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=y, + ) + + d = y.shape[-1] // 2 + output_shape = y.shape[:-1] + (d,) + out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + torch.ops._C.silu_and_mul(out1, y) + + out1 = out1.reshape(-1, out1.shape[-1]) + else: + torch.ops._C.moe_fc( + x=moe_expand, + weight=w1, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=y, + act="SWISH_GLU", + ) + + y = y[..., : y.shape[-1] // 2] + out1 = y.reshape(-1, y.shape[-1]) + + out = torch.empty( + M, + moe_top_k, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, ) - d = y.shape[-1] // 2 - output_shape = (y.shape[:-1] + (d, )) - out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) - torch.ops._C.silu_and_mul(out1, y) - - out = torch.empty(M,moe_top_k, - w2.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) - - out1 = out1.reshape(-1, out1.shape[-1]) - torch.ops._C.moe_fc( x=out1, weight=w2, @@ -538,8 +589,12 @@ class KunlunOps: y=out, ) - dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device) - output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) + dequant_scale = torch.ones( + [M, moe_top_k], dtype=torch.float32, device=out.device + ) + output = torch.empty( + [M, N], dtype=hidden_states.dtype, device=hidden_states.device + ) sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k) torch.ops._C.moe_post( @@ -547,9 +602,9 @@ class KunlunOps: moe_index=sorted_tokens_idx, normed_scale=normed_score, dequant_scale=dequant_scale, - y=output + y=output, ) - + return output @staticmethod @@ -568,23 +623,23 @@ class KunlunOps: topk_group: Optional[int] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> torch.Tensor: x = hidden_states - batch, hidden_size = x.shape + batch, hidden_size = x.shape num_local_experts, up_gate_size, _ = w13_weight.shape - router_logits = x.to(linear_weights.dtype)@linear_weights.T - - topk_weights = torch.empty(batch, - top_k, - dtype=router_logits.dtype, - device=router_logits.device) - topk_ids = torch.empty(batch, - top_k, - dtype=torch.int32, - device=router_logits.device) - block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device) - torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static) + router_logits = x.to(linear_weights.dtype) @ linear_weights.T + + topk_weights = torch.empty( + batch, top_k, dtype=router_logits.dtype, device=router_logits.device + ) + topk_ids = torch.empty( + batch, top_k, dtype=torch.int32, device=router_logits.device + ) + block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device) + torch.ops._C.moe_softmax_topk( + router_logits, topk_weights, topk_ids, block_static + ) if renormalize: topk_weights = topk_weights / topk_weights.sum(1, keepdim=True) @@ -598,11 +653,19 @@ class KunlunOps: selected_token = topk_ids_flat == experts_id if selected_token.sum(): cur_token = repeat_x[selected_token] - up_gate = torch.empty(selected_token.sum(), up_gate_size//2, - dtype=cur_token.dtype, device=cur_token.device) - torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T) + up_gate = torch.empty( + selected_token.sum(), + up_gate_size // 2, + dtype=cur_token.dtype, + device=cur_token.device, + ) + torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T) out[selected_token] = up_gate @ w2_weight[i].T - output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype) + output = ( + (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)) + .sum(dim=1) + .to(x.dtype) + ) return output @@ -638,10 +701,11 @@ class KunlunOps: prompt_lods_cpu: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - ) -> torch.Tensor: + ) -> torch.Tensor: """mla pa block""" - output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, - device=hidden_states.device) + output = torch.empty( + hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device + ) kunlun_ops.xft_multi_head_latent_page_attention_block( hidden_states, q_lora_rank, @@ -679,7 +743,6 @@ class KunlunOps: ) return output - def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, @@ -695,25 +758,34 @@ class KunlunOps: ) return output - def fused_recurrent_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - h0_source: torch.Tensor, - output_final_state: bool, - use_qk_l2norm_in_kernel: bool, - cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]: - ''' - Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起 - 1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。 - 2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。 - ''' + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + h0_source: torch.Tensor, + output_final_state: bool, + use_qk_l2norm_in_kernel: bool, + cu_seqlens: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起 + 1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。 + 2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。 + """ - o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd( - q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel, - cu_seqlens) - return (o, final_state) \ No newline at end of file + o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd( + q, + k, + v, + g, + beta, + scale, + h0_source, + output_final_state, + use_qk_l2norm_in_kernel, + cu_seqlens, + ) + return (o, final_state) diff --git a/vllm_kunlun/ops/fla/chunk.py b/vllm_kunlun/ops/fla/chunk.py index 20a039f..f2f43c0 100644 --- a/vllm_kunlun/ops/fla/chunk.py +++ b/vllm_kunlun/ops/fla/chunk.py @@ -9,60 +9,196 @@ # ruff: noqa: E501 import warnings from typing import Optional -import torch.nn.functional as F +import cocopod # noqa import torch -import torch.distributed as dist +import torch.nn.functional as F from einops import rearrange -from .chunk_delta_h import chunk_gated_delta_rule_fwd_h -from .chunk_o import chunk_fwd_o -from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from .cumsum import chunk_local_cumsum +from .index import prepare_chunk_indices, prepare_chunk_offsets from .l2norm import l2norm_fwd -from .solve_tril import solve_tril from .utils import SUPPRESS_LEVEL, input_guard -from .wy_fast import recompute_w_u_fwd -from .index import prepare_chunk_indices -import xspeedgate_ops -import cocopod -def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): - chunk_size=64 - A = -A.transpose(1,2) +def torch_solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + output_dtype: torch.dtype = torch.float, +): + chunk_size = 64 + A = -A.transpose(1, 2) sequence_length = A.shape[-2] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size A = F.pad(A, (0, 0, 0, pad_size)) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) + # mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0) + # A = A.masked_fill(mask, 0) for i in range(1, chunk_size): row = A[..., i, :i].clone() sub = A[..., :i, :i].clone() A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device) - return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2) + return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[ + :, :, :sequence_length, : + ].transpose(1, 2) -def chunk_gated_delta_rule_fwd(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None): - g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) - A = chunk_scaled_dot_kkt_fwd(k=k, - beta=beta, - g_cumsum=g, - cu_seqlens=cu_seqlens, - output_dtype=q.dtype) - #kernel版 - torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) - chunk_indices = prepare_chunk_indices( - cu_seqlens, 64) if cu_seqlens is not None else None +def recompute_w_u_fwd_torch( + k: torch.Tensor, # [B, T, H, K] + v: torch.Tensor, # [B, T, H, V] + beta: torch.Tensor, # [B, T, H] + g: torch.Tensor, # [B, T, H] + A: torch.Tensor, # [B, H, T, T] +): + """ + 最简单版本:假设等长序列,key和value头数相同 + """ + chunk_size = 64 + num_v_heads, num_k_heads = v.shape[2], k.shape[2] + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + k, v, beta, g, A = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A) + ] + + batch_size, num_heads, sequence_length, k_head_dim = k.shape + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + k = F.pad(k, (0, 0, 0, pad_size)) + v = F.pad(v, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + A = F.pad(A, (0, 0, 0, pad_size)) + A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) + + v_beta = v * beta.unsqueeze(-1) + k_beta = k * beta.unsqueeze(-1) + + k, v, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (k, v, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + u = A @ v_beta + w = A @ (k_beta * g.exp().unsqueeze(-1)) + w = ( + w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :] + .transpose(1, 2) + .contiguous() + ) + u = ( + u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :] + .transpose(1, 2) + .contiguous() + ) + + return w, u + + +def split_by_value(tensor, chunk_size=64): + indices = tensor.tolist() + result = set(indices) # 使用集合避免重复 + + for i in range(len(indices) - 1): + start = indices[i] + end = indices[i + 1] + + # 计算第一个对齐边界 + # 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数 + first_boundary = start + chunk_size + + # 在(start, end)范围内插入所有对齐边界 + boundary = first_boundary + while boundary < end: + result.add(boundary) + boundary += chunk_size + + return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device) + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + chunk_size = 64 + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None + ) + chunk_offsets = ( + prepare_chunk_offsets(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) + + # ! + # g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + g = torch.ops.xspeedgate_ops.chunk_local_cumsum( + g, + chunk_size=64, + reverse=False, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + head_first=False, + ) + + # ! + # A = chunk_scaled_dot_kkt_fwd(k=k, + # beta=beta, + # g_cumsum=g, + # cu_seqlens=cu_seqlens, + # output_dtype=q.dtype) + A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd( + k, beta, g, cu_seqlens, chunk_indices, chunk_size + ) + + # torch版 + # if get_tensor_model_parallel_rank() == 0: + # torch.save(A, "A_in") + # torch.save(cu_seqlens, "cu_seqlens") + # A2 = A.clone() + torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size) + + # ! + # torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) + # if get_tensor_model_parallel_rank() == 0: + # err = torch.max(torch.abs(A - A2)) + # print("err", err) + # if err > 1e-3: + # raise + # A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + # for i in range(len(cu_seqlens)-1): + # A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] + # A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype) + + """ + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + for i in range(len(cu_seqlens)-1): + k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] + v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] + beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :] + A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] + g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :] + + w_i, u_i = recompute_w_u_fwd_torch( + k=k_i, + v=v_i, + beta=beta_i, + A=A_i, + g=g_i, + ) + w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i + u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i + """ w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd( k=k, v=v, @@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, g_cumsum=g, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, - chunk_size=64 + chunk_size=64, ) - h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + """ + w, u = recompute_w_u_fwd( k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=output_final_state, + v=v, + beta=beta, + A=A, + g_cumsum=g, cu_seqlens=cu_seqlens, ) + """ + + # i + # import os + # if not os.path.exists("/qwen-next/in"): + # os.makedirs("/qwen-next/in") + # torch.save(k, "/qwen-next/in/k.pt") + # torch.save(u, "/qwen-next/in/u.pt") + # torch.save(w, "/qwen-next/in/w.pt") + # torch.save(g, "/qwen-next/in/g.pt") + # torch.save(initial_state, "/qwen-next/in/initial_state.pt") + # torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt") + # torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt") + # torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt") + # torch.save(chunk_size, "/qwen-next/in/chunk_size.pt") + # torch.save(output_final_state, "/qwen-next/in/output_final_state.pt") + + h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h( + k, + u, + w, + g, + initial_state, + cu_seqlens, + chunk_indices, + chunk_offsets.to(torch.int32), + chunk_size, + output_final_state, + True, + ) + + # h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + # k=k, + # w=w, + # u=u, + # g=g, + # initial_state=initial_state, + # output_final_state=output_final_state, + # cu_seqlens=cu_seqlens, + # ) + # if not os.path.exists("/qwen-next/out"): + # os.makedirs("/qwen-next/out") + # torch.save(h, "/qwen-next/out/h.pt") + # torch.save(v_new, "/qwen-next/out/v_new.pt") + # torch.save(final_state, "/qwen-next/out/final_state.pt") + o = torch.ops.xspeedgate_ops.chunk_fwd_o( q=q, k=k, @@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, - chunk_size=64 + chunk_size=64, ) + """ + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + """ if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None elif SUPPRESS_LEVEL >= 3: @@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @staticmethod @input_guard - @torch.amp.custom_fwd(device_type='cuda') - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False): + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) @@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @torch.compiler.disable -def chunk_gated_delta_rule(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = False): +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): r""" Args: q (torch.Tensor): @@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor, ) """ assert q.dtype == k.dtype == v.dtype - assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." - assert len( - beta.shape - ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + assert ( + q.dtype != torch.float32 + ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert ( + len(beta.shape) == 3 + ), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." if head_first: raise DeprecationWarning( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", - stacklevel=2) + stacklevel=2, + ) q, k, v, beta, g = map( - lambda x: rearrange(x, 'b h t ... -> b t h ...'), - (q, k, v, beta, g)) + lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g) + ) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2) + stacklevel=2, + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") - if initial_state is not None and initial_state.shape[0] != len( - cu_seqlens) - 1: + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if scale is None: - scale = k.shape[-1]**-0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, - use_qk_l2norm_in_kernel) + scale = k.shape[-1] ** -0.5 + + if False: + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + g = g.contiguous() + beta = beta.contiguous() + initial_state = initial_state.contiguous() + + o = torch.empty_like(v) + final_state = torch.empty_like(initial_state) + import kunlun_ops + + kunlun_ops.gated_delta_rule( + q, + k, + v, + initial_state, + g, + beta, + final_state, + o, + scale, + cu_seqlens.cpu(), + cu_seqlens, + cu_seqlens.cpu(), + cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + else: + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) if head_first: - o = rearrange(o, 'b t h ... -> b h t ...') + o = rearrange(o, "b t h ... -> b h t ...") return o, final_state diff --git a/vllm_kunlun/ops/fla/chunk_o.py b/vllm_kunlun/ops/fla/chunk_o.py index f861ffc..5eb7260 100644 --- a/vllm_kunlun/ops/fla/chunk_o.py +++ b/vllm_kunlun/ops/fla/chunk_o.py @@ -12,21 +12,21 @@ from typing import Optional import torch - from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices -from .op import exp from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None -}) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) # @triton.autotune( # configs=[ # triton.Config({ @@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] # ], # key=['H', 'K', 'V', 'BT'], # ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def chunk_fwd_kernel_o( q, k, @@ -67,10 +67,12 @@ def chunk_fwd_kernel_o( if IS_VARLEN: i_tg = i_t - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) else: @@ -89,12 +91,15 @@ def chunk_fwd_kernel_o( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), - (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), - (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), - (BK, BV), (1, 0)) + p_q = tl.make_block_ptr( + q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) + ) + p_h = tl.make_block_ptr( + h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) + ) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BT] @@ -109,8 +114,8 @@ def chunk_fwd_kernel_o( if USE_G: g += bos * H + i_h - p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) b_o = b_o * tl.exp(b_g)[:, None] b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :]) @@ -120,10 +125,12 @@ def chunk_fwd_kernel_o( # b_A = tl.where(m_A, b_A, 0) b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0) - p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) + p_v = tl.make_block_ptr( + v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_o = tl.make_block_ptr( + o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) b_v = tl.load(p_v, boundary_check=(0, 1)) # to fix mma -> mma layout conversion @@ -133,48 +140,29 @@ def chunk_fwd_kernel_o( def chunk_fwd_o( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - h: torch.Tensor, - g: Optional[torch.Tensor] = None, # cumsum of log decay - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64) -> torch.Tensor: - B, T, Hg, K, V = *q.shape, v.shape[-1] - H = v.shape[-2] + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + _, T, _, _, _ = *q.shape, v.shape[-1] if FLA_GDN_FIX_BT: BT = 64 else: BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 o = torch.empty_like(v) - def grid(meta): - return (triton.cdiv(V, meta['BV']), NT, B * H) - - chunk_fwd_kernel_o[grid]( - q, - k, - v, - h, - g, - o, - cu_seqlens, - chunk_indices, - scale, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=64, - BV=32 + o = torch.ops.xspeedgate_ops.chunk_fwd_o( + q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size ) return o diff --git a/vllm_kunlun/ops/fla/fused_recurrent.py b/vllm_kunlun/ops/fla/fused_recurrent.py index d3c81b1..a6acdf0 100644 --- a/vllm_kunlun/ops/fla/fused_recurrent.py +++ b/vllm_kunlun/ops/fla/fused_recurrent.py @@ -9,28 +9,28 @@ # ruff: noqa: E501 from typing import Optional -import torch - import kunlun_ops +import torch class FusedRecurrentFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False): - + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2( q.contiguous(), k.contiguous(), @@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function): h0_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - is_h0_transposed=True + is_h0_transposed=True, ) return o, final_state @@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule( if cu_seqlens is not None and q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + f"Please flatten variable-length inputs before processing." + ) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive" if beta is None: diff --git a/vllm_kunlun/ops/fla/l2norm.py b/vllm_kunlun/ops/fla/l2norm.py index 55cc061..d3dcc62 100644 --- a/vllm_kunlun/ops/fla/l2norm.py +++ b/vllm_kunlun/ops/fla/l2norm.py @@ -10,22 +10,21 @@ import os from typing import Optional +import kunlun_ops import torch from vllm.triton_utils import tl, triton -import kunlun_ops - - BT_LIST = [8, 16, 32, 64, 128] USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) -@triton.autotune(configs=[ - triton.Config({}, num_warps=num_warps) - for num_warps in [1, 2, 4, 8, 16, 32] -], - key=['D']) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=["D"], +) @triton.jit def l2norm_fwd_kernel1( x, @@ -49,11 +48,14 @@ def l2norm_fwd_kernel1( tl.store(y + cols, b_y, mask=mask) -@triton.autotune(configs=[ - triton.Config({'BT': BT}, num_warps=num_warps) - for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST -], - key=['D']) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in BT_LIST + ], + key=["D"], +) @triton.jit(do_not_specialize=["NB"]) def l2norm_fwd_kernel( x, @@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd_triton(x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None): - x_shape_og = x.shape - x = x.view(-1, x.shape[-1]) - # allocate output - if output_dtype is None: - y = torch.empty_like(x) - else: - y = torch.empty_like(x, dtype=output_dtype) - assert y.stride(-1) == 1 - T, D = x.shape[0], x.shape[-1] - # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) - if D > BD: - raise RuntimeError("This layer doesn't support feature dim >= 64KB.") - - if not USE_DEFAULT_FLA_NORM: - MBLOCK = 32 - # M, N = x.shape - l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )]( - x, - y, - eps, - T, - D, - MBLOCK, - ) - else: - if D <= 512: - NB = triton.cdiv(T, 2048) - - def grid(meta): - return (triton.cdiv(T, meta['BT']), ) - - l2norm_fwd_kernel[grid]( - x, - y, - eps, - NB=NB, - T=T, - D=D, - BD=BD, - ) - else: - l2norm_fwd_kernel1[(T, )]( - x, - y, - eps=eps, - D=D, - BD=BD, - ) - - return y.view(x_shape_og) - - -def l2norm_fwd(x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None): +def l2norm_fwd( + x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None +): out = torch.empty_like(x) - kunlun_ops.l2norm(x, out, eps) + kunlun_ops.l2norm(x, out, eps) return out diff --git a/vllm_kunlun/ops/fla/layernorm_guard.py b/vllm_kunlun/ops/fla/layernorm_guard.py index a6a5f43..b83c728 100644 --- a/vllm_kunlun/ops/fla/layernorm_guard.py +++ b/vllm_kunlun/ops/fla/layernorm_guard.py @@ -19,20 +19,21 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange - from vllm.triton_utils import tl, triton from .utils import input_guard -def rms_norm_ref(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - upcast=True): +def rms_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True, +): dtype = x.dtype weight = weight.float() bias = bias.float() if bias is not None else None @@ -43,12 +44,10 @@ def rms_norm_ref(x, x = x * F.silu(z) if group_size is None: rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * - weight) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) else: x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + - eps) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight if bias is not None: out = out + bias @@ -57,10 +56,12 @@ def rms_norm_ref(x, return out.to(dtype) -@triton.heuristics({ - "HAS_BIAS": lambda args: args["B"] is not None, - "HAS_Z": lambda args: args["Z"] is not None, -}) +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + } +) @triton.jit def layer_norm_fwd_kernel( X, # pointer to the input @@ -97,17 +98,17 @@ def layer_norm_fwd_kernel( B += group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) + xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: - xbar = tl.where(cols < N, x, 0.) + xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) @@ -149,46 +150,50 @@ def layer_norm_fwd( # weight = weight.reshape(N) # print("weight",weight.shape) # print("x",x.shape) - assert weight.shape == (N, ) + assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 - assert bias.shape == (N, ) + assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = torch.empty((ngroups * M, ), dtype=torch.float32, - device=x.device) if not is_rms_norm else None - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) - layer_norm_fwd_kernel[grid](x, - out, - weight, - bias, - z, - mean, - rstd, - x.stride(0), - out.stride(0), - z.stride(0) if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps) + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) return out, mean, rstd @@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function): @input_guard @staticmethod - def forward(ctx, - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" x_shape_og = x.shape # reshape input data into 2D tensor @@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function): weight = weight.contiguous() if bias is not None: bias = bias.contiguous() - y, mean, rstd = layer_norm_fwd( - x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=is_rms_norm, + # y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm) + y = torch.empty_like(x) + mean, rstd = None, None + import kunlun_ops + + kunlun_ops.rms_norm_gated( + x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) ctx.x_shape_og = x_shape_og ctx.eps = eps @@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function): return y.reshape(x_shape_og) -def layernorm_fn(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, is_rms_norm) +def layernorm_fn( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm + ) -def rmsnorm_fn(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, True) +def rmsnorm_fn( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, True + ) class LayerNormGated(nn.Module): @@ -294,15 +299,16 @@ class LayerNormGated(nn.Module): torch.nn.init.zeros_(self.bias) def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ - return layernorm_fn(x, - self.weight, - self.bias, - z=z, - group_size=self.group_size, - eps=self.eps, - norm_before_gate=self.norm_before_gate) + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) class RMSNormGated(nn.Module): @@ -332,12 +338,13 @@ class RMSNormGated(nn.Module): torch.nn.init.ones_(self.weight) def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ - return rmsnorm_fn(x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate) + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return rmsnorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) diff --git a/vllm_kunlun/ops/fla/wy_fast.py b/vllm_kunlun/ops/fla/wy_fast.py index f3b3b7a..909a47e 100644 --- a/vllm_kunlun/ops/fla/wy_fast.py +++ b/vllm_kunlun/ops/fla/wy_fast.py @@ -11,7 +11,6 @@ from typing import Optional import torch - from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices @@ -28,6 +27,7 @@ RESOLUTION = { torch.complex64: 1.3e-6, } + def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): assert res.dtype == dtype ref = ref.to(dtype) @@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): rtol = RESOLUTION[dtype] torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan) + @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) # @triton.autotune( # configs=[ @@ -80,7 +81,6 @@ def recompute_u_fwd_kernel( p_beta = tl.make_block_ptr( beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) ) - p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr( A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) ) @@ -110,7 +110,6 @@ def recompute_u_fwd_kernel( tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) - @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) # @triton.autotune( # configs=[ @@ -195,53 +194,12 @@ def recompute_w_u_fwd( A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor], ) -> tuple[torch.Tensor, torch.Tensor]: - B, T, Hg, K, V = *k.shape, v.shape[-1] - H = v.shape[-2] BT = A.shape[-1] - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - BK = 64 - BV = 64 - u = torch.empty_like(v) - w = k.new_empty(B, T, H, K) - recompute_u_fwd_kernel[(NT, B * H)]( - k=k, - v=v, - beta=beta, - w=w, - u=u, - A=A, - g=g_cumsum, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=BK, - BV=BV, + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) - recompute_w_fwd_kernel[(NT, B * H)]( - k=k, - v=v, - beta=beta, - w=w, - u=u, - A=A, - g=g_cumsum, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=BK, - BV=BV, + w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd( + k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT ) - return w, u \ No newline at end of file + return w, u diff --git a/vllm_kunlun/ops/layernorm.py b/vllm_kunlun/ops/layernorm.py index b65dc76..d705b0c 100644 --- a/vllm_kunlun/ops/layernorm.py +++ b/vllm_kunlun/ops/layernorm.py @@ -15,51 +15,52 @@ # This file is a part of the vllm-ascend project. # -import torch - -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm -from vllm.model_executor.layers import layernorm from typing import Optional, Union -import kunlun_ops + +import torch +from vllm.model_executor.layers import layernorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm + def vllm_kunlun_forward_cuda( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """forward_cuda""" - if x.is_contiguous() == False: - # kunlun does not support uncontiguous input and they do not think it is a bug - # so we must make it contiguous() manually - x = x.contiguous() - if self.variance_size_override is not None: - return self.forward_native(x, residual) + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """forward_cuda""" + if not x.is_contiguous(): + # kunlun does not support uncontiguous input and they do not think it is a bug + # so we must make it contiguous() manually + x = x.contiguous() + if self.variance_size_override is not None: + return self.forward_native(x, residual) - - if residual is not None: - # residual_output = torch.empty_like(residual) - torch.ops._C.add_rmsnorm( - x, - residual, - residual_output=residual, - weight=self.weight.data, - eps=self.variance_epsilon, - output=x - ) - return x, residual - out = torch.empty_like(x) - torch.ops._C.rmsnorm( + if residual is not None: + # residual_output = torch.empty_like(residual) + torch.ops._C.add_rmsnorm( x, - self.weight.data, - out, - self.variance_epsilon, + residual, + residual_output=residual, + weight=self.weight.data, + eps=self.variance_epsilon, + output=x, ) - return out + return x, residual + out = torch.empty_like(x) + torch.ops._C.rmsnorm( + x, + self.weight.data, + out, + self.variance_epsilon, + ) + return out + RMSNorm.forward_cuda = vllm_kunlun_forward_cuda RMSNorm.forward = vllm_kunlun_forward_cuda + class KunlunGemmaRMSNorm(OriGemmaRMSNorm): @staticmethod def forward_xpu( @@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm): x: torch.Tensor, residual: Optional[torch.Tensor], ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if x.is_contiguous() == False: + if not x.is_contiguous(): # kunlun does not support uncontiguous input and they do not think it is a bug # so we must make it contiguous() manually x = x.contiguous() - + if x.dim() == 3: + x_shape = x.shape + x = x.view(-1, x.size(-1)) if residual is not None: - torch.ops._C.add_rmsnorm( + out = torch.empty_like(x) + out_residual = torch.empty_like(residual) + torch.ops._C.gemma_add_rmsnorm( x, residual, - residual_output=residual, - weight=weight+1, + residual_output=out_residual, + weight=weight, eps=variance_epsilon, - output=x + output=out, + ) + else: + out = torch.empty_like(x) + torch.ops._C.gemma_rmsnorm( + x, + weight, + out, + variance_epsilon, ) - return x, residual - out = torch.empty_like(x) - torch.ops._C.rmsnorm( - x, - weight+1, - out, - variance_epsilon, - ) - return out + if x.dim() == 3: + x = x.view(x_shape) + if out is not None: + out = out.view(x_shape) + + if residual is not None: + return out, out_residual + else: + return out def forward_cuda( self, @@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm): residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if torch.compiler.is_compiling(): - self.forward_static = self.forward_xpu # only use in cudagraph + self.forward_static = self.forward_xpu # only use in cudagraph return self.forward_native(x, residual) if not getattr(self, "_is_compiled", False): self.forward_static = torch.compile( # type: ignore - self.forward_static, backend="aot_eager") + self.forward_static, backend="aot_eager" + ) self._is_compiled = True return self.forward_native(x, residual) RMSNorm.forward_cuda = vllm_kunlun_forward_cuda RMSNorm.forward = vllm_kunlun_forward_cuda -layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm \ No newline at end of file +layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm diff --git a/vllm_kunlun/ops/mamba/causal_conv1d.py b/vllm_kunlun/ops/mamba/causal_conv1d.py index 4bc6416..1252d27 100644 --- a/vllm_kunlun/ops/mamba/causal_conv1d.py +++ b/vllm_kunlun/ops/mamba/causal_conv1d.py @@ -6,13 +6,12 @@ from typing import Optional, Union +import kunlun_ops import numpy as np import torch import torch.nn.functional as F - from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton -import kunlun_ops @triton.jit() @@ -36,8 +35,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # Strides stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl. - constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, @@ -59,14 +57,15 @@ def _causal_conv1d_fwd_kernel( # continuous batching NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - ): conv_states_ptr = initial_states_ptr conv_state_indices_ptr = cache_indices_ptr stride_conv_state_seq = stride_istate_seq stride_conv_state_dim = stride_istate_dim stride_conv_state_tok = stride_istate_token - state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + state_len = ( + KERNEL_WIDTH - 1 + ) # can be passed via argument if it's not the same as this value # one program handles one chunk in a single sequence # rather than mixing sequences - to make updating initial_states across sequences efficiently @@ -90,12 +89,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence - x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + x_base = ( + x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + ) # [BLOCK_N,] if IS_CONTINUOUS_BATCHING: # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) else: # cache_idx conv_state_batch_coord = idx_seq @@ -103,9 +103,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence return - conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_states_base = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -116,12 +118,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching # read from conv_states load_init_state = False if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states - prior_tokens = conv_states_base + (state_len - - 1) * stride_conv_state_tok + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] @@ -151,38 +151,44 @@ def _causal_conv1d_fwd_kernel( # continuous batching # prior-tokens are zeros if KERNEL_WIDTH >= 2: # STRATEGY1 # first chunk and does not have prior-token, so just set to 0 - col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 3: # STRATEGY1 - col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 4: # STRATEGY1 - col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 5: # STRATEGY1 - col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) # STEP 2: # here prepare data for updating conv_state - if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + if ( + state_len <= seqlen + ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) # just read from 'x' # copy 'x' data to conv_state # load only 'x' data (and set 0 before 'x' if seqlen < state_len) idx_tokens_last = (seqlen - state_len) + tl.arange( - 0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ( - (sequence_start_index + idx_tokens_last) * - stride_x_token)[:, None] + ( - idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - mask_x = ((idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + 0, NP2_STATELEN + ) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] + conv_states_ptrs_target = ( + conv_states_base[None, :] + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] # tl.debug_barrier() # NOTE: use this due to bug in Triton compiler tl.store(conv_states_ptrs_target, new_conv_state, mask) @@ -192,27 +198,30 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] conv_states_ptrs_source = ( - conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, - None] + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) # tl.debug_barrier( @@ -220,11 +229,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching new_conv_state = tl.where( mask, conv_state, loaded_x ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # load_init_state == False # update conv_state by shifting left, BUT @@ -233,21 +244,25 @@ def _causal_conv1d_fwd_kernel( # continuous batching VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # chunk_offset > 0 @@ -257,37 +272,38 @@ def _causal_conv1d_fwd_kernel( # continuous batching mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 3: conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 4: conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 5: # ruff: noqa: F841 conv_states_ptrs = prior_tokens # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) x_base_1d = x_base + token_offset * stride_x_token # starting of chunk @@ -352,9 +368,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < segment_len) & ( - idx_feats < dim) # token-index # feature-index - o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token - ) * stride_o_token + (idx_feats * stride_o_dim) + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) @@ -428,21 +448,15 @@ def causal_conv1d_fn_triton( batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: - seqlens = query_start_loc.diff().to('cpu') + seqlens = query_start_loc.diff().to("cpu") args = seqlens MAX_NUM_PROGRAMS = 1024 batch_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking which seq-idx the Triton program is handling token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking BLOCK_M-based index in the sequence the Triton program is handling is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) @@ -468,9 +482,11 @@ def causal_conv1d_fn_triton( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines == conv_states.shape[0] - and dim == conv_states.shape[1] - and width - 1 <= conv_states.shape[2]) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) @@ -483,8 +499,7 @@ def causal_conv1d_fn_triton( stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) - stride_cache_indices = cache_indices.stride( - 0) if cache_indices is not None else 0 + stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -498,8 +513,10 @@ def causal_conv1d_fn_triton( assert cache_indices.dim() == 1 assert padded_batch == cache_indices.size(0) if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch, ) - assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert has_initial_state.size() == (padded_batch,) + assert ( + conv_states is not None + ), "ERROR: `has_initial_state` is used, which needs also `conv_states`" assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" @@ -524,44 +541,46 @@ def causal_conv1d_fn_triton( if META["batch_ptr"].nelement() < len(mlist): newlen = len(mlist) + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= len(mlist): - META["batch_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(mlist))) - META["token_chunk_offset_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(offsetlist))) + META["batch_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(mlist)) + ) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist)) + ) META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( - META["x_ptr"].device) + META["x_ptr"].device + ) return tot + else: def num_program(META, nums_dict): - tot = nums_dict[META["BLOCK_M"]]['tot'] + tot = nums_dict[META["BLOCK_M"]]["tot"] - mlist = nums_dict[META["BLOCK_M"]]['mlist'] - mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len'] + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] - offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist'] + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] - META["token_chunk_offset_ptr"] = nums_dict[ - META["BLOCK_M"]]["token_chunk_offset_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ + "token_chunk_offset_ptr" + ] else: if META["batch_ptr"].nelement() < mlist_len: newlen = mlist_len + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= mlist_len: META["batch_ptr"][0:mlist_len].copy_(mlist) - META["token_chunk_offset_ptr"][0:mlist_len].copy_( - offsetlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) return tot def grid(META): @@ -614,64 +633,18 @@ def causal_conv1d_fn_triton( IS_CONTINUOUS_BATCHING=cache_indices is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, - #launch_cooperative_grid=True + # launch_cooperative_grid=True BLOCK_M=8, BLOCK_N=256, num_stages=2, - groups_per_cluster = np2_statelen, - isCloseUnrollControl = True, - isCloseVectorization = True, - is_use_mask_zero = True + groups_per_cluster=np2_statelen, + isCloseUnrollControl=True, + isCloseVectorization=True, + is_use_mask_zero=True, ) return out -def causal_conv1d_single( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - - if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out[..., :(width - 1)].copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) - - def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, @@ -685,62 +658,82 @@ def causal_conv1d_fn( metadata=None, validate_data=False, ): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen - sequences are concatenated from left to right for varlen - weight: (dim, width) - bias: (dim,) - query_start_loc: (batch + 1) int32 - The cumulative sequence lengths of the sequences in - the batch, used to index into sequence. prepended by 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), - x.shape=(dim,17) - cache_indices: (batch) int32 - indicates the corresponding state index, - like so: conv_state = conv_states[cache_indices[batch_id]] - has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial - state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - out_ref = [] - out_ref_b = [] - seqlens = query_start_loc[1:] - query_start_loc[:-1] - seqlens = seqlens.tolist() - splits = torch.split(x, seqlens, dim=-1) + x = x.contiguous() + out = torch.empty_like(x) + has_initial_state = has_initial_state.to(torch.int32) + dim = x.shape[-1] + cu_seqlen = x.shape[-2] + width = weight.shape[-1] + num_cache_lines = conv_states.shape[0] + state_width = conv_states.shape[-2] + batch_size = query_start_loc.shape[0] - 1 + stride = conv_states.stride()[0] - for i in range(len(seqlens)): - x_s = splits[i] - if cache_indices[i] == PAD_SLOT_ID: - continue - out_ref_b.append( - causal_conv1d_single( - x_s, - weight, + kunlun_ops.causal_conv1d_fn( + x, + out, + dim, + cu_seqlen, + weight, + width, + conv_states, + num_cache_lines, + state_width, + query_start_loc.cpu(), + query_start_loc, + batch_size, + bias, + cache_indices_cpu=cache_indices.cpu(), + cache_indices_xpu=cache_indices, + has_initial_state_cpu=has_initial_state.cpu(), + has_initial_state_xpu=has_initial_state, + act="SWISH", + state_seq_stride=stride, + ) + # out = torch.nn.functional.silu(out) + return out + + +def torch_causal_conv1d_update_spec( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, + conv_state_indices=None, + num_accepted_tokens=None, +): + out = torch.empty_like(hidden_states) + _, seq_len, hidden_size = hidden_states.shape + for i in range(hidden_states.shape[0]): + tmp_conv_state = conv_state[conv_state_indices[i]] + state_len = tmp_conv_state.shape[-2] + hidden_states_i = hidden_states[i] + hidden_states_new = torch.cat( + [tmp_conv_state[: (2 + num_accepted_tokens[i]), :], hidden_states_i], dim=0 + ).to(weight.dtype) + + hidden_states_new = hidden_states_new.unsqueeze(0) + + conv_state[conv_state_indices[i]] = hidden_states_new[:, -state_len:, :] + for j in range(seq_len): + if j == seq_len - 1: + hidden_states_new_j = hidden_states_new + else: + hidden_states_new_j = hidden_states_new[:, : (1 - seq_len + j)] + hidden_states_new_j = hidden_states_new_j.transpose(-1, -2).contiguous() + out_i = F.conv1d( + hidden_states_new_j, + weight.unsqueeze(1), bias, - activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]].unsqueeze(0), - initial_states=conv_states[cache_indices[i]] - if has_initial_state[i] else None)) - out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) - out_ref_tensor = torch.cat(out_ref, dim=0) - return out_ref_tensor + padding=0, + groups=hidden_size, + ) + out_i = F.silu(out_i[:, :, -1:]) + out_i = out_i.to(hidden_states.dtype).squeeze(-1).unsqueeze(0) + out[i, j] = out_i + return out.view(-1, hidden_size) @triton.jit() @@ -797,9 +790,9 @@ def _causal_conv1d_update_kernel_xpu( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + conv_state_batch_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + ).to(tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa @@ -821,15 +814,16 @@ def _causal_conv1d_update_kernel_xpu( # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. - conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) - - 1) + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 else: conv_state_token_offset = 0 # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + conv_states_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) mask_w = idx_feats < dim prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok @@ -853,35 +847,45 @@ def _causal_conv1d_update_kernel_xpu( # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * - stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] - x_ptrs = x_base[None, :] + ( - (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens - VAL >= 0)[:, None] & - (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + conv_state_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = ( + conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -889,10 +893,11 @@ def _causal_conv1d_update_kernel_xpu( if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS @@ -960,14 +965,19 @@ def _causal_conv1d_update_kernel_xpu( if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = o_ptr + ( - idx_seq) * stride_o_seq + idx_token * stride_o_token + ( - idx_feats * stride_o_dim) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) + @triton.jit() def _causal_conv1d_update_kernel( # Pointers to matrices @@ -1011,9 +1021,9 @@ def _causal_conv1d_update_kernel( BLOCK_N: tl.constexpr, ): # ruff: noqa: E501 - # idx_seq = tl.program_id(0) + idx_seq = tl.program_id(0) - idx_seq = batch_id + # idx_seq = batch_id if idx_seq >= batch: return @@ -1022,9 +1032,9 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + conv_state_batch_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + ).to(tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa @@ -1046,15 +1056,16 @@ def _causal_conv1d_update_kernel( # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. - conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) - - 1) + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 else: conv_state_token_offset = 0 # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + conv_states_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) mask_w = idx_feats < dim prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok @@ -1078,36 +1089,46 @@ def _causal_conv1d_update_kernel( # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * - stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] - x_ptrs = x_base[None, :] + ( - (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens - VAL >= 0)[:, None] & - (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) # tl.debug_barrier() new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + conv_state_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = ( + conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -1115,10 +1136,11 @@ def _causal_conv1d_update_kernel( if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS @@ -1186,44 +1208,19 @@ def _causal_conv1d_update_kernel( if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = o_ptr + ( - idx_seq) * stride_o_seq + idx_token * stride_o_token + ( - idx_feats * stride_o_dim) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) -def torch_causal_conv1d_update( - hidden_states, - conv_state, - weight, - bias=None, - activation=None, - conv_state_indices=None -): - _, hidden_size, seq_len = hidden_states.shape - tmp_conv_state = conv_state[conv_state_indices] - state_len = tmp_conv_state.shape[-1] - - hidden_states_new = torch.cat([tmp_conv_state, hidden_states], dim=-1).to(weight.dtype) - cast_conv_state = conv_state.unsqueeze(0) - tmp_hidden_states = hidden_states_new[:, :, -state_len:] - ori_shape = tmp_hidden_states.shape - tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape) - kunlun_ops.reshape_and_cache_flash( - tmp_hidden_states, - tmp_hidden_states, - cast_conv_state, - cast_conv_state, - conv_state_indices) - out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) - out = F.silu(out[:, :, -seq_len:]) - out = out.to(hidden_states.dtype).squeeze(-1) - return out - - def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, @@ -1232,7 +1229,10 @@ def causal_conv1d_update( activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, + conv_state_indices_cpu: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: torch.Tensor | None = None, + max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, @@ -1278,99 +1278,54 @@ def causal_conv1d_update( # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() - if validate_data: + if False and validate_data: assert dim == weight.size(0) - assert conv_state.stride( - -2 - ) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) if conv_state_indices is None: assert conv_state.size(0) >= batch else: - assert (batch, ) == conv_state_indices.shape + assert (batch,) == conv_state_indices.shape assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this assert cache_seqlens is None # not needed for vLLM - circular buffer - if batch > 1: - return torch_causal_conv1d_update( - x, - conv_state, - weight, - bias, - activation, - conv_state_indices=conv_state_indices - ) - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( - ) - stride_state_indices = conv_state_indices.stride( - 0) if conv_state_indices is not None else 0 - if num_accepted_tokens is not None: - state_len = width - 1 + (seqlen - 1) # effective state_len needed + if num_accepted_tokens is None: + x = x.squeeze(-1).unsqueeze(1) else: - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) + x = x.squeeze(-1).view(-1, max_query_len, dim) + if num_accepted_tokens is None: + out = torch.empty_like(x) + import kunlun_ops - def grid(META): - return ( - 1, - triton.cdiv(dim, META["BLOCK_N"]), - ) - for batch_id in range(batch): - _causal_conv1d_update_kernel_xpu[grid]( + stride = conv_state.stride()[0] + kunlun_ops.causal_conv1d_update( x, weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, out, - batch_id=batch_id, - batch=batch, - dim=dim, - seqlen=seqlen, - state_len=state_len, - num_cache_lines=num_cache_lines, - stride_x_seq=stride_x_seq, - stride_x_dim=stride_x_dim, - stride_x_token=stride_x_token, - stride_w_dim=stride_w_dim, - stride_w_width=stride_w_width, - stride_conv_state_seq=stride_istate_seq, - stride_conv_state_dim=stride_istate_dim, - stride_conv_state_tok=stride_istate_token, - stride_state_indices=stride_state_indices, - stride_o_seq=stride_o_seq, - stride_o_dim=stride_o_dim, - stride_o_token=stride_o_token, - pad_slot_id=pad_slot_id, - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=256, - groups_per_cluster=np2_statelen, - isCloseUnrollControl=True, - isCloseVectorization=True, - isCloseOffsetAnalysis=True, - is_use_mask_zero = True + conv_state, + None, + bias, + conv_state_indices_cpu=conv_state_indices_cpu, + conv_state_indices_xpu=conv_state_indices, + act="SWISH", + state_seq_stride=stride, + is_ncw=False, + ) + out = out.squeeze(1) + return out + else: + return torch_causal_conv1d_update_spec( + x, + conv_state, + weight, + bias, + activation, + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, ) - if unsqueeze: - out = out.squeeze(-1) - return out diff --git a/vllm_kunlun/v1/attention/backends/gdn_attn.py b/vllm_kunlun/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000..c6bddc3 --- /dev/null +++ b/vllm_kunlun/v1/attention/backends/gdn_attn.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Backend for GatedDeltaNet attention.""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends import gdn_attn +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class GDNAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: + return GDNAttentionMetadataBuilder + + +@dataclass +class GDNAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_spec_decodes: int + num_spec_decode_tokens: int + num_actual_tokens: int + + has_initial_state: Optional[torch.Tensor] = None + has_initial_state_cpu: Optional[torch.Tensor] = None + + spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [num_spec_decodes + 1,] + ) + non_spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes + 1,] + ) + + spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes,] + ) + non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None + spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] + spec_token_masks: Optional[torch.Tensor] = ( + None # shape: [num_prefill_tokens + num_decode_tokens,] + ) + num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + + # The following attributes are for triton implementation of causal_conv1d + nums_dict: Optional[dict] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None + + +class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): + + cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501 + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self._init_reorder_batch_threshold(1, self.use_spec_decode) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), + self.compilation_config.max_capture_size, + ) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self.spec_token_masks = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), + dtype=torch.bool, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + + def build( # type: ignore[override] + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: Optional[torch.Tensor] = None, + num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None, + fast_build: bool = False, + ) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + + if ( + not self.use_spec_decode + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0] + .sum() + .item() + == 0 + ): + spec_sequence_masks = None + num_spec_decodes = 0 + else: + spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 + num_spec_decodes = spec_sequence_masks.sum().item() + if num_spec_decodes == 0: + spec_sequence_masks = None + else: + spec_sequence_masks = spec_sequence_masks.to( + query_start_loc.device, non_blocking=True + ) + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1) + ) + num_spec_decode_tokens = 0 + spec_token_masks = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens + + if num_prefills == 0 and num_decodes == 0: + spec_token_masks = torch.ones( + ( + min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + ), + dtype=torch.bool, + device=query_start_loc.device, + ) + spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens + ) + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, : self.num_spec + 1 + ] + non_spec_state_indices_tensor = m.block_table_tensor[ + ~spec_sequence_masks, 0 + ] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + ) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:], + ) + + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + has_initial_state_cpu = has_initial_state.cpu() + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(non_spec_query_start_loc) + ) + else: + has_initial_state = None + has_initial_state_cpu = None + num_actual_tokens = ( + num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens + ) + + # prepare tensors for cudagraph + # + # With speculative decoding, the xgrammar backend may rollback tokens + # and causing some sequences has less draft tokens than self.num_spec. + # + # In above cases, the max possible batch size for n tokens, can be + # min(n, cudagraph_max_bs). + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True + ) + spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True + ) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert spec_token_masks is not None + self.spec_token_masks[: spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True + ) + spec_token_masks = self.spec_token_masks[:num_actual_tokens] + spec_token_masks[spec_token_masks.size(0) :].fill_(False) + + self.spec_query_start_loc[: num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True + ) + spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1] + spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True + ) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = num_actual_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True + ) + non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ + :batch_size + ] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[: num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True + ) + non_spec_num_query_tokens = non_spec_query_start_loc[ + -1 + ] # type: ignore[index] + non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] + non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + + if num_accepted_tokens is not None: + num_accepted_tokens = num_accepted_tokens.to(torch.int32) + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + num_actual_tokens=num_actual_tokens, + has_initial_state=has_initial_state, + has_initial_state_cpu=has_initial_state_cpu, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + non_spec_state_indices_tensor_cpu=( + non_spec_state_indices_tensor.cpu() + if non_spec_state_indices_tensor is not None + else None + ), + spec_sequence_masks=spec_sequence_masks, + spec_token_masks=spec_token_masks, + num_accepted_tokens=num_accepted_tokens, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, + ) + return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert ( + m.num_reqs <= self.decode_cudagraph_max_bs + and m.num_actual_tokens <= self.decode_cudagraph_max_bs + ), ( + f"GDN only supports decode-only full CUDAGraph capture. " + f"Make sure batch size ({m.num_reqs}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " + f"and number of tokens ({m.num_actual_tokens}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})." + ) + + num_accepted_tokens = torch.diff(m.query_start_loc) + num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() + m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() + + return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) + + +gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata +gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 93a4022..d55a9f2 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -770,24 +770,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory value = value.contiguous() - if key_cache.is_contiguous(): - kunlun_ops.reshape_and_cache( - key[: attn_metadata.num_actual_tokens], - value[: attn_metadata.num_actual_tokens], - key_cache, - value_cache, - updated_slot_mapping, - ) - else: - cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) - cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) - kunlun_ops.reshape_and_cache_flash( - key, - value, - cast_key_cache, - cast_value_cache, - updated_slot_mapping, - ) + kunlun_ops.reshape_and_cache_flash( + key[: attn_metadata.num_actual_tokens], + value[: attn_metadata.num_actual_tokens], + key_cache, + value_cache, + updated_slot_mapping, + BLHD_LAYOUT=False, + ) assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. diff --git a/vllm_kunlun/v1/sample/rejection_sampler.py b/vllm_kunlun/v1/sample/rejection_sampler.py index 0d34385..4c99b54 100644 --- a/vllm_kunlun/v1/sample/rejection_sampler.py +++ b/vllm_kunlun/v1/sample/rejection_sampler.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -from typing import Union + +import kunlun_ops import torch import torch.nn as nn - from vllm.logger import init_logger - from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -54,7 +53,7 @@ class RejectionSampler(nn.Module): bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - ''' + """ Args: metadata: Metadata for spec decoding. @@ -81,7 +80,7 @@ class RejectionSampler(nn.Module): Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. - ''' + """ assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the @@ -124,11 +123,11 @@ class RejectionSampler(nn.Module): """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. - valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & - (output_token_ids_np < vocab_size)) + valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( + output_token_ids_np < vocab_size + ) outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) + row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs @@ -179,25 +178,15 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and sampling_metadata.all_greedy: - rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, - draft_token_ids, - target_argmax, - bonus_token_ids, - ) - else: - rejection_greedy_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - target_argmax, - bonus_token_ids, - num_draft_tokens, - max_spec_len, - is_greedy, - ) + kunlun_ops.rejection_greedy_sample( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + ) if sampling_metadata.all_greedy: return output_token_ids @@ -222,8 +211,9 @@ def rejection_sample( sampling_metadata, device, ) + bonus_token_ids = bonus_token_ids.squeeze(1) - rejection_random_sample_pytorch( + kunlun_ops.rejection_random_sample( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -235,8 +225,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, + no_draft_probs=draft_probs is None, ) return output_token_ids @@ -374,7 +363,7 @@ def generate_uniform_probs( random values in the range [0, 1). """ uniform_probs = torch.rand( - (num_tokens, ), + (num_tokens,), dtype=torch.float32, device=device, ) @@ -422,7 +411,7 @@ def sample_recovered_tokens( q[i].exponential_(generator=generator) recovered_token_ids = torch.empty_like(draft_token_ids) - sample_recovered_tokens_pytorch( + kunlun_ops.sample_recovered_tokens( recovered_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -430,16 +419,16 @@ def sample_recovered_tokens( target_probs, q, vocab_size, - IS_NGRAM=draft_probs is None, + no_draft_probs=draft_probs is None, ) return recovered_token_ids def rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, # [batch_size, 2] - draft_token_ids, # [num_tokens] - target_argmax, # [num_tokens] - bonus_token_ids, # [batch_size] + output_token_ids, # [batch_size, 2] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) @@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch( accept_req_mask = draft_token_ids == target_argmax output_token_ids[:, 0] = target_argmax bonus_token_ids = bonus_token_ids.squeeze(1) - output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, - output_token_ids[:, 1]) + output_token_ids[:, 1] = torch.where( + accept_req_mask, bonus_token_ids, output_token_ids[:, 1] + ) def rejection_greedy_sample_pytorch( - output_token_ids, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - target_argmax, # [num_tokens] - bonus_token_ids, # [batch_size] - draft_tokens_per_req, # [batch_size], list - max_spec_len, - is_greedy=None, # [batch_size] or None + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + draft_tokens_per_req, # [batch_size], list + max_spec_len, + is_greedy=None, # [batch_size] or None ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) device = output_token_ids.device draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( - device, non_blocking=True) + device, non_blocking=True + ) if is_greedy is None: is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) start_indices = cu_num_draft_tokens - draft_tokens_per_req req_ids = torch.arange(batch_size, device=device) token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) - token_positions = torch.arange( - num_tokens, device=device) - start_indices[token_req_ids] + token_positions = ( + torch.arange(num_tokens, device=device) - start_indices[token_req_ids] + ) # Find the first mismatch position of each request. - mismatch_global = (draft_token_ids != target_argmax) + mismatch_global = draft_token_ids != target_argmax if max_spec_len == 0: - first_mismatch_pos_per_req = torch.zeros(batch_size, - dtype=torch.long, - device=device) + first_mismatch_pos_per_req = torch.zeros( + batch_size, dtype=torch.long, device=device + ) else: # [bs, max_spec_len] - pos_matrix = torch.full((batch_size, max_spec_len), - -1, - dtype=torch.long, - device=device) + pos_matrix = torch.full( + (batch_size, max_spec_len), -1, dtype=torch.long, device=device + ) pos_matrix[token_req_ids, token_positions] = token_positions - mismatch_matrix = torch.full((batch_size, max_spec_len), - False, - dtype=torch.bool, - device=device) + mismatch_matrix = torch.full( + (batch_size, max_spec_len), False, dtype=torch.bool, device=device + ) mismatch_matrix[token_req_ids, token_positions] = mismatch_global - mismatch_positions = torch.where(mismatch_matrix, pos_matrix, - max_spec_len * 2) + mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2) first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) - no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) + no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2 first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ - no_mismatch_mask] + no_mismatch_mask + ] # Copy matched target tokens into output. - copy_len = torch.minimum(first_mismatch_pos_per_req + 1, - draft_tokens_per_req) - copy_indices = torch.arange(max_spec_len + 1, - device=device).expand(batch_size, -1) + copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req) + copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1) copy_mask = copy_indices < copy_len.unsqueeze(1) greedy_mask = is_greedy.unsqueeze(1) final_copy_mask = copy_mask & greedy_mask global_idx = start_indices.unsqueeze(1) + copy_indices - output_token_ids[final_copy_mask] = target_argmax[ - global_idx[final_copy_mask]].to(output_token_ids.dtype) + output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to( + output_token_ids.dtype + ) # Fill bonus token. - needs_bonus = is_greedy & (first_mismatch_pos_per_req - >= draft_tokens_per_req) + needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req) if torch.any(needs_bonus): bonus_rows = torch.where(needs_bonus)[0] bonus_cols = draft_tokens_per_req[bonus_rows] @@ -556,11 +544,9 @@ def rejection_random_sample_pytorch( if IS_NGRAM: draft_prob = 1.0 else: - draft_prob = draft_probs[start_idx + pos, - draft_token_id].item() + draft_prob = draft_probs[start_idx + pos, draft_token_id].item() - target_prob = target_probs[start_idx + pos, - draft_token_id].item() + target_prob = target_probs[start_idx + pos, draft_token_id].item() uniform_prob = uniform_probs[start_idx + pos].item() if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: @@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch( else: draft_p = draft_probs[token_idx].clone() target_p = target_probs[token_idx].clone() - prob = torch.maximum(target_p - draft_p, - torch.tensor(0.0, device=target_p.device)) + prob = torch.maximum( + target_p - draft_p, torch.tensor(0.0, device=target_p.device) + ) - q_values = torch.full((vocab_size, ), - float('-inf'), - device=q.device) + q_values = torch.full((vocab_size,), float("-inf"), device=q.device) q_values[:vocab_size] = q[req_idx, :vocab_size] recovered_id = torch.argmax(prob / q_values).item() @@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch( if IS_NGRAM: target_probs[token_idx, draft_token_id] = orig_prob - diff --git a/vllm_kunlun/v1/sample/spec_decode/eagle.py b/vllm_kunlun/v1/sample/spec_decode/eagle.py index 9f4e244..676694d 100644 --- a/vllm_kunlun/v1/sample/spec_decode/eagle.py +++ b/vllm_kunlun/v1/sample/spec_decode/eagle.py @@ -337,5 +337,5 @@ def prepare_next_token_ids_padded( return next_token_ids, valid_sampled_tokens_count -EagleProposer.propose = propose +# EagleProposer.propose = propose EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index bb4b000..084c59a 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -407,7 +407,7 @@ def add_rmsnorm( ) -> None: kunlun_ops.add_rmsnorm( x, - y, # 原来写 residual,这里其实是 y + y, residual_output=residual_output, weight=weight, eps=eps, @@ -523,6 +523,145 @@ def _fake_add_rmsnorm( add_rmsnorm.register_fake(_fake_add_rmsnorm) +@custom_op("_C::gemma_add_rmsnorm", mutates_args=()) +def gemma_add_rmsnorm( + x: torch.Tensor, + y: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweaved: bool = False, + store_output_before_norm: bool = True, + bias: torch.Tensor = None, + smooth: torch.Tensor = None, + residual_output: torch.Tensor = None, + force_sdnn: bool = False, +) -> None: + # print("gemma_add_rmsnorm wrapper") + kunlun_ops.gemma_add_rmsnorm( + x, + y, + weight=weight, + output=output, + eps=eps, + enable_pdl=enable_pdl, + interweaved=interweaved, + store_output_before_norm=store_output_before_norm, + bias=bias, + smooth=smooth, + residual_output=residual_output, + force_sdnn=force_sdnn, + ) + + +@impl("_C::gemma_add_rmsnorm", "CUDA") +def gemma_add_rmsnorm_cuda( + x: torch.Tensor, + y: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweaved: bool = False, + store_output_before_norm: bool = True, + bias: torch.Tensor = None, + smooth: torch.Tensor = None, + residual_output: torch.Tensor = None, + force_sdnn: bool = False, +) -> None: + # print("gemma_add_rmsnorm_cuda wrapper") + kunlun_ops.gemma_add_rmsnorm( + x, + y, + weight=weight, + output=output, + eps=eps, + enable_pdl=enable_pdl, + interweaved=interweaved, + store_output_before_norm=store_output_before_norm, + bias=bias, + smooth=smooth, + residual_output=residual_output, + force_sdnn=force_sdnn, + ) + + +def _fake_gemma_add_rmsnorm( + x: torch.Tensor, + y: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweaved: bool = False, + store_output_before_norm: bool = True, + bias: torch.Tensor = None, + smooth: torch.Tensor = None, + residual_output: torch.Tensor = None, + force_sdnn: bool = False, +): + output.fake_shape = x.shape + output.fake_dtype = x.dtype + return None + + +gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm) + + +@custom_op("_C::gemma_rmsnorm", mutates_args=()) +def gemma_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweave: bool = False, + bias: torch.Tensor = None, + force_sdnn: bool = False, +) -> None: + # print("gemma_rmsnorm wrapper") + kunlun_ops.gemma_rmsnorm( + x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn + ) + + +@impl("_C::gemma_rmsnorm", "CUDA") +def gemma_rmsnorm_cuda( + x: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweave: bool = False, + bias: torch.Tensor = None, + force_sdnn: bool = False, +) -> None: + # print("gemma_rmsnorm_cuda wrapper") + kunlun_ops.gemma_rmsnorm( + x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn + ) + + +def _fake_gemma_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-5, + enable_pdl: bool = False, + interweave: bool = False, + bias: torch.Tensor = None, + force_sdnn: bool = False, +): + # 设置 shape/dtype,但不返回值 + output.fake_shape = x.shape + output.fake_dtype = x.dtype + return None + + +gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm) + + @custom_op("_C::split_norm_rope_neox", mutates_args=()) def split_norm_rope_neox( q_emb: torch.Tensor,