[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -82,6 +82,9 @@ def register_model():
|
||||
"LlamaForCausalLM",
|
||||
"vllm_kunlun.models.llama:LlamaForCausalLM")
|
||||
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"MiMoV2FlashForCausalLM",
|
||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
|
||||
|
||||
def register_quant_method():
|
||||
"""to do"""
|
||||
|
||||
706
vllm_kunlun/models/mimo_v2_flash.py
Normal file
706
vllm_kunlun/models/mimo_v2_flash.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm_kunlun.ops.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MiMoV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MiMoV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_nextn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}."
|
||||
)
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
self.gate_dtype = torch.float32
|
||||
self.gate = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
dtype=self.gate_dtype,
|
||||
)
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=self.gate_dtype)
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
scoring_func="sigmoid",
|
||||
)
|
||||
self.register_buffer("kunlun_linear_weights", torch.zeros(
|
||||
config.num_local_experts,config.hidden_size,dtype=torch.float))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert hidden_states.dim() <= 2, "MiMoV2MoE only supports 1D or 2D inputs"
|
||||
is_input_1d = hidden_states.dim() == 1
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.gate_dtype is not None:
|
||||
gate_input = hidden_states.to(self.gate_dtype)
|
||||
else:
|
||||
gate_input = hidden_states
|
||||
router_logits = self.gate(gate_input)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
|
||||
)
|
||||
|
||||
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
|
||||
|
||||
|
||||
class MiMoV2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
v_head_dim: int | None = None,
|
||||
sliding_window_size: int = -1,
|
||||
attention_bias: bool = False,
|
||||
add_swa_attention_sink_bias: bool = False,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_id = layer_id
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = num_heads
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.v_head_dim = v_head_dim if v_head_dim is not None else head_dim
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.k_size = self.num_kv_heads * self.head_dim
|
||||
self.v_size = self.num_kv_heads * self.v_head_dim
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
v_head_size=self.v_head_dim,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.v_head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
partial_rotary_factor=partial_rotary_factor
|
||||
)
|
||||
|
||||
self.attention_sink_bias = (
|
||||
torch.nn.Parameter(torch.empty(self.num_heads), requires_grad=False)
|
||||
if add_swa_attention_sink_bias
|
||||
else None
|
||||
)
|
||||
|
||||
sliding_window = sliding_window_size if sliding_window_size > -1 else None
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix=f"{prefix}.attn",
|
||||
sinks=self.attention_sink_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
|
||||
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
|
||||
v = v.view(-1, self.num_kv_heads * self.head_dim)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[
|
||||
..., : self.v_head_dim
|
||||
].reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MiMoV2FlashDecoderLayer(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||
|
||||
if self.is_compressed_softmax_layer():
|
||||
self.self_attn = MiMoV2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.swa_num_attention_heads,
|
||||
num_kv_heads=config.swa_num_key_value_heads,
|
||||
head_dim=config.swa_head_dim,
|
||||
v_head_dim=getattr(config, "swa_v_head_dim", None),
|
||||
sliding_window_size=config.sliding_window_size,
|
||||
attention_bias=config.attention_bias,
|
||||
add_swa_attention_sink_bias=getattr(
|
||||
config, "add_swa_attention_sink_bias", False
|
||||
),
|
||||
layer_id=layer_id,
|
||||
rope_theta=getattr(config, "swa_rope_theta", rope_theta),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = MiMoV2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
v_head_dim=getattr(config, "v_head_dim", None),
|
||||
sliding_window_size=-1, # normal attention
|
||||
attention_bias=config.attention_bias,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
self.is_layer_sparse = self.is_moe_layer(layer_id)
|
||||
if self.is_layer_sparse:
|
||||
self.mlp = MiMoV2MoE(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = MiMoV2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.layernorm_epsilon
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
def is_moe_layer(self, layer_idx: int) -> bool:
|
||||
return (
|
||||
hasattr(self.config, "moe_layer_freq")
|
||||
and layer_idx >= 0
|
||||
and not isinstance(self.config.moe_layer_freq, int)
|
||||
and self.config.moe_layer_freq[layer_idx]
|
||||
)
|
||||
|
||||
def is_compressed_softmax_layer(self) -> bool:
|
||||
return self.config.hybrid_layer_pattern[self.layer_id] == 1
|
||||
|
||||
|
||||
class MiMoV2Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
quant_config = vllm_config.quant_config
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.v_scale = getattr(config, "attention_value_scale", None)
|
||||
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiMoV2FlashDecoderLayer(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
continue
|
||||
if "mtp" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None:
|
||||
cache_scale_name = self.quant_config.get_cache_scale(name)
|
||||
if cache_scale_name is not None and cache_scale_name in params_dict:
|
||||
param = params_dict[cache_scale_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
|
||||
kv_scale = loaded_weight
|
||||
if kv_scale.dim() > 0 and kv_scale.numel() > 1:
|
||||
kv_scale = kv_scale.view(-1)[0]
|
||||
|
||||
weight_loader(param, kv_scale)
|
||||
loaded_params.add(cache_scale_name)
|
||||
continue
|
||||
|
||||
expert_matched = False
|
||||
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name_rewritten = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_rewritten, self):
|
||||
continue
|
||||
|
||||
if (
|
||||
name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias")
|
||||
) and name_rewritten not in params_dict:
|
||||
continue
|
||||
|
||||
if name_rewritten not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name_rewritten]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_rewritten,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name_rewritten)
|
||||
expert_matched = True
|
||||
break
|
||||
|
||||
if expert_matched:
|
||||
continue
|
||||
|
||||
stacked_matched = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name_rewritten = name.replace(weight_name, param_name)
|
||||
|
||||
if (
|
||||
name_rewritten.endswith(".bias")
|
||||
and name_rewritten not in params_dict
|
||||
):
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name_rewritten, self):
|
||||
continue
|
||||
|
||||
if name_rewritten not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name_rewritten]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
|
||||
if param_name == "qkv_proj" and shard_id == "v":
|
||||
v_scale = (
|
||||
self.v_scale
|
||||
if self.v_scale is not None
|
||||
else getattr(self.config, "attention_value_scale", None)
|
||||
)
|
||||
if v_scale is not None and (
|
||||
name.endswith("weight_scale_inv") or name.endswith(".bias")
|
||||
):
|
||||
loaded_weight *= float(v_scale)
|
||||
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name_rewritten)
|
||||
|
||||
stacked_matched = True
|
||||
break
|
||||
|
||||
if stacked_matched:
|
||||
continue
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
orig_name = name
|
||||
mapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
name = mapped_name if mapped_name is not None else orig_name
|
||||
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
|
||||
if "attention_sink_bias" in name:
|
||||
total_heads = loaded_weight.shape[0]
|
||||
heads_per_rank = total_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
|
||||
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
else:
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MiMoV2Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
@@ -15,8 +15,8 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
# import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
@@ -1,3 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
import torch
|
||||
@@ -177,51 +194,10 @@ class KunlunOps:
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
if not is_neox_style:
|
||||
if cos_sin_cache.dtype == torch.float16:
|
||||
cos_sin_cache = cos_sin_cache.to(torch.float32)
|
||||
positions = positions.to(torch.int)
|
||||
if positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
|
||||
xtorch_ops.rotary_embedding_gptj(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache)
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
if query_x_dim != query_x.dim():
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
return query, key
|
||||
|
||||
# TODO: need opt
|
||||
if cos_sin_cache.dim() == 4:
|
||||
max_seq_len = cos_sin_cache.shape[2]
|
||||
head_dim = cos_sin_cache.shape[3]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
||||
|
||||
# 重塑 query 和 key 的形状
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # 确保形状正确
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
@@ -234,8 +210,6 @@ class KunlunOps:
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@@ -433,6 +407,121 @@ class KunlunOps:
|
||||
|
||||
return out
|
||||
|
||||
def _dbg(x):
|
||||
if torch.is_tensor(x):
|
||||
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||
return (type(x), x)
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
moe_top_k: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = w2.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=1,
|
||||
)
|
||||
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -487,42 +576,6 @@ class KunlunOps:
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_multi_head_latent_page_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights)
|
||||
linear_weights=linear_weights,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
# fused_moe do not support expert number > 400
|
||||
elif layer.local_num_experts > 400:
|
||||
hidden_states = x
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
router_logits = router_logits.float()
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=layer.w13_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=layer.w2_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
|
||||
@@ -57,6 +57,8 @@ def vllm_kunlun_forward_cuda(
|
||||
)
|
||||
return out
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
|
||||
@@ -3,14 +3,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ReplicatedLinear,
|
||||
UnquantizedLinearMethod,
|
||||
ColumnParallelLinear
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -59,4 +75,361 @@ def create_weights(
|
||||
|
||||
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
||||
UnquantizedLinearMethod.create_weights = create_weights
|
||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
v_head_size: int | None = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.v_head_size = v_head_size if v_head_size is not None else head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (
|
||||
self.num_heads * self.head_size
|
||||
+ self.num_kv_heads * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size
|
||||
) * tp_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
"q": 0,
|
||||
"k": self.num_heads * self.head_size,
|
||||
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
"total": (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_offset_mapping.get(loaded_shard_id)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads * self.head_size,
|
||||
"k": self.num_kv_heads * self.head_size,
|
||||
"v": self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
def _load_fused_module_from_checkpoint(
|
||||
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Handle special case for models where QKV layers are already
|
||||
fused on disk. In this case, we have no shard id. This function
|
||||
determines the shard id by splitting these layers and then calls
|
||||
the weight loader using the shard id.
|
||||
|
||||
An example of a model with these fused layers:
|
||||
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
||||
"""
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
(
|
||||
"k",
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if (
|
||||
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and param.packed_dim == param.output_dim
|
||||
):
|
||||
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
|
||||
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
if loaded_shard_id is not None:
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv).
|
||||
# (e.g., Phi-3's qkv_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
(
|
||||
"k",
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
"v": (
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size
|
||||
+ self.total_num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.v_head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (
|
||||
self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
),
|
||||
"v": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id
|
||||
)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_rank = self.tp_rank
|
||||
else:
|
||||
shard_rank = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_rank * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id
|
||||
)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -8,14 +8,8 @@ from typing import List, Optional, Tuple
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
if current_platform.is_kunlun():
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@@ -70,7 +70,7 @@ def vllm_kunlun_forward_cuda(
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
@@ -143,14 +143,11 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
|
||||
@@ -1,143 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank,
|
||||
offset=offset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
padded_added_vocab_end_index: int
|
||||
|
||||
org_vocab_start_index: int
|
||||
org_vocab_end_index: int
|
||||
added_vocab_start_index: int
|
||||
added_vocab_end_index: int
|
||||
|
||||
@property
|
||||
def num_org_elements(self) -> int:
|
||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements(self) -> int:
|
||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
return self.num_org_elements_padded - self.num_org_elements
|
||||
|
||||
@property
|
||||
def num_added_vocab_padding(self) -> int:
|
||||
return self.num_added_elements_padded - self.num_added_elements
|
||||
|
||||
@property
|
||||
def num_elements_padded(self) -> int:
|
||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index
|
||||
<= self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index
|
||||
<= self.padded_added_vocab_end_index)
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index
|
||||
<= self.padded_added_vocab_start_index)
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.num_org_elements <= self.num_org_elements_padded
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend="aot_eager")
|
||||
def get_masked_input_and_mask(
|
||||
@@ -159,319 +27,25 @@ def get_masked_input_and_mask(
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def forward_native_kunlun(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
@CustomOp.register("vllm_kunlun_vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
quant_config: quant config for the layer
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
During sampling, we gather logits from all ranks. The relationship
|
||||
of index->token_id will follow the same format as outlined in the class
|
||||
docstring. However, after the gather, we want to reindex the final
|
||||
logits tensor to map index->token_id one-to-one (the index is always
|
||||
equal the token_id it corresponds to). The indices returned by this
|
||||
method allow us to do that.
|
||||
"""
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: list[int] = []
|
||||
added_embeddings: list[int] = []
|
||||
padding: list[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If the parameter is a gguf weight, then load it directly.
|
||||
if getattr(param, "is_gguf_weight_type", None):
|
||||
param.data.copy_(loaded_weight)
|
||||
param.weight_type = loaded_weight.item()
|
||||
return
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
shape[output_dim] = self.num_embeddings_per_partition
|
||||
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
assert param.data.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# Shard indexes for loading the weight
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
||||
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = param.packed_factor if isinstance(
|
||||
param, BasevLLMParameter) else param.pack_factor
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.packed_factor)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
|
||||
# Copy the data. Select chunk corresponding to current shard.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
||||
s += f', tp_size={self.tp_size}'
|
||||
return s
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
|
||||
Output logits weight matrices used in the Sampler. The weight and bias
|
||||
tensors are padded to make sure they are divisible by the number of
|
||||
model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||
"""Tie the weights with word embeddings."""
|
||||
# GGUF quantized embed_tokens.
|
||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||
return embed_tokens
|
||||
else:
|
||||
self.weight = embed_tokens.weight
|
||||
return self
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
VocabParallelEmbedding.forward_native = forward_native_kunlun
|
||||
@@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# Prefix cache loc
|
||||
kv_lod_cpu: Optional[torch.Tensor] = None
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None
|
||||
@@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next.)
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = prefill_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
sink=self.sinks
|
||||
)
|
||||
# For hybrid Attention (Qwen3-Next)
|
||||
tmp_block_tables = prefill_meta.block_tables * 2
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
swa_left = self.sliding_window if self.sliding_window is not None else -1,
|
||||
swa_right = 0 if self.sliding_window is not None else -1,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
decode_query = query[:num_decode_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=tmp_block_tables,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
is_context=False,
|
||||
is_causal=True,
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
vo_head_dim=self.head_size
|
||||
)
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||
block_tables=tmp_block_tables,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
def use_cascade_attention(
|
||||
@@ -788,4 +817,4 @@ def use_cascade_attention(
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
return cascade_time < flash_decoding_time
|
||||
@@ -938,6 +938,83 @@ def _fake_rotary_embedding(
|
||||
|
||||
rotary_embedding.register_fake(_fake_rotary_embedding)
|
||||
|
||||
@custom_op("_C::quant2d", mutates_args=())
|
||||
def quant2d(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
x=x,
|
||||
y=y,
|
||||
max=max,
|
||||
force_sdnn=force_sdnn
|
||||
)
|
||||
|
||||
@impl("_C::quant2d", "CUDA")
|
||||
def quant2d_cuda(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
x=x,
|
||||
y=y,
|
||||
max=max,
|
||||
force_sdnn=force_sdnn
|
||||
)
|
||||
|
||||
def _fake_quant2d(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
quant2d.register_fake(_fake_quant2d)
|
||||
|
||||
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
|
||||
def gemm_I8_I8_bf16_nt(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale),
|
||||
rhs=(weight, weight_scale),
|
||||
out=out
|
||||
)
|
||||
|
||||
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
|
||||
def gemm_I8_I8_bf16_nt_cuda(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale),
|
||||
rhs=(weight, weight_scale),
|
||||
out=out
|
||||
)
|
||||
|
||||
def _fake_gemm_I8_I8_bf16_nt(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
|
||||
|
||||
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
|
||||
def moe_softmax_topk_norm(
|
||||
x: torch.Tensor,
|
||||
@@ -1068,15 +1145,39 @@ def moe_fc(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
xtorch_ops.moe_fc(
|
||||
x,
|
||||
weight,
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
moe_topk,
|
||||
y)
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_topk,
|
||||
y=y,
|
||||
act=act,
|
||||
x_perchannel_max=x_perchannel_max,
|
||||
w_perchannel_max=w_perchannel_max,
|
||||
topk_ids=topk_ids,
|
||||
topk_w=topk_w,
|
||||
bias=bias,
|
||||
tgemm_type=tgemm_type,
|
||||
tweight_type=tweight_type,
|
||||
scale_n=scale_n,
|
||||
scale_k=scale_k,
|
||||
use_pack_int4=use_pack_int4,
|
||||
sort_mode=sort_mode)
|
||||
|
||||
@impl("_C::moe_fc", "CUDA")
|
||||
def moe_fc_cuda(
|
||||
@@ -1085,15 +1186,39 @@ def moe_fc_cuda(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
xtorch_ops.moe_fc(
|
||||
x,
|
||||
weight,
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
moe_topk,
|
||||
y)
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_topk,
|
||||
y=y,
|
||||
act=act,
|
||||
x_perchannel_max=x_perchannel_max,
|
||||
w_perchannel_max=w_perchannel_max,
|
||||
topk_ids=topk_ids,
|
||||
topk_w=topk_w,
|
||||
bias=bias,
|
||||
tgemm_type=tgemm_type,
|
||||
tweight_type=tweight_type,
|
||||
scale_n=scale_n,
|
||||
scale_k=scale_k,
|
||||
use_pack_int4=use_pack_int4,
|
||||
sort_mode=sort_mode)
|
||||
|
||||
def fake_moe_fc(
|
||||
x: torch.Tensor,
|
||||
@@ -1101,7 +1226,19 @@ def fake_moe_fc(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
return None
|
||||
|
||||
@@ -1151,6 +1288,63 @@ def fake_moe_post(
|
||||
moe_post.register_fake(fake_moe_post)
|
||||
|
||||
|
||||
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
|
||||
def moe_sigmoid_group_topk_norm(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
block_static=block_static,
|
||||
bias=bias,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
|
||||
def moe_sigmoid_group_topk_norm_cuda(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
block_static=block_static,
|
||||
bias=bias,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
def _fake_moe_sigmoid_group_topk_norm(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
|
||||
##################################################
|
||||
# --------------- awq_dequantize -----------------
|
||||
##################################################
|
||||
|
||||
Reference in New Issue
Block a user