diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index 7dd089c..69c0713 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -82,6 +82,9 @@ def register_model(): "LlamaForCausalLM", "vllm_kunlun.models.llama:LlamaForCausalLM") - + ModelRegistry.register_model( + "MiMoV2FlashForCausalLM", + "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") + def register_quant_method(): """to do""" diff --git a/vllm_kunlun/models/mimo_v2_flash.py b/vllm_kunlun/models/mimo_v2_flash.py new file mode 100644 index 0000000..0dfcb86 --- /dev/null +++ b/vllm_kunlun/models/mimo_v2_flash.py @@ -0,0 +1,706 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from itertools import islice + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionType +from vllm_kunlun.ops.attention.layer import Attention +from vllm.config import ( + CacheConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) +from vllm_kunlun.ops.linear import QKVParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm_kunlun.ops.activation import SiluAndMul +logger = init_logger(__name__) + + +class MiMoV2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiMoV2MoE(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + is_nextn: bool = False, + ): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.n_routed_experts + + + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}." + ) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate_dtype = torch.float32 + self.gate = nn.Linear( + config.hidden_size, + config.n_routed_experts, + bias=False, + dtype=self.gate_dtype, + ) + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=self.gate_dtype) + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + scoring_func="sigmoid", + ) + self.register_buffer("kunlun_linear_weights", torch.zeros( + config.num_local_experts,config.hidden_size,dtype=torch.float)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert hidden_states.dim() <= 2, "MiMoV2MoE only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.gate_dtype is not None: + gate_input = hidden_states.to(self.gate_dtype) + else: + gate_input = hidden_states + router_logits = self.gate(gate_input) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight + ) + + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states + + +class MiMoV2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + v_head_dim: int | None = None, + sliding_window_size: int = -1, + attention_bias: bool = False, + add_swa_attention_sink_bias: bool = False, + layer_id: int = 0, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + partial_rotary_factor: float = 1.0, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.layer_id = layer_id + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = head_dim + + self.v_head_dim = v_head_dim if v_head_dim is not None else head_dim + + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_head_dim + + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + v_head_size=self.v_head_dim, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.v_head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + partial_rotary_factor=partial_rotary_factor + ) + + self.attention_sink_bias = ( + torch.nn.Parameter(torch.empty(self.num_heads), requires_grad=False) + if add_swa_attention_sink_bias + else None + ) + + sliding_window = sliding_window_size if sliding_window_size > -1 else None + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.attention_sink_bias, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + v = v.view(-1, self.num_kv_heads, self.v_head_dim) + v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0) + v = v.view(-1, self.num_kv_heads * self.head_dim) + + attn_output = self.attn(q, k, v) + + attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiMoV2FlashDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_text_config + quant_config = vllm_config.quant_config + layer_id = extract_layer_index(prefix) + + self.hidden_size = config.hidden_size + self.config = config + self.layer_id = layer_id + + rope_theta = getattr(config, "rope_theta", 1000000) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + + if self.is_compressed_softmax_layer(): + self.self_attn = MiMoV2Attention( + hidden_size=self.hidden_size, + num_heads=config.swa_num_attention_heads, + num_kv_heads=config.swa_num_key_value_heads, + head_dim=config.swa_head_dim, + v_head_dim=getattr(config, "swa_v_head_dim", None), + sliding_window_size=config.sliding_window_size, + attention_bias=config.attention_bias, + add_swa_attention_sink_bias=getattr( + config, "add_swa_attention_sink_bias", False + ), + layer_id=layer_id, + rope_theta=getattr(config, "swa_rope_theta", rope_theta), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0), + prefix=f"{prefix}.self_attn", + ) + else: + self.self_attn = MiMoV2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + v_head_dim=getattr(config, "v_head_dim", None), + sliding_window_size=-1, # normal attention + attention_bias=config.attention_bias, + layer_id=layer_id, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0), + prefix=f"{prefix}.self_attn", + ) + + self.is_layer_sparse = self.is_moe_layer(layer_id) + if self.is_layer_sparse: + self.mlp = MiMoV2MoE( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = MiMoV2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.layernorm_epsilon + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + def is_moe_layer(self, layer_idx: int) -> bool: + return ( + hasattr(self.config, "moe_layer_freq") + and layer_idx >= 0 + and not isinstance(self.config.moe_layer_freq, int) + and self.config.moe_layer_freq[layer_idx] + ) + + def is_compressed_softmax_layer(self) -> bool: + return self.config.hybrid_layer_pattern[self.layer_id] == 1 + + +class MiMoV2Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.get_text_config() + quant_config = vllm_config.quant_config + eplb_config = vllm_config.parallel_config.eplb_config + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.num_redundant_experts = eplb_config.num_redundant_experts + self.v_scale = getattr(config, "attention_value_scale", None) + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiMoV2FlashDecoderLayer( + vllm_config=vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + if "mtp" in name: + continue + + if self.quant_config is not None: + cache_scale_name = self.quant_config.get_cache_scale(name) + if cache_scale_name is not None and cache_scale_name in params_dict: + param = params_dict[cache_scale_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + + kv_scale = loaded_weight + if kv_scale.dim() > 0 and kv_scale.numel() > 1: + kv_scale = kv_scale.view(-1)[0] + + weight_loader(param, kv_scale) + loaded_params.add(cache_scale_name) + continue + + expert_matched = False + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: + if weight_name not in name: + continue + + name_rewritten = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_rewritten, self): + continue + + if ( + name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias") + ) and name_rewritten not in params_dict: + continue + + if name_rewritten not in params_dict: + continue + + param = params_dict[name_rewritten] + weight_loader = param.weight_loader + + weight_loader( + param, + loaded_weight, + name_rewritten, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(name_rewritten) + expert_matched = True + break + + if expert_matched: + continue + + stacked_matched = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name_rewritten = name.replace(weight_name, param_name) + + if ( + name_rewritten.endswith(".bias") + and name_rewritten not in params_dict + ): + continue + + if is_pp_missing_parameter(name_rewritten, self): + continue + + if name_rewritten not in params_dict: + continue + + param = params_dict[name_rewritten] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + if param_name == "qkv_proj" and shard_id == "v": + v_scale = ( + self.v_scale + if self.v_scale is not None + else getattr(self.config, "attention_value_scale", None) + ) + if v_scale is not None and ( + name.endswith("weight_scale_inv") or name.endswith(".bias") + ): + loaded_weight *= float(v_scale) + + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name_rewritten) + + stacked_matched = True + break + + if stacked_matched: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + orig_name = name + mapped_name = maybe_remap_kv_scale_name(name, params_dict) + name = mapped_name if mapped_name is not None else orig_name + + if name not in params_dict: + continue + + param = params_dict[name] + + if "attention_sink_bias" in name: + total_heads = loaded_weight.shape[0] + heads_per_rank = total_heads // tp_size + head_start = tp_rank * heads_per_rank + narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank) + + param.data.copy_(narrow_weight) + loaded_params.add(name) + else: + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.model = MiMoV2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index 042d1f0..b9f47fd 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -15,8 +15,8 @@ # This file is a part of the vllm-ascend project. # -# import vllm_kunlun.ops.linear import vllm_kunlun.ops.rotary_embedding import vllm_kunlun.ops.layernorm import vllm_kunlun.ops.quantization.awq -import vllm_kunlun.ops.quantization.gptq \ No newline at end of file +import vllm_kunlun.ops.quantization.gptq +import vllm_kunlun.ops.vocab_parallel_embedding \ No newline at end of file diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 013f75f..325bef0 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -1,3 +1,20 @@ +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """kunlun custom op entry""" import torch_xmlir import torch @@ -177,51 +194,10 @@ class KunlunOps: """ query_x = query.contiguous() key_x = key.contiguous() - query_x_dim = query_x.dim() - if not is_neox_style: - if cos_sin_cache.dtype == torch.float16: - cos_sin_cache = cos_sin_cache.to(torch.float32) - positions = positions.to(torch.int) - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query_x = query_x.unsqueeze(0) - key_x = key_x.unsqueeze(0) - xtorch_ops.rotary_embedding_gptj( - positions, - query_x, - key_x, - head_size, - cos_sin_cache) - query.data = query_x - key.data = key_x - if query_x_dim != query_x.dim(): - query_x = query_x.unsqueeze(0) - key_x = key_x.unsqueeze(0) - return query, key - - # TODO: need opt - if cos_sin_cache.dim() == 4: - max_seq_len = cos_sin_cache.shape[2] - head_dim = cos_sin_cache.shape[3] - cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D] - cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim) - - # 重塑 query 和 key 的形状 num_tokens = query_x.shape[0] num_heads = query_x.shape[1] // head_size num_kv_heads = key_x.shape[1] // head_size - - # # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size] - # query_x = query_x.view(num_tokens, num_heads, head_size) - # # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size] - # key_x = key_x.view(num_tokens, num_kv_heads, head_size) - - # # 确保形状正确 - # assert query_x.shape == (num_tokens, num_heads, head_size), \ - # f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}" - # assert key_x.shape == (num_tokens, num_kv_heads, head_size), \ - # f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}" torch.ops._C.rotary_embedding( positions, @@ -234,8 +210,6 @@ class KunlunOps: query_x = query_x.view(num_tokens, num_heads * head_size) key_x = key_x.view(num_tokens, num_kv_heads * head_size) - # query.data = query_x - # key.data = key_x return query_x, key_x # Rotary embedding @@ -433,6 +407,121 @@ class KunlunOps: return out + def _dbg(x): + if torch.is_tensor(x): + return (type(x), x.device, x.dtype, x.shape, x.is_contiguous()) + return (type(x), x) + @staticmethod + def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + linear_weights: torch.Tensor, + moe_top_k: int, + renormalize: bool, + inplace: bool = False, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """fused_moe""" + global_num_experts = linear_weights.shape[0] + M, N = hidden_states.shape + hidden_dim = w2.shape[1] + normed_score = torch.empty(M, + moe_top_k, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + moe_top_k, + dtype=torch.int32, + device=hidden_states.device) + num_blocks = 12 + block_statistic = torch.zeros( + num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device + ) + + torch.ops._C.moe_sigmoid_group_topk_norm( + x=router_logits, + topk_index=topk_ids, + norm_score=normed_score, + block_static=block_statistic, + bias=e_score_correction_bias, + scale=1.0, + n_group=num_expert_group, + topk_group=1, + ) + + moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float + expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] + sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] + sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) + + torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + + torch.ops._C.moe_pre_sorted( + x=hidden_states, + topk_index=topk_ids, + block_statistic=block_statistic, + moe_expand=moe_expand, + moe_index=sorted_tokens_idx, + expert_m=expert_m, + sorted_tokens_num_lod=sorted_tokens_num_lod) + + y = torch.empty(M,moe_top_k, + w1.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) + + torch.ops._C.moe_fc( + x=moe_expand, + weight=w1, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=y) + + d = y.shape[-1] // 2 + output_shape = (y.shape[:-1] + (d, )) + out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + torch.ops._C.swiglu(y, out1) + + out = torch.empty(M,moe_top_k, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + out1 = out1.reshape(-1, out1.shape[-1]) + + torch.ops._C.moe_fc( + x=out1, + weight=w2, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=out) + + dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device) + output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) + sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k) + + torch.ops._C.moe_post( + x=out, + moe_index=sorted_tokens_idx, + normed_scale=normed_score, + dequant_scale=dequant_scale, + y=output + ) + + return output + @staticmethod def fused_moe_ep( hidden_states: torch.Tensor, @@ -487,42 +576,6 @@ class KunlunOps: return output - @staticmethod - def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - linear_weights: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """fused_moe""" - output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, - device=hidden_states.device) - expert_num = linear_weights.shape[0] - - torch.ops._C.moe_ffn_block( - x=hidden_states, - gate_w=linear_weights, - inter_w=w1, - output_w=w2, - expert_num=expert_num, - moe_top_k=topk, - topk_group=topk_group, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - expert_group_num=num_expert_group, - out=output, - ) - return output - @staticmethod def fused_multi_head_latent_page_attention( hidden_states: torch.Tensor, diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index 9f01f70..b68a627 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, - linear_weights=linear_weights) + linear_weights=linear_weights, + e_score_correction_bias=e_score_correction_bias) def forward_kunlun( self, @@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: """forward_kunlun""" from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops @@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): num_expert_group=num_expert_group, topk_group=topk_group ) - # fused_moe do not support expert number > 400 - elif layer.local_num_experts > 400: - hidden_states = x - global_num_experts = linear_weights.shape[0] - M, N = hidden_states.shape - hidden_dim = layer.w2_weight.shape[1] - normed_score = torch.empty(M, - top_k, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - top_k, - dtype=torch.int32, - device=hidden_states.device) - num_blocks = 12 - block_statistic = torch.zeros( - num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device - ) - - router_logits = router_logits.float() - torch.ops._C.moe_softmax_topk_norm( - x=router_logits, - normed_score=normed_score, - topk_index=topk_ids, - block_statistic=None, - stable=True) - - moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float - expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] - sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] - sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device) - - torch.ops._C.gen_block_statistic(topk_ids,block_statistic) - - torch.ops._C.moe_pre_sorted( - x=hidden_states, - topk_index=topk_ids, - block_statistic=block_statistic, - moe_expand=moe_expand, - moe_index=sorted_tokens_idx, - expert_m=expert_m, - sorted_tokens_num_lod=sorted_tokens_num_lod) - - y = torch.empty(M,top_k, - layer.w13_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) - - moe_expand = moe_expand.view(M * top_k, hidden_dim) - - torch.ops._C.moe_fc( - x=moe_expand, - weight=layer.w13_weight, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=top_k, - y=y) - - d = y.shape[-1] // 2 - output_shape = (y.shape[:-1] + (d, )) - out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) - torch.ops._C.swiglu(y, out1) - - out = torch.empty(M,top_k, - layer.w2_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) - - out1 = out1.reshape(-1, out1.shape[-1]) - - torch.ops._C.moe_fc( - x=out1, - weight=layer.w2_weight, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=top_k, - y=out) - - dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device) - output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) - sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) - - torch.ops._C.moe_post( - x=out, - moe_index=sorted_tokens_idx, - normed_scale=normed_score, - dequant_scale=dequant_scale, - y=output - ) - return output else: return ops.fused_moe(x, layer.w13_weight, @@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): inplace=True, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, - topk_group=topk_group + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, ) class FusedMoE(VllmFusedMoE): diff --git a/vllm_kunlun/ops/layernorm.py b/vllm_kunlun/ops/layernorm.py index badaba1..85b8124 100644 --- a/vllm_kunlun/ops/layernorm.py +++ b/vllm_kunlun/ops/layernorm.py @@ -57,6 +57,8 @@ def vllm_kunlun_forward_cuda( ) return out +RMSNorm.forward_cuda = vllm_kunlun_forward_cuda +RMSNorm.forward = vllm_kunlun_forward_cuda class KunlunGemmaRMSNorm(OriGemmaRMSNorm): @staticmethod diff --git a/vllm_kunlun/ops/linear.py b/vllm_kunlun/ops/linear.py index db854b8..e057dd6 100644 --- a/vllm_kunlun/ops/linear.py +++ b/vllm_kunlun/ops/linear.py @@ -3,14 +3,30 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter - +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.linear import ( WEIGHT_LOADER_V2_SUPPORTED, ReplicatedLinear, UnquantizedLinearMethod, + ColumnParallelLinear ) from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -59,4 +75,361 @@ def create_weights( # rewrite create_weights and remove weight_loader_v2 to suport cuda graph UnquantizedLinearMethod.create_weights = create_weights -WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") \ No newline at end of file +WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + +class QKVParallelLinear(ColumnParallelLinear): + """ + Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2) + """ + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int | None = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + v_head_size: int | None = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.v_head_size = v_head_size if v_head_size is not None else head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + self.num_heads * self.head_size + + self.num_kv_heads * self.head_size + + self.num_kv_heads * self.v_head_size + ) * tp_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.v_head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) + + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + self.num_kv_heads) * self.head_size + + self.num_kv_heads * self.v_head_size, + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.v_head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.v_head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight( + loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank + ) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + # Note(simon): This is needed for Qwen3's fp8 quantization. + if isinstance(param, BlockQuantScaleParameter): + assert self.quant_method is not None + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + idx_map = {"q": 0, "k": 1, "v": 2} + if loaded_shard_id is not None: + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = {k: loaded_weight.item() for k in idx_map} + return + + if is_gguf_weight: + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.v_head_size, + ), + ] + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.total_num_heads * self.head_size), + "k": ( + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + "v": ( + (self.total_num_heads + self.total_num_kv_heads) + * self.head_size, + self.total_num_kv_heads * self.v_head_size, + ), + "total": ( + (self.total_num_heads + self.total_num_kv_heads) + * self.head_size + + self.total_num_kv_heads * self.v_head_size, + 0, + ), + } + + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, shard_id + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.v_head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.v_head_size, + ), + "total": ( + (self.num_heads + self.num_kv_heads) * self.head_size + + self.num_kv_heads * self.v_head_size, + 0, + ), + } + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, loaded_shard_id + ) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_rank = self.tp_rank + else: + shard_rank = self.tp_rank // self.num_kv_head_replicas + start_idx = shard_rank * shard_size + + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + diff --git a/vllm_kunlun/ops/paged_attn.py b/vllm_kunlun/ops/paged_attn.py index b2bcd14..9f2efb2 100644 --- a/vllm_kunlun/ops/paged_attn.py +++ b/vllm_kunlun/ops/paged_attn.py @@ -8,14 +8,8 @@ from typing import List, Optional, Tuple from vllm.platforms import current_platform -if current_platform.is_kunlun(): - from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops -else: - from vllm import _custom_ops as ops - from vllm.triton_utils.importing import HAS_TRITON +from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops - if HAS_TRITON: - from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py index 2d65f3d..a151568 100644 --- a/vllm_kunlun/ops/rotary_embedding.py +++ b/vllm_kunlun/ops/rotary_embedding.py @@ -70,7 +70,7 @@ def vllm_kunlun_forward_cuda( self.is_neox_style, self.rotary_dim, offsets) else: - ops.rotary_embedding(positions, query, key, self.head_size, + query, key = ops.rotary_embedding(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style) return query, key @@ -143,14 +143,11 @@ def vllm_kunlun_mrope_forward_cuda( return query, key -# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda -# RotaryEmbedding.forward = vllm_kunlun_forward_cuda -# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache +RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda +RotaryEmbedding.forward = vllm_kunlun_forward_cuda MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda -# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq -# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache def Split_Norm_Rope( diff --git a/vllm_kunlun/ops/vocab_parallel_embedding.py b/vllm_kunlun/ops/vocab_parallel_embedding.py index 5d1b37a..e3d5a8a 100644 --- a/vllm_kunlun/ops/vocab_parallel_embedding.py +++ b/vllm_kunlun/ops/vocab_parallel_embedding.py @@ -1,143 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Optional - import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter, UninitializedParameter - +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) -from vllm.model_executor.layers.utils import dispatch_unquantized_gemm -from vllm.model_executor.parameter import BasevLLMParameter -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform - -DEFAULT_VOCAB_PADDING_SIZE = 64 - - -class UnquantizedEmbeddingMethod(QuantizeMethodBase): - """Unquantized method for embeddings.""" - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - """Create weights for embedding layer.""" - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) - - def embedding(self, layer: torch.nn.Module, - input_: torch.Tensor) -> torch.Tensor: - return F.embedding(input_, layer.weight) - - -def pad_vocab_size(vocab_size: int, - pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: - """Pad the vocab size to the given value.""" - return ((vocab_size + pad_to - 1) // pad_to) * pad_to - - -def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, - rank: int, - offset: int = 0) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f + offset, index_l + offset - - -def vocab_range_from_global_vocab_size(global_vocab_size: int, - rank: int, - world_size: int, - offset: int = 0) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, - offset=offset) - - -@dataclass -class VocabParallelEmbeddingShardIndices: - """Indices for a shard of a vocab parallel embedding.""" - padded_org_vocab_start_index: int - padded_org_vocab_end_index: int - padded_added_vocab_start_index: int - padded_added_vocab_end_index: int - - org_vocab_start_index: int - org_vocab_end_index: int - added_vocab_start_index: int - added_vocab_end_index: int - - @property - def num_org_elements(self) -> int: - return self.org_vocab_end_index - self.org_vocab_start_index - - @property - def num_added_elements(self) -> int: - return self.added_vocab_end_index - self.added_vocab_start_index - - @property - def num_org_elements_padded(self) -> int: - return (self.padded_org_vocab_end_index - - self.padded_org_vocab_start_index) - - @property - def num_added_elements_padded(self) -> int: - return (self.padded_added_vocab_end_index - - self.padded_added_vocab_start_index) - - @property - def num_org_vocab_padding(self) -> int: - return self.num_org_elements_padded - self.num_org_elements - - @property - def num_added_vocab_padding(self) -> int: - return self.num_added_elements_padded - self.num_added_elements - - @property - def num_elements_padded(self) -> int: - return self.num_org_elements_padded + self.num_added_elements_padded - - def __post_init__(self): - # sanity checks - assert (self.padded_org_vocab_start_index - <= self.padded_org_vocab_end_index) - assert (self.padded_added_vocab_start_index - <= self.padded_added_vocab_end_index) - - assert self.org_vocab_start_index <= self.org_vocab_end_index - assert self.added_vocab_start_index <= self.added_vocab_end_index - - assert self.org_vocab_start_index <= self.padded_org_vocab_start_index - assert (self.added_vocab_start_index - <= self.padded_added_vocab_start_index) - assert self.org_vocab_end_index <= self.padded_org_vocab_end_index - assert self.added_vocab_end_index <= self.padded_added_vocab_end_index - - assert self.num_org_elements <= self.num_org_elements_padded - assert self.num_added_elements <= self.num_added_elements_padded - @torch.compile(dynamic=True, backend="aot_eager") def get_masked_input_and_mask( @@ -159,319 +27,25 @@ def get_masked_input_and_mask( input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask +def forward_native_kunlun(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output -@CustomOp.register("vllm_kunlun_vocab_parallel_embedding") -class VocabParallelEmbedding(CustomOp): - """Embedding parallelized in the vocabulary dimension. - - Adapted from torch.nn.Embedding, note that we pad the vocabulary size to - make sure it is divisible by the number of model parallel GPUs. - - In order to support various loading methods, we ensure that LoRA-added - embeddings are always at the end of TP-sharded tensors. In other words, - we shard base embeddings and LoRA embeddings separately (both padded), - and place them in the same tensor. - In this example, we will have the original vocab size = 1010, - added vocab size = 16 and padding to 64. Therefore, the total - vocab size with padding will be 1088 (because we first pad 1010 to - 1024, add 16, and then pad to 1088). - Therefore, the tensor format looks like the following: - TP1, rank 0 (no sharding): - |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| - corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 | - index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | - - TP2, rank 0: - |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| - corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 | - index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 | - TP2, rank 1: - |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| - corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | - index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 | - - Args: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - params_dtype: type of the parameters. - org_num_embeddings: original vocabulary size (without LoRA). - padding_size: padding size for the vocabulary. - quant_config: quant config for the layer - prefix: full name of the layer in the state dict - """ # noqa: E501 - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__() - - # Keep the input dimensions. - tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_embeddings = num_embeddings - self.padding_size = padding_size - self.org_vocab_size = org_num_embeddings or num_embeddings - num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, - self.padding_size) - self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, - self.padding_size) - assert self.org_vocab_size_padded <= self.num_embeddings_padded - - self.shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) - self.embedding_dim = embedding_dim - - quant_method = None - if quant_config is not None: - quant_method = quant_config.get_quant_method(self, prefix=prefix) - if quant_method is None: - quant_method = UnquantizedEmbeddingMethod() - - # If we are making an embedding layer, then our quantization linear - # method must implement the embedding operation. If we are another - # layer type like ParallelLMHead, this is not important. - is_embedding_layer = type(self) is VocabParallelEmbedding - quant_method_implements_embedding = method_has_implemented_embedding( - type(quant_method)) - if is_embedding_layer and not quant_method_implements_embedding: - raise NotImplementedError( - f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod.") - - self.quant_method: QuantizeMethodBase = quant_method - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - # Divide the weight matrix along the vocaburaly dimension. - self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide(self.num_embeddings_padded, - self.tp_size) - assert (self.shard_indices.num_elements_padded == - self.num_embeddings_per_partition) - self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index) - self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index) - - self.quant_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) - - @classmethod - def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, - vocab_size: int, org_vocab_size: int, tp_rank: int, - tp_size: int) -> VocabParallelEmbeddingShardIndices: - """Get start and end indices for vocab parallel embedding, following the - layout outlined in the class docstring, based on the given tp_rank and - tp_size.""" - num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded - padded_org_vocab_start_index, padded_org_vocab_end_index = ( - vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, - tp_size)) - padded_added_vocab_start_index, padded_added_vocab_end_index = ( - vocab_range_from_global_vocab_size(num_added_embeddings_padded, - tp_rank, - tp_size, - offset=org_vocab_size)) - # remove padding - org_vocab_start_index = min(padded_org_vocab_start_index, - org_vocab_size) - org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) - added_vocab_start_index = min(padded_added_vocab_start_index, - vocab_size) - added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) - return VocabParallelEmbeddingShardIndices( - padded_org_vocab_start_index, padded_org_vocab_end_index, - padded_added_vocab_start_index, padded_added_vocab_end_index, - org_vocab_start_index, org_vocab_end_index, - added_vocab_start_index, added_vocab_end_index) - - def get_sharded_to_full_mapping(self) -> Optional[list[int]]: - """Get a mapping that can be used to reindex the gathered - logits for sampling. - - During sampling, we gather logits from all ranks. The relationship - of index->token_id will follow the same format as outlined in the class - docstring. However, after the gather, we want to reindex the final - logits tensor to map index->token_id one-to-one (the index is always - equal the token_id it corresponds to). The indices returned by this - method allow us to do that. - """ - if self.tp_size < 2: - return None - - base_embeddings: list[int] = [] - added_embeddings: list[int] = [] - padding: list[int] = [] - for tp_rank in range(self.tp_size): - shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) - range_start = self.num_embeddings_per_partition * tp_rank - range_end = self.num_embeddings_per_partition * (tp_rank + 1) - base_embeddings.extend( - range(range_start, - range_start + shard_indices.num_org_elements)) - padding.extend( - range(range_start + shard_indices.num_org_elements, - range_start + shard_indices.num_org_elements_padded)) - added_embeddings.extend( - range( - range_start + shard_indices.num_org_elements_padded, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements)) - padding.extend( - range( - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded)) - assert (range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded == range_end) - ret = base_embeddings + added_embeddings + padding - assert len(ret) == self.num_embeddings_padded - return ret - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - output_dim = getattr(param, "output_dim", None) - packed_dim = getattr(param, "packed_dim", None) - - # If the parameter is a gguf weight, then load it directly. - if getattr(param, "is_gguf_weight_type", None): - param.data.copy_(loaded_weight) - param.weight_type = loaded_weight.item() - return - elif isinstance(param, UninitializedParameter): - shape = list(loaded_weight.shape) - if output_dim is not None: - shape[output_dim] = self.num_embeddings_per_partition - param.materialize(tuple(shape), dtype=loaded_weight.dtype) - - # If parameter does not have output dim, then it should - # be copied onto all gpus (e.g. g_idx for act_order gptq). - if output_dim is None: - assert param.data.shape == loaded_weight.shape - param.data.copy_(loaded_weight) - return - - # Shard indexes for loading the weight - start_idx = self.shard_indices.org_vocab_start_index - shard_size = self.shard_indices.org_vocab_end_index - start_idx - - # If param packed on the same dim we are sharding on, then - # need to adjust offsets of loaded weight by pack_factor. - if packed_dim is not None and packed_dim == output_dim: - packed_factor = param.packed_factor if isinstance( - param, BasevLLMParameter) else param.pack_factor - assert loaded_weight.shape[output_dim] == (self.org_vocab_size // - param.packed_factor) - start_idx = start_idx // packed_factor - shard_size = shard_size // packed_factor - else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size - - # Copy the data. Select chunk corresponding to current shard. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) - - def forward(self, input_): - if self.tp_size > 1: - # Build the mask. - masked_input, input_mask = get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, - self.shard_indices.org_vocab_end_index, - self.shard_indices.num_org_vocab_padding, - self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) - else: - masked_input = input_ - # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output - - def extra_repr(self) -> str: - s = f"num_embeddings={self.num_embeddings_per_partition}" - s += f", embedding_dim={self.embedding_dim}" - s += f", org_vocab_size={self.org_vocab_size}" - s += f', num_embeddings_padded={self.num_embeddings_padded}' - s += f', tp_size={self.tp_size}' - return s - - -class ParallelLMHead(VocabParallelEmbedding): - """Parallelized LM head. - - Output logits weight matrices used in the Sampler. The weight and bias - tensors are padded to make sure they are divisible by the number of - model parallel GPUs. - - Args: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - bias: whether to use bias. - params_dtype: type of the parameters. - org_num_embeddings: original vocabulary size (without LoRA). - padding_size: padding size for the vocabulary. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config, - prefix) - self.quant_config = quant_config - if bias: - self.bias = Parameter( - torch.empty(self.num_embeddings_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - - def tie_weights(self, embed_tokens: VocabParallelEmbedding): - """Tie the weights with word embeddings.""" - # GGUF quantized embed_tokens. - if self.quant_config and self.quant_config.get_name() == "gguf": - return embed_tokens - else: - self.weight = embed_tokens.weight - return self - - def forward(self, input_): - del input_ - raise RuntimeError("LMHead's weights should be used in the sampler.") \ No newline at end of file +VocabParallelEmbedding.forward_native = forward_native_kunlun \ No newline at end of file diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 4f2555e..0c2d050 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - # Prefix cache loc kv_lod_cpu: Optional[torch.Tensor] = None kv_lod_xpu: Optional[torch.Tensor] = None @@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if blocksparse_params is not None: raise ValueError( "kunlunAttention does not support block-sparse attention.") - # if logits_soft_cap is not None: - # raise ValueError( - # "kunlunAttention does not support attention logits soft capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] + prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens] + prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens] + # For hybrid Attention (Qwen3-Next.) if key_cache.is_contiguous(): tmp_block_tables = prefill_meta.block_tables else: - tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next - - xtorch_ops.prefill_attention( - q=prefill_query, - k=key_cache, # Key Cache (block_num, head, block_size, dim) - v=value_cache, - out=output[num_decode_tokens:attn_metadata.num_actual_tokens], - is_causal=True, - is_prefix_cache=True, - block_table=tmp_block_tables, - context_qlen_lod_cpu=prefill_meta.query_start_loc_host, - context_qlen_lod_xpu=prefill_meta.query_start_loc, - context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, - context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu, - alibi_slopes=self.alibi_slopes, - softmax_lse=None, - sink=self.sinks - ) + # For hybrid Attention (Qwen3-Next) + tmp_block_tables = prefill_meta.block_tables * 2 + + # Prefix cache + if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]: + xtorch_ops.prefill_attention( + q=prefill_query, + k=key_cache, # Key Cache [block_num, head, block_size, dim] + v=value_cache, + out=output[num_decode_tokens:attn_metadata.num_actual_tokens], + is_causal=True, + is_prefix_cache=True, + block_table=tmp_block_tables, + context_qlen_lod_cpu=prefill_meta.query_start_loc_host, + context_qlen_lod_xpu=prefill_meta.query_start_loc, + context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, + context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu, + alibi_slopes=self.alibi_slopes, + softmax_lse=None + ) + else: + xtorch_ops.prefill_attention( + q=prefill_query, + k=prefill_key, + v=prefill_value, + out=output[num_decode_tokens:attn_metadata.num_actual_tokens], + is_causal=True, + context_qlen_lod_cpu=prefill_meta.query_start_loc_host, + context_qlen_lod_xpu=prefill_meta.query_start_loc, + alibi_slopes=self.alibi_slopes, + softmax_lse=None, + swa_left = self.sliding_window if self.sliding_window is not None else -1, + swa_right = 0 if self.sliding_window is not None else -1, + sink = self.sinks.to(torch.float32) if self.sinks is not None else None + ) - if decode_meta := attn_metadata.decode_metadata: + + if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata.") decode_query = query[:num_decode_tokens] + # For hybrid Attention (Qwen3-Next if key_cache.is_contiguous(): tmp_block_tables = decode_meta.block_tables else: tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next - - xtorch_ops.paged_attention( - x=decode_query, - k_cache=key_cache, - v_cache=value_cache, - block_tables=tmp_block_tables, - context_lens_cpu=decode_meta.seq_lens_tensor_cpu, - context_lens_xpu=decode_meta.seq_lens_tensor, - is_context=False, - is_causal=True, + + xtorch_ops.speculative_attention( out=output[:num_decode_tokens], - vo_head_dim=self.head_size - ) + # Only MLA support q len > 1 right now + q=decode_query.unsqueeze(0), + k_cache=key_cache, + v_cache=value_cache, + context_lens_cpu=decode_meta.seq_lens_tensor_cpu, + context_lens_xpu=decode_meta.seq_lens_tensor, + batch_num=decode_meta.block_tables.shape[0], + # TODO (@xyDong23): Support MTP(q lens >1) + qlen=1, + # TODO (@xyDong23): Support max_context_len to (262144) + max_context_len=131072, + head_num=self.num_heads, + head_dim=self.head_size, + scale=0.0, + kv_head_num=self.num_kv_heads, + block_size=key_cache.shape[2], + max_num_blocks_per_seq=decode_meta.block_tables.shape[1], + max_window_size=self.sliding_window if self.sliding_window is not None else -1, + block_tables=tmp_block_tables, + sink = self.sinks.to(torch.float32) if self.sinks is not None else None + ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def use_cascade_attention( @@ -788,4 +817,4 @@ def use_cascade_attention( flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time + return cascade_time < flash_decoding_time \ No newline at end of file diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index a799a82..e0729a9 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -938,6 +938,83 @@ def _fake_rotary_embedding( rotary_embedding.register_fake(_fake_rotary_embedding) +@custom_op("_C::quant2d", mutates_args=()) +def quant2d( + x: torch.Tensor, + y: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool, +) -> None: + xtorch_ops.quant2d( + x=x, + y=y, + max=max, + force_sdnn=force_sdnn + ) + +@impl("_C::quant2d", "CUDA") +def quant2d_cuda( + x: torch.Tensor, + y: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool, +) -> None: + xtorch_ops.quant2d( + x=x, + y=y, + max=max, + force_sdnn=force_sdnn + ) + +def _fake_quant2d( + x: torch.Tensor, + y: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool, +) -> None: + return None + +quant2d.register_fake(_fake_quant2d) + +@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=()) +def gemm_I8_I8_bf16_nt( + x_q: torch.Tensor, + x_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + xtorch_ops.gemm_I8_I8_bf16_nt( + lhs=(x_q, x_scale), + rhs=(weight, weight_scale), + out=out + ) + +@impl("_C::gemm_I8_I8_bf16_nt", "CUDA") +def gemm_I8_I8_bf16_nt_cuda( + x_q: torch.Tensor, + x_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + xtorch_ops.gemm_I8_I8_bf16_nt( + lhs=(x_q, x_scale), + rhs=(weight, weight_scale), + out=out + ) + +def _fake_gemm_I8_I8_bf16_nt( + x_q: torch.Tensor, + x_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + return None + +gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt) + @custom_op("_C::moe_softmax_topk_norm", mutates_args=()) def moe_softmax_topk_norm( x: torch.Tensor, @@ -1068,15 +1145,39 @@ def moe_fc( sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, - y: torch.Tensor + y: torch.Tensor, + act: Optional[str] = None, + x_perchannel_max: Optional[torch.Tensor] = None, + w_perchannel_max: Optional[torch.Tensor] = None , + topk_ids: Optional[torch.Tensor] = None, + topk_w: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + tgemm_type: Optional[str] = None, + tweight_type: Optional[str] = None, + scale_n: Optional[int] = 0, + scale_k: Optional[int] = 0, + use_pack_int4: Optional[bool] = False, + sort_mode: Optional[bool] = True )-> None: xtorch_ops.moe_fc( - x, - weight, - sorted_tokens_num_lod, - sorted_tokens_idx, - moe_topk, - y) + x=x, + weight=weight, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_topk, + y=y, + act=act, + x_perchannel_max=x_perchannel_max, + w_perchannel_max=w_perchannel_max, + topk_ids=topk_ids, + topk_w=topk_w, + bias=bias, + tgemm_type=tgemm_type, + tweight_type=tweight_type, + scale_n=scale_n, + scale_k=scale_k, + use_pack_int4=use_pack_int4, + sort_mode=sort_mode) @impl("_C::moe_fc", "CUDA") def moe_fc_cuda( @@ -1085,15 +1186,39 @@ def moe_fc_cuda( sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, - y: torch.Tensor + y: torch.Tensor, + act: Optional[str] = None, + x_perchannel_max: Optional[torch.Tensor] = None, + w_perchannel_max: Optional[torch.Tensor] = None , + topk_ids: Optional[torch.Tensor] = None, + topk_w: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + tgemm_type: Optional[str] = None, + tweight_type: Optional[str] = None, + scale_n: Optional[int] = 0, + scale_k: Optional[int] = 0, + use_pack_int4: Optional[bool] = False, + sort_mode: Optional[bool] = True )-> None: xtorch_ops.moe_fc( - x, - weight, - sorted_tokens_num_lod, - sorted_tokens_idx, - moe_topk, - y) + x=x, + weight=weight, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_topk, + y=y, + act=act, + x_perchannel_max=x_perchannel_max, + w_perchannel_max=w_perchannel_max, + topk_ids=topk_ids, + topk_w=topk_w, + bias=bias, + tgemm_type=tgemm_type, + tweight_type=tweight_type, + scale_n=scale_n, + scale_k=scale_k, + use_pack_int4=use_pack_int4, + sort_mode=sort_mode) def fake_moe_fc( x: torch.Tensor, @@ -1101,7 +1226,19 @@ def fake_moe_fc( sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, - y: torch.Tensor + y: torch.Tensor, + act: Optional[str] = None, + x_perchannel_max: Optional[torch.Tensor] = None, + w_perchannel_max: Optional[torch.Tensor] = None , + topk_ids: Optional[torch.Tensor] = None, + topk_w: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + tgemm_type: Optional[str] = None, + tweight_type: Optional[str] = None, + scale_n: Optional[int] = 0, + scale_k: Optional[int] = 0, + use_pack_int4: Optional[bool] = False, + sort_mode: Optional[bool] = True )-> None: return None @@ -1151,6 +1288,63 @@ def fake_moe_post( moe_post.register_fake(fake_moe_post) +@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=()) +def moe_sigmoid_group_topk_norm( + x: torch.Tensor, + topk_index: torch.Tensor, + norm_score: torch.Tensor, + block_static: torch.Tensor, + bias: torch.Tensor, + scale: float, + n_group: int, + topk_group: int +) -> None: + xtorch_ops.moe_sigmoid_group_topk_norm( + x=x, + norm_score=norm_score, + topk_index=topk_index, + block_static=block_static, + bias=bias, + n_group=n_group, + topk_group=topk_group, + scale=scale, + ) + +@impl("_C::moe_sigmoid_group_topk_norm", "CUDA") +def moe_sigmoid_group_topk_norm_cuda( + x: torch.Tensor, + topk_index: torch.Tensor, + norm_score: torch.Tensor, + block_static: torch.Tensor, + bias: torch.Tensor, + scale: float, + n_group: int, + topk_group: int +) -> None: + xtorch_ops.moe_sigmoid_group_topk_norm( + x=x, + norm_score=norm_score, + topk_index=topk_index, + block_static=block_static, + bias=bias, + n_group=n_group, + topk_group=topk_group, + scale=scale, + ) + +def _fake_moe_sigmoid_group_topk_norm( + x: torch.Tensor, + topk_index: torch.Tensor, + norm_score: torch.Tensor, + block_static: torch.Tensor, + bias: torch.Tensor, + scale: float, + n_group: int, + topk_group: int +) -> None: + return None + +moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm) ################################################## # --------------- awq_dequantize ----------------- ##################################################