diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index 69c0713..e7b953a 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -5,7 +5,6 @@ def register_model(): # from .demo_model import DemoModel # noqa: F401 from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401 from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401 - from .qwen3 import Qwen3ForCausalLM #noqa: F401 from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401 from .qwen3_vl import Qwen3VLForConditionalGeneration from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration @@ -48,11 +47,7 @@ def register_model(): ModelRegistry.register_model( "InternLM2ForCausalLM", - "vllm_kunlun.models.internlm2:InternLM2ForCausalLM") - - ModelRegistry.register_model( - "Qwen2ForCausalLM", - "vllm_kunlun.models.qwen2:Qwen2ForCausalLM") + "vllm_kunlun.models.internlm2:InternLM2ForCausalLM") ModelRegistry.register_model( "InternVLChatModel", @@ -78,10 +73,6 @@ def register_model(): "SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM") - ModelRegistry.register_model( - "LlamaForCausalLM", - "vllm_kunlun.models.llama:LlamaForCausalLM") - ModelRegistry.register_model( "MiMoV2FlashForCausalLM", "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") diff --git a/vllm_kunlun/models/llama.py b/vllm_kunlun/models/llama.py deleted file mode 100644 index 65823fe..0000000 --- a/vllm_kunlun/models/llama.py +++ /dev/null @@ -1,673 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights.""" -from collections.abc import Iterable -from itertools import islice -from typing import Any, Optional, Union - -import torch -from torch import nn -from transformers import LlamaConfig - -from vllm.attention import AttentionType -from vllm_kunlun.ops.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm_kunlun.ops.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope - -from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding - -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.sequence import IntermediateTensors - -from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class LlamaMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - prefix: str = "", - reduce_results: bool = True, - disable_tp: bool = False, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[intermediate_size] * 2, - bias=bias, - quant_config=quant_config, - disable_tp=disable_tp, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - reduce_results=reduce_results, - disable_tp=disable_tp, - prefix=f"{prefix}.down_proj", - ) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) - x, _ = self.down_proj(x) - return x - - -class LlamaAttention(nn.Module): - - def __init__( - self, - config: LlamaConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - ) -> None: - super().__init__() - layer_idx = extract_layer_index(prefix) - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MistralConfig has an optional head_dim introduced by Mistral-Nemo - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = self.hidden_size // self.total_num_heads - self.head_dim = head_dim - # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", - 1) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size=hidden_size, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - total_num_kv_heads=self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - input_size=self.total_num_heads * self.head_dim, - output_size=hidden_size, - bias=bias_o_proj, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self._init_rotary_emb(config, - rope_scaling=rope_scaling, - quant_config=quant_config) - - sliding_window = None - if layer_types := getattr(config, "layer_types", None): - # Fix for Eagle3 compatibility: - # for draft models, subtract target layer count - # to get draft-relative layer index starting from 0 - if hasattr(config, 'target_layer_count'): - # This is a draft model, - # adjust layer_idx to be relative to draft layers - effective_layer_idx = layer_idx - config.target_layer_count - else: - # This is a target model, use layer_idx directly - effective_layer_idx = layer_idx - assert effective_layer_idx < len(layer_types), \ - f"effective_layer_idx: {effective_layer_idx} \ - is out of bounds for layer_types: {layer_types}" - - is_sliding = layer_types[ - effective_layer_idx] == "sliding_attention" - if is_sliding: - sliding_window = config.sliding_window - - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) - - self.attn = attn_cls( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - per_layer_sliding_window=sliding_window, - attn_type=attn_type, - prefix=f"{prefix}.attn", - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - def _init_rotary_emb(self, config: LlamaConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: - is_neox_style = True - is_gguf = quant_config and quant_config.get_name() == "gguf" - if is_gguf and config.model_type == "llama": - is_neox_style = False - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, - ) - - -class LlamaDecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "", - config: Optional[LlamaConfig] = None) -> None: - super().__init__() - - config = config or vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # Support abacusai/Smaug-72B-v0.1 with attention_bias - # Support internlm/internlm-7b with bias - attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) - bias_o_proj = attention_bias - # support internlm/internlm3-8b with qkv_bias - if hasattr(config, 'qkv_bias'): - attention_bias = config.qkv_bias - - # By default, Llama uses causal attention as it is a decoder-only model. - # You can override the HF config with `is_causal=False` to enable - # bidirectional attention, which is used in some embedding models - # (e.g. parasail-ai/GritLM-7B-vllm) - if getattr(config, "is_causal", True): - attn_type = AttentionType.DECODER - else: - attn_type = AttentionType.ENCODER_ONLY - - self.self_attn = LlamaAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - bias=attention_bias, - bias_o_proj=bias_o_proj, - cache_config=cache_config, - prefix=f"{prefix}.self_attn", - attn_type=attn_type, - ) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - bias=getattr(config, "mlp_bias", False), - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile -class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - quant_config=quant_config, - ) - else: - self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), - prefix=f"{prefix}.layers", - ) - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - self.aux_hidden_state_layers = tuple[int, ...]() - - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - aux_hidden_states = [] - for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer)): - if idx in self.aux_hidden_state_layers: - aux_hidden_states.append(hidden_states + residual) - hidden_states, residual = layer(positions, hidden_states, residual) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - - if len(aux_hidden_states) > 0: - return hidden_states, aux_hidden_states - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if "scale" in name: - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - # LoRA specific attributes - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings" - } - embedding_padding_modules = ["lm_head"] - - # Mistral/Llama models can also be loaded with --load-format mistral - # from consolidated.safetensors checkpoints - mistral_mapping = { - "layers": "model.layers", - "attention": "self_attn", - "qscale_act": "input_scale", - "qscale_weight": "weight_scale", - "kv_fake_quantizer.qscale_act": "kv_scale", - "q_fake_quantizer.qscale_act": "attn.q_scale", - "k_fake_quantizer.qscale_act": "k_scale", - "v_fake_quantizer.qscale_act": "v_scale", - "wq": "q_proj", - "wk": "k_proj", - "wv": "v_proj", - "wo": "o_proj", - "attention_norm": "input_layernorm", - "feed_forward": "mlp", - "w1": "gate_proj", - "w2": "down_proj", - "w3": "up_proj", - "ffn_norm": "post_attention_layernorm", - "tok_embeddings": "model.embed_tokens", - "output": "lm_head", - "norm": "model.norm", - } - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - self.config = config - self.lora_config = lora_config - - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type) - - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - else: - self.lm_head = PPMissingLayer() - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: - self.model.aux_hidden_state_layers = layers - - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: - num_layers = len(self.model.layers) - return (2, num_layers // 2, num_layers - 3) - - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): - return LlamaModel(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return model_output - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights( - self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights) - - # This function is used to remap the mistral format as - # used by Mistral and Llama <=2 - def maybe_remap_mistral( - self, - name: str, - loaded_weight: torch.Tensor, - ) -> tuple[str, torch.Tensor]: - - def permute(w: torch.Tensor, n_heads: int, attn_out: int): - attn_in = self.config.head_dim * n_heads - - return w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, 2).reshape(attn_in, attn_out) - - mapping = self.mistral_mapping - modules = name.split(".") - - # rotary embeds should be sliced - # If using quantized model in mistral format, - # quantization scales (qscale_weight) also need to be sliced - if "wk" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, - self.config.hidden_size) - elif "wk" in modules and modules[ - -1] == "qscale_weight" and loaded_weight.numel() > 1: - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, 1) - elif "wq" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, - self.config.hidden_size) - elif "wq" in modules and modules[ - -1] == "qscale_weight" and loaded_weight.numel() > 1: - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, 1) - - num_modules = len(modules) - for i in range(num_modules): - item = modules[i] - next_item = modules[i + 1] if i < num_modules - 1 else None - - combined_item = (f"{item}.{next_item}" - if next_item is not None else None) - - if combined_item in mapping: - name = name.replace(combined_item, mapping[combined_item]) - elif item in mapping and mapping[item] not in name: - name = name.replace(item, mapping[item]) - - return name, loaded_weight - diff --git a/vllm_kunlun/models/qwen2.py b/vllm_kunlun/models/qwen2.py deleted file mode 100644 index 212200d..0000000 --- a/vllm_kunlun/models/qwen2.py +++ /dev/null @@ -1,511 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py -# Copyright 2024 The Qwen team. -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Qwen2 model compatible with HuggingFace weights.""" -import os - -from collections.abc import Iterable -from typing import Any, Optional, Union - -import torch -from torch import nn -from transformers import Qwen2Config - -from vllm.attention import AttentionType -from vllm_kunlun.ops.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm_kunlun.ops.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) -from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -# from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from vllm.model_executor.models.adapters import as_seq_cls_model -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class Qwen2MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - ) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class Qwen2Attention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.dual_chunk_attention_config = dual_chunk_attention_config - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=self.rope_theta, - rope_scaling=rope_scaling, - dual_chunk_attention_config=dual_chunk_attention_config, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - attn_type=attn_type, - prefix=f"{prefix}.attn", - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # INTERNVL_3暂时使用环境变量来控制是否使用原生rotary_embedding - # 若要修改,可尝试参考 qwen3.py - if os.getenv('INTERNVL_3') == "1": - q, k = self.rotary_emb.forward_native(positions, q, k) - else: - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class Qwen2DecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 1000000) - rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) - - # By default, Qwen2 uses causal attention as it is a decoder-only model. - # You can override the HF config with `is_causal=False` to enable - # bidirectional attention, which is used in some embedding models - # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) - if getattr(config, "is_causal", True): - attn_type = AttentionType.DECODER - else: - attn_type = AttentionType.ENCODER_ONLY - - self.self_attn = Qwen2Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - rope_scaling=rope_scaling, - prefix=f"{prefix}.self_attn", - attn_type=attn_type, - dual_chunk_attention_config=dual_chunk_attention_config, - ) - self.mlp = Qwen2MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) -class Qwen2Model(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): - assert config.max_window_layers == config.num_hidden_layers, ( - "Sliding window for some but all layers is not supported. " - "This model uses sliding window but `max_window_layers` = {} " - "is less than `num_hidden_layers` = {}. Please open an issue " - "to discuss this feature.".format( - config.max_window_layers, - config.num_hidden_layers, - )) - - self.config = config - config = config.get_text_config() - self.quant_config = quant_config - self.vocab_size = config.vocab_size - - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() - - # Use the provided decoder layer type or default to Qwen2DecoderLayer - decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - # sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states,) - # sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) - - -Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm_kunlun/models/qwen3.py b/vllm_kunlun/models/qwen3.py index 7556e1e..3ac34b7 100644 --- a/vllm_kunlun/models/qwen3.py +++ b/vllm_kunlun/models/qwen3.py @@ -47,8 +47,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.sequence import IntermediateTensors from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .qwen2 import Qwen2MLP as Qwen3MLP -from .qwen2 import Qwen2Model +from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP +from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix) diff --git a/vllm_kunlun/models/qwen3_vl.py b/vllm_kunlun/models/qwen3_vl.py index 31122be..d695dcb 100644 --- a/vllm_kunlun/models/qwen3_vl.py +++ b/vllm_kunlun/models/qwen3_vl.py @@ -80,8 +80,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) -from .qwen2_vl import Qwen2VLProcessingInfo -from .qwen3 import Qwen3ForCausalLM, Qwen3Model +from vllm.model_executor.models.qwen2_vl import Qwen2VLProcessingInfo +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings) from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model diff --git a/vllm_kunlun/models/qwen3_vl_moe.py b/vllm_kunlun/models/qwen3_vl_moe.py index 6d1685b..e0643c7 100644 --- a/vllm_kunlun/models/qwen3_vl_moe.py +++ b/vllm_kunlun/models/qwen3_vl_moe.py @@ -42,8 +42,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel -from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, +from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) from vllm.model_executor.models.utils import is_pp_missing_parameter, maybe_prefix diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index b9f47fd..fec8bbb 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -19,4 +19,5 @@ import vllm_kunlun.ops.rotary_embedding import vllm_kunlun.ops.layernorm import vllm_kunlun.ops.quantization.awq import vllm_kunlun.ops.quantization.gptq -import vllm_kunlun.ops.vocab_parallel_embedding \ No newline at end of file +import vllm_kunlun.ops.vocab_parallel_embedding +import vllm_kunlun.ops.linear \ No newline at end of file diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 325bef0..0849432 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -491,7 +491,7 @@ class KunlunOps: d = y.shape[-1] // 2 output_shape = (y.shape[:-1] + (d, )) out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) - torch.ops._C.swiglu(y, out1) + torch.ops._C.silu_and_mul(out1, y) out = torch.empty(M,moe_top_k, w2.shape[1], @@ -570,7 +570,7 @@ class KunlunOps: cur_token = repeat_x[selected_token] up_gate = torch.empty(selected_token.sum(), up_gate_size//2, dtype=cur_token.dtype, device=cur_token.device) - torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate) + torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T) out[selected_token] = up_gate @ w2_weight[i].T output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype) diff --git a/vllm_kunlun/ops/activation.py b/vllm_kunlun/ops/activation.py index 79dd552..0526acf 100644 --- a/vllm_kunlun/ops/activation.py +++ b/vllm_kunlun/ops/activation.py @@ -98,7 +98,7 @@ class SiluAndMul(CustomOp): d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - torch.ops._C.swiglu(x, out) + torch.ops._C.silu_and_mul(out, x) return out def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 0c2d050..95ef1bb 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -43,7 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.config import VllmConfig, get_layers_from_vllm_config - +import inspect class KunlunAttentionBackend(AttentionBackend): """KunlunAttentionBackend""" @@ -723,30 +723,45 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): tmp_block_tables = decode_meta.block_tables else: tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next - - xtorch_ops.speculative_attention( - out=output[:num_decode_tokens], - # Only MLA support q len > 1 right now - q=decode_query.unsqueeze(0), - k_cache=key_cache, - v_cache=value_cache, - context_lens_cpu=decode_meta.seq_lens_tensor_cpu, - context_lens_xpu=decode_meta.seq_lens_tensor, - batch_num=decode_meta.block_tables.shape[0], - # TODO (@xyDong23): Support MTP(q lens >1) - qlen=1, - # TODO (@xyDong23): Support max_context_len to (262144) - max_context_len=131072, - head_num=self.num_heads, - head_dim=self.head_size, - scale=0.0, - kv_head_num=self.num_kv_heads, - block_size=key_cache.shape[2], - max_num_blocks_per_seq=decode_meta.block_tables.shape[1], - max_window_size=self.sliding_window if self.sliding_window is not None else -1, - block_tables=tmp_block_tables, - sink = self.sinks.to(torch.float32) if self.sinks is not None else None - ) + + sig = inspect.signature(xtorch_ops.speculative_attention) + if "max_window_size" in sig.parameters: + xtorch_ops.speculative_attention( + out=output[:num_decode_tokens], + # Only MLA support q len > 1 right now + q=decode_query.unsqueeze(0), + k_cache=key_cache, + v_cache=value_cache, + context_lens_cpu=decode_meta.seq_lens_tensor_cpu, + context_lens_xpu=decode_meta.seq_lens_tensor, + batch_num=decode_meta.block_tables.shape[0], + # TODO (@xyDong23): Support MTP(q lens >1) + qlen=1, + # TODO (@xyDong23): Support max_context_len to (262144) + max_context_len=131072, + head_num=self.num_heads, + head_dim=self.head_size, + scale=0.0, + kv_head_num=self.num_kv_heads, + block_size=key_cache.shape[2], + max_num_blocks_per_seq=decode_meta.block_tables.shape[1], + max_window_size=self.sliding_window if self.sliding_window is not None else -1, + block_tables=tmp_block_tables, + sink = self.sinks.to(torch.float32) if self.sinks is not None else None + ) + else: + xtorch_ops.paged_attention( + x=decode_query, + k_cache=key_cache, + v_cache=value_cache, + block_tables=tmp_block_tables, + context_lens_cpu=decode_meta.seq_lens_tensor_cpu, + context_lens_xpu=decode_meta.seq_lens_tensor, + is_context=False, + is_causal=True, + out=output[:num_decode_tokens], + vo_head_dim=self.head_size + ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def use_cascade_attention( diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index e0729a9..7425ac1 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -323,27 +323,6 @@ def rms_norm_dynamic_per_token_quant_xpu( )->None: pass -@custom_op("_C::silu_and_mul", mutates_args=()) -def silu_and_mul( - result : torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - epsilon: float -)->None: - pass -@impl("_C::silu_and_mul", "CUDA") -def silu_and_mul_xpu( - result : torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - epsilon: float -)->None: - pass - @custom_op("_C::silu_and_mul_quant", mutates_args=()) def silu_and_mul_quant( result : torch.Tensor, @@ -592,39 +571,39 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"): ) -> None: pass -@custom_op("_C::swiglu", mutates_args=()) -def swiglu( +@custom_op("_C::silu_and_mul", mutates_args=()) +def silu_and_mul( + out: torch.Tensor, x: torch.Tensor, - y: torch.Tensor, axis: int=-1, turn: bool=True ) -> None: xtorch_ops.swiglu( - x, - y, + x=x, + y=out, ) -@impl("_C::swiglu", "CUDA") -def swiglu_cuda( +@impl("_C::silu_and_mul", "CUDA") +def silu_and_mul_cuda( + out: torch.Tensor, x: torch.Tensor, - y: torch.Tensor, axis: int=-1, turn: bool=True ) -> None: xtorch_ops.swiglu( - x, - y, + x=x, + y=out, ) -def _fake_swiglu( +def _fake_silu_and_mul( + out: torch.Tensor, x: torch.Tensor, - y: torch.Tensor, axis: int=-1, turn: bool=True): return None -swiglu.register_fake(_fake_swiglu) +silu_and_mul.register_fake(_fake_silu_and_mul)