[Feature] Support XiaoMi MIMO Flash V2 (#62)

* [Feature] Support MIMO Flash V2
This commit is contained in:
Xinyu Dong
2025-12-31 10:16:33 +08:00
committed by GitHub
parent 341dc7f296
commit b3c30a3cb9
12 changed files with 1530 additions and 690 deletions

View File

@@ -82,6 +82,9 @@ def register_model():
"LlamaForCausalLM",
"vllm_kunlun.models.llama:LlamaForCausalLM")
ModelRegistry.register_model(
"MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
def register_quant_method():
"""to do"""

View File

@@ -0,0 +1,706 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.config import (
CacheConfig,
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm_kunlun.ops.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from vllm_kunlun.ops.activation import SiluAndMul
logger = init_logger(__name__)
class MiMoV2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiMoV2MoE(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
is_nextn: bool = False,
):
super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.n_routed_experts
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}."
)
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
self.gate_dtype = torch.float32
self.gate = nn.Linear(
config.hidden_size,
config.n_routed_experts,
bias=False,
dtype=self.gate_dtype,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=self.gate_dtype)
)
self.experts = FusedMoE(
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
scoring_func="sigmoid",
)
self.register_buffer("kunlun_linear_weights", torch.zeros(
config.num_local_experts,config.hidden_size,dtype=torch.float))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.dim() <= 2, "MiMoV2MoE only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.gate_dtype is not None:
gate_input = hidden_states.to(self.gate_dtype)
else:
gate_input = hidden_states
router_logits = self.gate(gate_input)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
)
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
class MiMoV2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
v_head_dim: int | None = None,
sliding_window_size: int = -1,
attention_bias: bool = False,
add_swa_attention_sink_bias: bool = False,
layer_id: int = 0,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
partial_rotary_factor: float = 1.0,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.layer_id = layer_id
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim is not None else head_dim
self.q_size = self.num_heads * self.head_dim
self.k_size = self.num_kv_heads * self.head_dim
self.v_size = self.num_kv_heads * self.v_head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
v_head_size=self.v_head_dim,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.v_head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
partial_rotary_factor=partial_rotary_factor
)
self.attention_sink_bias = (
torch.nn.Parameter(torch.empty(self.num_heads), requires_grad=False)
if add_swa_attention_sink_bias
else None
)
sliding_window = sliding_window_size if sliding_window_size > -1 else None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.attention_sink_bias,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
v = v.view(-1, self.num_kv_heads * self.head_dim)
attn_output = self.attn(q, k, v)
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[
..., : self.v_head_dim
].reshape(-1, self.num_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class MiMoV2FlashDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
self.config = config
self.layer_id = layer_id
rope_theta = getattr(config, "rope_theta", 1000000)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
if self.is_compressed_softmax_layer():
self.self_attn = MiMoV2Attention(
hidden_size=self.hidden_size,
num_heads=config.swa_num_attention_heads,
num_kv_heads=config.swa_num_key_value_heads,
head_dim=config.swa_head_dim,
v_head_dim=getattr(config, "swa_v_head_dim", None),
sliding_window_size=config.sliding_window_size,
attention_bias=config.attention_bias,
add_swa_attention_sink_bias=getattr(
config, "add_swa_attention_sink_bias", False
),
layer_id=layer_id,
rope_theta=getattr(config, "swa_rope_theta", rope_theta),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = MiMoV2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
v_head_dim=getattr(config, "v_head_dim", None),
sliding_window_size=-1, # normal attention
attention_bias=config.attention_bias,
layer_id=layer_id,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
prefix=f"{prefix}.self_attn",
)
self.is_layer_sparse = self.is_moe_layer(layer_id)
if self.is_layer_sparse:
self.mlp = MiMoV2MoE(
vllm_config=vllm_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = MiMoV2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.layernorm_epsilon
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def is_moe_layer(self, layer_idx: int) -> bool:
return (
hasattr(self.config, "moe_layer_freq")
and layer_idx >= 0
and not isinstance(self.config.moe_layer_freq, int)
and self.config.moe_layer_freq[layer_idx]
)
def is_compressed_softmax_layer(self) -> bool:
return self.config.hybrid_layer_pattern[self.layer_id] == 1
class MiMoV2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.get_text_config()
quant_config = vllm_config.quant_config
eplb_config = vllm_config.parallel_config.eplb_config
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.num_redundant_experts = eplb_config.num_redundant_experts
self.v_scale = getattr(config, "attention_value_scale", None)
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiMoV2FlashDecoderLayer(
vllm_config=vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
else:
self.norm = PPMissingLayer()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=self.num_redundant_experts,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
if "mtp" in name:
continue
if self.quant_config is not None:
cache_scale_name = self.quant_config.get_cache_scale(name)
if cache_scale_name is not None and cache_scale_name in params_dict:
param = params_dict[cache_scale_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
kv_scale = loaded_weight
if kv_scale.dim() > 0 and kv_scale.numel() > 1:
kv_scale = kv_scale.view(-1)[0]
weight_loader(param, kv_scale)
loaded_params.add(cache_scale_name)
continue
expert_matched = False
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
if weight_name not in name:
continue
name_rewritten = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_rewritten, self):
continue
if (
name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias")
) and name_rewritten not in params_dict:
continue
if name_rewritten not in params_dict:
continue
param = params_dict[name_rewritten]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_rewritten,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(name_rewritten)
expert_matched = True
break
if expert_matched:
continue
stacked_matched = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name_rewritten = name.replace(weight_name, param_name)
if (
name_rewritten.endswith(".bias")
and name_rewritten not in params_dict
):
continue
if is_pp_missing_parameter(name_rewritten, self):
continue
if name_rewritten not in params_dict:
continue
param = params_dict[name_rewritten]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if param_name == "qkv_proj" and shard_id == "v":
v_scale = (
self.v_scale
if self.v_scale is not None
else getattr(self.config, "attention_value_scale", None)
)
if v_scale is not None and (
name.endswith("weight_scale_inv") or name.endswith(".bias")
):
loaded_weight *= float(v_scale)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name_rewritten)
stacked_matched = True
break
if stacked_matched:
continue
if name.endswith(".bias") and name not in params_dict:
continue
orig_name = name
mapped_name = maybe_remap_kv_scale_name(name, params_dict)
name = mapped_name if mapped_name is not None else orig_name
if name not in params_dict:
continue
param = params_dict[name]
if "attention_sink_bias" in name:
total_heads = loaded_weight.shape[0]
heads_per_rank = total_heads // tp_size
head_start = tp_rank * heads_per_rank
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
else:
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = MiMoV2Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@@ -15,8 +15,8 @@
# This file is a part of the vllm-ascend project.
#
# import vllm_kunlun.ops.linear
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.vocab_parallel_embedding

View File

@@ -1,3 +1,20 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
import torch
@@ -177,51 +194,10 @@ class KunlunOps:
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
if not is_neox_style:
if cos_sin_cache.dtype == torch.float16:
cos_sin_cache = cos_sin_cache.to(torch.float32)
positions = positions.to(torch.int)
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
xtorch_ops.rotary_embedding_gptj(
positions,
query_x,
key_x,
head_size,
cos_sin_cache)
query.data = query_x
key.data = key_x
if query_x_dim != query_x.dim():
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
return query, key
# TODO: need opt
if cos_sin_cache.dim() == 4:
max_seq_len = cos_sin_cache.shape[2]
head_dim = cos_sin_cache.shape[3]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
# 重塑 query 和 key 的形状
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
# query_x = query_x.view(num_tokens, num_heads, head_size)
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
# # 确保形状正确
# assert query_x.shape == (num_tokens, num_heads, head_size), \
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
torch.ops._C.rotary_embedding(
positions,
@@ -234,8 +210,6 @@ class KunlunOps:
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
# query.data = query_x
# key.data = key_x
return query_x, key_x
# Rotary embedding
@@ -433,6 +407,121 @@ class KunlunOps:
return out
def _dbg(x):
if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
moe_top_k: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""fused_moe"""
global_num_experts = linear_weights.shape[0]
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
moe_top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
moe_top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=1,
)
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.swiglu(y, out1)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
@@ -487,42 +576,6 @@ class KunlunOps:
return output
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,

View File

@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
linear_weights=linear_weights)
linear_weights=linear_weights,
e_score_correction_bias=e_score_correction_bias)
def forward_kunlun(
self,
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
num_expert_group=num_expert_group,
topk_group=topk_group
)
# fused_moe do not support expert number > 400
elif layer.local_num_experts > 400:
hidden_states = x
global_num_experts = linear_weights.shape[0]
M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty(M,
top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
router_logits = router_logits.float()
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=layer.w13_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.swiglu(y, out1)
out = torch.empty(M,top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=layer.w2_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out)
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
else:
return ops.fused_moe(x,
layer.w13_weight,
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
class FusedMoE(VllmFusedMoE):

View File

@@ -57,6 +57,8 @@ def vllm_kunlun_forward_cuda(
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod

View File

@@ -3,14 +3,30 @@
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED,
ReplicatedLinear,
UnquantizedLinearMethod,
ColumnParallelLinear
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter,
)
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -59,4 +75,361 @@ def create_weights(
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
UnquantizedLinearMethod.create_weights = create_weights
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
class QKVParallelLinear(ColumnParallelLinear):
"""
Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2)
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
v_head_size: int | None = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.v_head_size = v_head_size if v_head_size is not None else head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (
self.num_heads * self.head_size
+ self.num_kv_heads * self.head_size
+ self.num_kv_heads * self.v_head_size
) * tp_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
]
super().__init__(
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp,
)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.v_head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
# Assume the weight block size has been set by quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
)
def weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
return
if is_gguf_weight:
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset
)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
"total": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0,
),
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id
)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.v_head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset
)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (
self.num_heads * self.head_size,
self.num_kv_heads * self.head_size,
),
"v": (
(self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.v_head_size,
),
"total": (
(self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
0,
),
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id
)
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_rank = self.tp_rank
else:
shard_rank = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id
)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

View File

@@ -8,14 +8,8 @@ from typing import List, Optional, Tuple
from vllm.platforms import current_platform
if current_platform.is_kunlun():
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
else:
from vllm import _custom_ops as ops
from vllm.triton_utils.importing import HAS_TRITON
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

View File

@@ -70,7 +70,7 @@ def vllm_kunlun_forward_cuda(
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
@@ -143,14 +143,11 @@ def vllm_kunlun_mrope_forward_cuda(
return query, key
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
def Split_Norm_Rope(

View File

@@ -1,143 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int,
rank: int,
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int,
rank: int,
world_size: int,
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank,
offset=offset)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index -
self.padded_org_vocab_start_index)
@property
def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index -
self.padded_added_vocab_start_index)
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert (self.padded_org_vocab_start_index
<= self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index
<= self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index
<= self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
@torch.compile(dynamic=True, backend="aot_eager")
def get_masked_input_and_mask(
@@ -159,319 +27,25 @@ def get_masked_input_and_mask(
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def forward_native_kunlun(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
@CustomOp.register("vllm_kunlun_vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
vocab_size: int, org_vocab_size: int, tp_rank: int,
tp_size: int) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
tp_size))
padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
tp_rank,
tp_size,
offset=org_vocab_size))
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index,
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index,
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index,
padded_added_vocab_start_index, padded_added_vocab_end_index,
org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if self.tp_size < 2:
return None
base_embeddings: list[int] = []
added_embeddings: list[int] = []
padding: list[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend(
range(range_start,
range_start + shard_indices.num_org_elements))
padding.extend(
range(range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded))
added_embeddings.extend(
range(
range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements))
padding.extend(
range(
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded))
assert (range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded == range_end)
ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = self.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
# Shard indexes for loading the weight
start_idx = self.shard_indices.org_vocab_start_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance(
param, BasevLLMParameter) else param.pack_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
VocabParallelEmbedding.forward_native = forward_native_kunlun

View File

@@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Prefix cache loc
kv_lod_cpu: Optional[torch.Tensor] = None
kv_lod_xpu: Optional[torch.Tensor] = None
@@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
# For hybrid Attention (Qwen3-Next.)
if key_cache.is_contiguous():
tmp_block_tables = prefill_meta.block_tables
else:
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim)
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
sink=self.sinks
)
# For hybrid Attention (Qwen3-Next)
tmp_block_tables = prefill_meta.block_tables * 2
# Prefix cache
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None
)
else:
xtorch_ops.prefill_attention(
q=prefill_query,
k=prefill_key,
v=prefill_value,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
swa_left = self.sliding_window if self.sliding_window is not None else -1,
swa_right = 0 if self.sliding_window is not None else -1,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
if decode_meta := attn_metadata.decode_metadata:
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
decode_query = query[:num_decode_tokens]
# For hybrid Attention (Qwen3-Next
if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=tmp_block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
is_causal=True,
xtorch_ops.speculative_attention(
out=output[:num_decode_tokens],
vo_head_dim=self.head_size
)
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
@@ -788,4 +817,4 @@ def use_cascade_attention(
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
return cascade_time < flash_decoding_time

View File

@@ -938,6 +938,83 @@ def _fake_rotary_embedding(
rotary_embedding.register_fake(_fake_rotary_embedding)
@custom_op("_C::quant2d", mutates_args=())
def quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
@impl("_C::quant2d", "CUDA")
def quant2d_cuda(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
def _fake_quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
return None
quant2d.register_fake(_fake_quant2d)
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
def gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
def gemm_I8_I8_bf16_nt_cuda(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
def _fake_gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
return None
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
def moe_softmax_topk_norm(
x: torch.Tensor,
@@ -1068,15 +1145,39 @@ def moe_fc(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x,
weight,
sorted_tokens_num_lod,
sorted_tokens_idx,
moe_topk,
y)
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
@impl("_C::moe_fc", "CUDA")
def moe_fc_cuda(
@@ -1085,15 +1186,39 @@ def moe_fc_cuda(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x,
weight,
sorted_tokens_num_lod,
sorted_tokens_idx,
moe_topk,
y)
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
def fake_moe_fc(
x: torch.Tensor,
@@ -1101,7 +1226,19 @@ def fake_moe_fc(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
return None
@@ -1151,6 +1288,63 @@ def fake_moe_post(
moe_post.register_fake(fake_moe_post)
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
def moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
def moe_sigmoid_group_topk_norm_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
def _fake_moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
return None
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
##################################################
# --------------- awq_dequantize -----------------
##################################################