[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -82,6 +82,9 @@ def register_model():
|
|||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"vllm_kunlun.models.llama:LlamaForCausalLM")
|
"vllm_kunlun.models.llama:LlamaForCausalLM")
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"MiMoV2FlashForCausalLM",
|
||||||
|
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
|
||||||
|
|
||||||
def register_quant_method():
|
def register_quant_method():
|
||||||
"""to do"""
|
"""to do"""
|
||||||
|
|||||||
706
vllm_kunlun/models/mimo_v2_flash.py
Normal file
706
vllm_kunlun/models/mimo_v2_flash.py
Normal file
@@ -0,0 +1,706 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm_kunlun.ops.attention.layer import Attention
|
||||||
|
from vllm.config import (
|
||||||
|
CacheConfig,
|
||||||
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
|
)
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_ep_group,
|
||||||
|
get_pp_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm_kunlun.ops.linear import QKVParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
maybe_remap_kv_scale_name,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
PPMissingLayer,
|
||||||
|
extract_layer_index,
|
||||||
|
is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory,
|
||||||
|
make_layers,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
from vllm_kunlun.ops.activation import SiluAndMul
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
reduce_results: bool = True,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||||
|
)
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
is_nextn: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_text_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.ep_group = get_ep_group().device_group
|
||||||
|
self.ep_rank = get_ep_group().rank_in_group
|
||||||
|
self.ep_size = self.ep_group.size()
|
||||||
|
self.n_routed_experts = config.n_routed_experts
|
||||||
|
|
||||||
|
|
||||||
|
if self.tp_size > config.n_routed_experts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
|
f"the number of experts {config.n_routed_experts}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.hidden_act != "silu":
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported activation: {config.hidden_act}. "
|
||||||
|
"Only silu is supported for now."
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
|
self.enable_eplb = parallel_config.enable_eplb
|
||||||
|
|
||||||
|
self.n_logical_experts = self.n_routed_experts
|
||||||
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||||
|
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||||
|
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||||
|
|
||||||
|
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||||
|
self.physical_expert_end = (
|
||||||
|
self.physical_expert_start + self.n_local_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate_dtype = torch.float32
|
||||||
|
self.gate = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.n_routed_experts,
|
||||||
|
bias=False,
|
||||||
|
dtype=self.gate_dtype,
|
||||||
|
)
|
||||||
|
self.gate.e_score_correction_bias = nn.Parameter(
|
||||||
|
torch.empty(config.n_routed_experts, dtype=self.gate_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=self.n_routed_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
reduce_results=True,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
scoring_func="sigmoid",
|
||||||
|
)
|
||||||
|
self.register_buffer("kunlun_linear_weights", torch.zeros(
|
||||||
|
config.num_local_experts,config.hidden_size,dtype=torch.float))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert hidden_states.dim() <= 2, "MiMoV2MoE only supports 1D or 2D inputs"
|
||||||
|
is_input_1d = hidden_states.dim() == 1
|
||||||
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
if self.gate_dtype is not None:
|
||||||
|
gate_input = hidden_states.to(self.gate_dtype)
|
||||||
|
else:
|
||||||
|
gate_input = hidden_states
|
||||||
|
router_logits = self.gate(gate_input)
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
v_head_dim: int | None = None,
|
||||||
|
sliding_window_size: int = -1,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
add_swa_attention_sink_bias: bool = False,
|
||||||
|
layer_id: int = 0,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
partial_rotary_factor: float = 1.0,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.layer_id = layer_id
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.v_head_dim = v_head_dim if v_head_dim is not None else head_dim
|
||||||
|
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.k_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.v_size = self.num_kv_heads * self.v_head_dim
|
||||||
|
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=attention_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
v_head_size=self.v_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.v_head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
reduce_results=True,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
partial_rotary_factor=partial_rotary_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_sink_bias = (
|
||||||
|
torch.nn.Parameter(torch.empty(self.num_heads), requires_grad=False)
|
||||||
|
if add_swa_attention_sink_bias
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
sliding_window = sliding_window_size if sliding_window_size > -1 else None
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
|
attn_type=AttentionType.DECODER,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
sinks=self.attention_sink_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
|
||||||
|
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
|
||||||
|
v = v.view(-1, self.num_kv_heads * self.head_dim)
|
||||||
|
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[
|
||||||
|
..., : self.v_head_dim
|
||||||
|
].reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2FlashDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_text_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
layer_id = extract_layer_index(prefix)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.config = config
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||||
|
|
||||||
|
if self.is_compressed_softmax_layer():
|
||||||
|
self.self_attn = MiMoV2Attention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.swa_num_attention_heads,
|
||||||
|
num_kv_heads=config.swa_num_key_value_heads,
|
||||||
|
head_dim=config.swa_head_dim,
|
||||||
|
v_head_dim=getattr(config, "swa_v_head_dim", None),
|
||||||
|
sliding_window_size=config.sliding_window_size,
|
||||||
|
attention_bias=config.attention_bias,
|
||||||
|
add_swa_attention_sink_bias=getattr(
|
||||||
|
config, "add_swa_attention_sink_bias", False
|
||||||
|
),
|
||||||
|
layer_id=layer_id,
|
||||||
|
rope_theta=getattr(config, "swa_rope_theta", rope_theta),
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = MiMoV2Attention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
v_head_dim=getattr(config, "v_head_dim", None),
|
||||||
|
sliding_window_size=-1, # normal attention
|
||||||
|
attention_bias=config.attention_bias,
|
||||||
|
layer_id=layer_id,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_layer_sparse = self.is_moe_layer(layer_id)
|
||||||
|
if self.is_layer_sparse:
|
||||||
|
self.mlp = MiMoV2MoE(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = MiMoV2MLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||||
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.layernorm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
def is_moe_layer(self, layer_idx: int) -> bool:
|
||||||
|
return (
|
||||||
|
hasattr(self.config, "moe_layer_freq")
|
||||||
|
and layer_idx >= 0
|
||||||
|
and not isinstance(self.config.moe_layer_freq, int)
|
||||||
|
and self.config.moe_layer_freq[layer_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_compressed_softmax_layer(self) -> bool:
|
||||||
|
return self.config.hybrid_layer_pattern[self.layer_id] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2Model(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||||
|
self.v_scale = getattr(config, "attention_value_scale", None)
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank or (
|
||||||
|
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||||
|
):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.embed_tokens",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: MiMoV2FlashDecoderLayer(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.embed_input_ids(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for idx, layer in enumerate(
|
||||||
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
|
):
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors(
|
||||||
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.n_routed_experts,
|
||||||
|
num_redundant_experts=self.num_redundant_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
|
continue
|
||||||
|
if "mtp" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.quant_config is not None:
|
||||||
|
cache_scale_name = self.quant_config.get_cache_scale(name)
|
||||||
|
if cache_scale_name is not None and cache_scale_name in params_dict:
|
||||||
|
param = params_dict[cache_scale_name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_scale = loaded_weight
|
||||||
|
if kv_scale.dim() > 0 and kv_scale.numel() > 1:
|
||||||
|
kv_scale = kv_scale.view(-1)[0]
|
||||||
|
|
||||||
|
weight_loader(param, kv_scale)
|
||||||
|
loaded_params.add(cache_scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert_matched = False
|
||||||
|
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name_rewritten = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name_rewritten, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias")
|
||||||
|
) and name_rewritten not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name_rewritten not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name_rewritten]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name_rewritten,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
loaded_params.add(name_rewritten)
|
||||||
|
expert_matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if expert_matched:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stacked_matched = False
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name_rewritten = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if (
|
||||||
|
name_rewritten.endswith(".bias")
|
||||||
|
and name_rewritten not in params_dict
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name_rewritten, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name_rewritten not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name_rewritten]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
|
||||||
|
if param_name == "qkv_proj" and shard_id == "v":
|
||||||
|
v_scale = (
|
||||||
|
self.v_scale
|
||||||
|
if self.v_scale is not None
|
||||||
|
else getattr(self.config, "attention_value_scale", None)
|
||||||
|
)
|
||||||
|
if v_scale is not None and (
|
||||||
|
name.endswith("weight_scale_inv") or name.endswith(".bias")
|
||||||
|
):
|
||||||
|
loaded_weight *= float(v_scale)
|
||||||
|
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
loaded_params.add(name_rewritten)
|
||||||
|
|
||||||
|
stacked_matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if stacked_matched:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
orig_name = name
|
||||||
|
mapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
name = mapped_name if mapped_name is not None else orig_name
|
||||||
|
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
|
||||||
|
if "attention_sink_bias" in name:
|
||||||
|
total_heads = loaded_weight.shape[0]
|
||||||
|
heads_per_rank = total_heads // tp_size
|
||||||
|
head_start = tp_rank * heads_per_rank
|
||||||
|
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
|
||||||
|
|
||||||
|
param.data.copy_(narrow_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
else:
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = MiMoV2Model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
||||||
@@ -15,8 +15,8 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
# import vllm_kunlun.ops.linear
|
|
||||||
import vllm_kunlun.ops.rotary_embedding
|
import vllm_kunlun.ops.rotary_embedding
|
||||||
import vllm_kunlun.ops.layernorm
|
import vllm_kunlun.ops.layernorm
|
||||||
import vllm_kunlun.ops.quantization.awq
|
import vllm_kunlun.ops.quantization.awq
|
||||||
import vllm_kunlun.ops.quantization.gptq
|
import vllm_kunlun.ops.quantization.gptq
|
||||||
|
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||||
@@ -1,3 +1,20 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# This file is a part of the vllm-kunlun project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""kunlun custom op entry"""
|
"""kunlun custom op entry"""
|
||||||
import torch_xmlir
|
import torch_xmlir
|
||||||
import torch
|
import torch
|
||||||
@@ -177,51 +194,10 @@ class KunlunOps:
|
|||||||
"""
|
"""
|
||||||
query_x = query.contiguous()
|
query_x = query.contiguous()
|
||||||
key_x = key.contiguous()
|
key_x = key.contiguous()
|
||||||
query_x_dim = query_x.dim()
|
|
||||||
if not is_neox_style:
|
|
||||||
if cos_sin_cache.dtype == torch.float16:
|
|
||||||
cos_sin_cache = cos_sin_cache.to(torch.float32)
|
|
||||||
positions = positions.to(torch.int)
|
|
||||||
if positions.dim() == 1:
|
|
||||||
positions = positions.unsqueeze(0)
|
|
||||||
query_x = query_x.unsqueeze(0)
|
|
||||||
key_x = key_x.unsqueeze(0)
|
|
||||||
|
|
||||||
xtorch_ops.rotary_embedding_gptj(
|
|
||||||
positions,
|
|
||||||
query_x,
|
|
||||||
key_x,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache)
|
|
||||||
query.data = query_x
|
|
||||||
key.data = key_x
|
|
||||||
if query_x_dim != query_x.dim():
|
|
||||||
query_x = query_x.unsqueeze(0)
|
|
||||||
key_x = key_x.unsqueeze(0)
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
# TODO: need opt
|
|
||||||
if cos_sin_cache.dim() == 4:
|
|
||||||
max_seq_len = cos_sin_cache.shape[2]
|
|
||||||
head_dim = cos_sin_cache.shape[3]
|
|
||||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
|
|
||||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
|
||||||
|
|
||||||
# 重塑 query 和 key 的形状
|
|
||||||
num_tokens = query_x.shape[0]
|
num_tokens = query_x.shape[0]
|
||||||
num_heads = query_x.shape[1] // head_size
|
num_heads = query_x.shape[1] // head_size
|
||||||
num_kv_heads = key_x.shape[1] // head_size
|
num_kv_heads = key_x.shape[1] // head_size
|
||||||
|
|
||||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
|
||||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
|
||||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
|
||||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
# # 确保形状正确
|
|
||||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
|
||||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
|
||||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
|
||||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
|
||||||
|
|
||||||
torch.ops._C.rotary_embedding(
|
torch.ops._C.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
@@ -234,8 +210,6 @@ class KunlunOps:
|
|||||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||||
|
|
||||||
# query.data = query_x
|
|
||||||
# key.data = key_x
|
|
||||||
return query_x, key_x
|
return query_x, key_x
|
||||||
|
|
||||||
# Rotary embedding
|
# Rotary embedding
|
||||||
@@ -433,6 +407,121 @@ class KunlunOps:
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def _dbg(x):
|
||||||
|
if torch.is_tensor(x):
|
||||||
|
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||||
|
return (type(x), x)
|
||||||
|
@staticmethod
|
||||||
|
def fused_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
linear_weights: torch.Tensor,
|
||||||
|
moe_top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
inplace: bool = False,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""fused_moe"""
|
||||||
|
global_num_experts = linear_weights.shape[0]
|
||||||
|
M, N = hidden_states.shape
|
||||||
|
hidden_dim = w2.shape[1]
|
||||||
|
normed_score = torch.empty(M,
|
||||||
|
moe_top_k,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
topk_ids = torch.empty(M,
|
||||||
|
moe_top_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
num_blocks = 12
|
||||||
|
block_statistic = torch.zeros(
|
||||||
|
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||||
|
x=router_logits,
|
||||||
|
topk_index=topk_ids,
|
||||||
|
norm_score=normed_score,
|
||||||
|
block_static=block_statistic,
|
||||||
|
bias=e_score_correction_bias,
|
||||||
|
scale=1.0,
|
||||||
|
n_group=num_expert_group,
|
||||||
|
topk_group=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||||
|
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||||
|
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||||
|
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
|
||||||
|
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||||
|
|
||||||
|
torch.ops._C.moe_pre_sorted(
|
||||||
|
x=hidden_states,
|
||||||
|
topk_index=topk_ids,
|
||||||
|
block_statistic=block_statistic,
|
||||||
|
moe_expand=moe_expand,
|
||||||
|
moe_index=sorted_tokens_idx,
|
||||||
|
expert_m=expert_m,
|
||||||
|
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||||
|
|
||||||
|
y = torch.empty(M,moe_top_k,
|
||||||
|
w1.shape[1],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
|
||||||
|
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||||
|
|
||||||
|
torch.ops._C.moe_fc(
|
||||||
|
x=moe_expand,
|
||||||
|
weight=w1,
|
||||||
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
|
moe_topk=moe_top_k,
|
||||||
|
y=y)
|
||||||
|
|
||||||
|
d = y.shape[-1] // 2
|
||||||
|
output_shape = (y.shape[:-1] + (d, ))
|
||||||
|
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||||
|
torch.ops._C.swiglu(y, out1)
|
||||||
|
|
||||||
|
out = torch.empty(M,moe_top_k,
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
|
||||||
|
out1 = out1.reshape(-1, out1.shape[-1])
|
||||||
|
|
||||||
|
torch.ops._C.moe_fc(
|
||||||
|
x=out1,
|
||||||
|
weight=w2,
|
||||||
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
|
moe_topk=moe_top_k,
|
||||||
|
y=out)
|
||||||
|
|
||||||
|
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||||
|
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||||
|
|
||||||
|
torch.ops._C.moe_post(
|
||||||
|
x=out,
|
||||||
|
moe_index=sorted_tokens_idx,
|
||||||
|
normed_scale=normed_score,
|
||||||
|
dequant_scale=dequant_scale,
|
||||||
|
y=output
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fused_moe_ep(
|
def fused_moe_ep(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -487,42 +576,6 @@ class KunlunOps:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fused_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
linear_weights: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
inplace: bool = False,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""fused_moe"""
|
|
||||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
expert_num = linear_weights.shape[0]
|
|
||||||
|
|
||||||
torch.ops._C.moe_ffn_block(
|
|
||||||
x=hidden_states,
|
|
||||||
gate_w=linear_weights,
|
|
||||||
inter_w=w1,
|
|
||||||
output_w=w2,
|
|
||||||
expert_num=expert_num,
|
|
||||||
moe_top_k=topk,
|
|
||||||
topk_group=topk_group,
|
|
||||||
renormalize=renormalize,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
expert_group_num=num_expert_group,
|
|
||||||
out=output,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fused_multi_head_latent_page_attention(
|
def fused_multi_head_latent_page_attention(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
linear_weights=linear_weights)
|
linear_weights=linear_weights,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
def forward_kunlun(
|
def forward_kunlun(
|
||||||
self,
|
self,
|
||||||
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""forward_kunlun"""
|
"""forward_kunlun"""
|
||||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||||
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group
|
topk_group=topk_group
|
||||||
)
|
)
|
||||||
# fused_moe do not support expert number > 400
|
|
||||||
elif layer.local_num_experts > 400:
|
|
||||||
hidden_states = x
|
|
||||||
global_num_experts = linear_weights.shape[0]
|
|
||||||
M, N = hidden_states.shape
|
|
||||||
hidden_dim = layer.w2_weight.shape[1]
|
|
||||||
normed_score = torch.empty(M,
|
|
||||||
top_k,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states.device)
|
|
||||||
topk_ids = torch.empty(M,
|
|
||||||
top_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=hidden_states.device)
|
|
||||||
num_blocks = 12
|
|
||||||
block_statistic = torch.zeros(
|
|
||||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
router_logits = router_logits.float()
|
|
||||||
torch.ops._C.moe_softmax_topk_norm(
|
|
||||||
x=router_logits,
|
|
||||||
normed_score=normed_score,
|
|
||||||
topk_index=topk_ids,
|
|
||||||
block_statistic=None,
|
|
||||||
stable=True)
|
|
||||||
|
|
||||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
|
||||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
|
||||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
|
||||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
|
||||||
|
|
||||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
|
||||||
|
|
||||||
torch.ops._C.moe_pre_sorted(
|
|
||||||
x=hidden_states,
|
|
||||||
topk_index=topk_ids,
|
|
||||||
block_statistic=block_statistic,
|
|
||||||
moe_expand=moe_expand,
|
|
||||||
moe_index=sorted_tokens_idx,
|
|
||||||
expert_m=expert_m,
|
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
|
||||||
|
|
||||||
y = torch.empty(M,top_k,
|
|
||||||
layer.w13_weight.shape[1],
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
|
|
||||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
|
||||||
|
|
||||||
torch.ops._C.moe_fc(
|
|
||||||
x=moe_expand,
|
|
||||||
weight=layer.w13_weight,
|
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
||||||
sorted_tokens_idx=sorted_tokens_idx,
|
|
||||||
moe_topk=top_k,
|
|
||||||
y=y)
|
|
||||||
|
|
||||||
d = y.shape[-1] // 2
|
|
||||||
output_shape = (y.shape[:-1] + (d, ))
|
|
||||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
|
||||||
torch.ops._C.swiglu(y, out1)
|
|
||||||
|
|
||||||
out = torch.empty(M,top_k,
|
|
||||||
layer.w2_weight.shape[1],
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
|
|
||||||
out1 = out1.reshape(-1, out1.shape[-1])
|
|
||||||
|
|
||||||
torch.ops._C.moe_fc(
|
|
||||||
x=out1,
|
|
||||||
weight=layer.w2_weight,
|
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
||||||
sorted_tokens_idx=sorted_tokens_idx,
|
|
||||||
moe_topk=top_k,
|
|
||||||
y=out)
|
|
||||||
|
|
||||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
|
||||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
|
||||||
|
|
||||||
torch.ops._C.moe_post(
|
|
||||||
x=out,
|
|
||||||
moe_index=sorted_tokens_idx,
|
|
||||||
normed_scale=normed_score,
|
|
||||||
dequant_scale=dequant_scale,
|
|
||||||
y=output
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
else:
|
else:
|
||||||
return ops.fused_moe(x,
|
return ops.fused_moe(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group
|
topk_group=topk_group,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
class FusedMoE(VllmFusedMoE):
|
class FusedMoE(VllmFusedMoE):
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ def vllm_kunlun_forward_cuda(
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||||
|
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||||
|
|
||||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -3,14 +3,30 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
WEIGHT_LOADER_V2_SUPPORTED,
|
WEIGHT_LOADER_V2_SUPPORTED,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
|
ColumnParallelLinear
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter
|
from vllm.model_executor.parameter import (
|
||||||
|
BasevLLMParameter,
|
||||||
|
BlockQuantScaleParameter,
|
||||||
|
PackedColumnParameter,
|
||||||
|
PackedvLLMParameter,
|
||||||
|
PerTensorScaleParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
)
|
||||||
|
from vllm.distributed import (
|
||||||
|
divide,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
split_tensor_along_last_dim,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -59,4 +75,361 @@ def create_weights(
|
|||||||
|
|
||||||
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
||||||
UnquantizedLinearMethod.create_weights = create_weights
|
UnquantizedLinearMethod.create_weights = create_weights
|
||||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||||
|
|
||||||
|
class QKVParallelLinear(ColumnParallelLinear):
|
||||||
|
"""
|
||||||
|
Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
head_size: int,
|
||||||
|
total_num_heads: int,
|
||||||
|
total_num_kv_heads: int | None = None,
|
||||||
|
bias: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: torch.dtype | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
|
v_head_size: int | None = None,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = head_size
|
||||||
|
self.v_head_size = v_head_size if v_head_size is not None else head_size
|
||||||
|
self.total_num_heads = total_num_heads
|
||||||
|
if total_num_kv_heads is None:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||||
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
|
if tp_size >= self.total_num_kv_heads:
|
||||||
|
self.num_kv_heads = 1
|
||||||
|
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
|
||||||
|
else:
|
||||||
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||||
|
self.num_kv_head_replicas = 1
|
||||||
|
input_size = self.hidden_size
|
||||||
|
output_size = (
|
||||||
|
self.num_heads * self.head_size
|
||||||
|
+ self.num_kv_heads * self.head_size
|
||||||
|
+ self.num_kv_heads * self.v_head_size
|
||||||
|
) * tp_size
|
||||||
|
self.output_sizes = [
|
||||||
|
self.num_heads * self.head_size * tp_size, # q_proj
|
||||||
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||||
|
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
input_size=input_size,
|
||||||
|
output_size=output_size,
|
||||||
|
bias=bias,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
|
shard_offset_mapping = {
|
||||||
|
"q": 0,
|
||||||
|
"k": self.num_heads * self.head_size,
|
||||||
|
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
||||||
|
"total": (self.num_heads + self.num_kv_heads) * self.head_size
|
||||||
|
+ self.num_kv_heads * self.v_head_size,
|
||||||
|
}
|
||||||
|
return shard_offset_mapping.get(loaded_shard_id)
|
||||||
|
|
||||||
|
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||||
|
shard_size_mapping = {
|
||||||
|
"q": self.num_heads * self.head_size,
|
||||||
|
"k": self.num_kv_heads * self.head_size,
|
||||||
|
"v": self.num_kv_heads * self.v_head_size,
|
||||||
|
}
|
||||||
|
return shard_size_mapping.get(loaded_shard_id)
|
||||||
|
|
||||||
|
def _load_fused_module_from_checkpoint(
|
||||||
|
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle special case for models where QKV layers are already
|
||||||
|
fused on disk. In this case, we have no shard id. This function
|
||||||
|
determines the shard id by splitting these layers and then calls
|
||||||
|
the weight loader using the shard id.
|
||||||
|
|
||||||
|
An example of a model with these fused layers:
|
||||||
|
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
||||||
|
"""
|
||||||
|
shard_offsets = [
|
||||||
|
# (shard_id, shard_offset, shard_size)
|
||||||
|
("q", 0, self.total_num_heads * self.head_size),
|
||||||
|
(
|
||||||
|
"k",
|
||||||
|
self.total_num_heads * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.head_size,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"v",
|
||||||
|
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.v_head_size,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
|
# Special case for Quantization.
|
||||||
|
# If quantized, we need to adjust the offset and size to account
|
||||||
|
# for the packing.
|
||||||
|
if (
|
||||||
|
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
|
||||||
|
and param.packed_dim == param.output_dim
|
||||||
|
):
|
||||||
|
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
||||||
|
shard_size=shard_size, shard_offset=shard_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
|
param.output_dim, shard_offset, shard_size
|
||||||
|
)
|
||||||
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||||
|
|
||||||
|
def weight_loader_v2(
|
||||||
|
self,
|
||||||
|
param: BasevLLMParameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: str | None = None,
|
||||||
|
):
|
||||||
|
if loaded_shard_id is None: # special case for certain models
|
||||||
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
|
param.load_qkv_weight(
|
||||||
|
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||||
|
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
|
||||||
|
return
|
||||||
|
# TODO: @dsikka - move to parameter.py
|
||||||
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
|
|
||||||
|
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||||
|
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||||
|
|
||||||
|
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||||
|
if isinstance(param, BlockQuantScaleParameter):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
# Assume the weight block size has been set by quant method
|
||||||
|
assert hasattr(self, "weight_block_size")
|
||||||
|
weight_block_size = self.weight_block_size
|
||||||
|
assert weight_block_size is not None
|
||||||
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||||
|
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||||
|
shard_size = (shard_size + block_n - 1) // block_n
|
||||||
|
|
||||||
|
param.load_qkv_weight(
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
num_heads=self.num_kv_head_replicas,
|
||||||
|
shard_id=loaded_shard_id,
|
||||||
|
shard_offset=shard_offset,
|
||||||
|
shard_size=shard_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def weight_loader(
|
||||||
|
self,
|
||||||
|
param: Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: str | None = None,
|
||||||
|
):
|
||||||
|
# Special case for GGUF
|
||||||
|
# initialize GGUF param after we know the quantize type
|
||||||
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||||
|
if is_gguf_weight_type:
|
||||||
|
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||||
|
if loaded_shard_id is not None:
|
||||||
|
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||||
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||||
|
else:
|
||||||
|
param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_gguf_weight:
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
|
||||||
|
if loaded_shard_id is not None:
|
||||||
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
|
param.shard_id.append(loaded_shard_id)
|
||||||
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||||
|
param.data_container.append(loaded_weight)
|
||||||
|
return
|
||||||
|
|
||||||
|
param_data = param.data
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
|
# Special case for per-tensor scales in fused case.
|
||||||
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||||
|
|
||||||
|
if loaded_shard_id is None:
|
||||||
|
# Loaded weight is already fused on disk (qkv).
|
||||||
|
# (e.g., Phi-3's qkv_proj).
|
||||||
|
if output_dim is None:
|
||||||
|
if needs_scalar_to_array:
|
||||||
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
|
param_data, loaded_weight, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
return
|
||||||
|
shard_offsets = [
|
||||||
|
# (shard_id, shard_offset, shard_size)
|
||||||
|
("q", 0, self.total_num_heads * self.head_size),
|
||||||
|
(
|
||||||
|
"k",
|
||||||
|
self.total_num_heads * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.head_size,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"v",
|
||||||
|
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.v_head_size,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
|
||||||
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
|
# Special case for Quantized Weights.
|
||||||
|
# If quantized, we need to adjust the offset and size to account
|
||||||
|
# for the packing.
|
||||||
|
if packed_dim == output_dim:
|
||||||
|
shard_size = shard_size // param.packed_factor
|
||||||
|
shard_offset = shard_offset // param.packed_factor
|
||||||
|
|
||||||
|
# Special case for Marlin.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_bitsandbytes_4bit:
|
||||||
|
orig_qkv_offsets = {
|
||||||
|
"q": (0, self.total_num_heads * self.head_size),
|
||||||
|
"k": (
|
||||||
|
self.total_num_heads * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.head_size,
|
||||||
|
),
|
||||||
|
"v": (
|
||||||
|
(self.total_num_heads + self.total_num_kv_heads)
|
||||||
|
* self.head_size,
|
||||||
|
self.total_num_kv_heads * self.v_head_size,
|
||||||
|
),
|
||||||
|
"total": (
|
||||||
|
(self.total_num_heads + self.total_num_kv_heads)
|
||||||
|
* self.head_size
|
||||||
|
+ self.total_num_kv_heads * self.v_head_size,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_qkv_offsets, shard_id
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
|
output_dim, shard_offset, shard_size
|
||||||
|
)
|
||||||
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
|
|
||||||
|
# If output dim is defined, use the default loading process.
|
||||||
|
if output_dim is not None:
|
||||||
|
if loaded_shard_id == "q":
|
||||||
|
shard_offset = 0
|
||||||
|
shard_size = self.num_heads * self.head_size
|
||||||
|
elif loaded_shard_id == "k":
|
||||||
|
shard_offset = self.num_heads * self.head_size
|
||||||
|
shard_size = self.num_kv_heads * self.head_size
|
||||||
|
elif loaded_shard_id == "v":
|
||||||
|
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||||
|
shard_size = self.num_kv_heads * self.v_head_size
|
||||||
|
# Special case for Quantized Weights.
|
||||||
|
# If quantized, we need to adjust the offset and size to account
|
||||||
|
# for the packing.
|
||||||
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
|
if packed_dim == output_dim:
|
||||||
|
shard_size = shard_size // param.packed_factor
|
||||||
|
shard_offset = shard_offset // param.packed_factor
|
||||||
|
|
||||||
|
# Special case for Marlin.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||||
|
# bitsandbytes loads the weights of the specific portion
|
||||||
|
# no need to narrow
|
||||||
|
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||||
|
|
||||||
|
if use_bitsandbytes_4bit:
|
||||||
|
orig_qkv_offsets = {
|
||||||
|
"q": (0, self.num_heads * self.head_size),
|
||||||
|
"k": (
|
||||||
|
self.num_heads * self.head_size,
|
||||||
|
self.num_kv_heads * self.head_size,
|
||||||
|
),
|
||||||
|
"v": (
|
||||||
|
(self.num_heads + self.num_kv_heads) * self.head_size,
|
||||||
|
self.num_kv_heads * self.v_head_size,
|
||||||
|
),
|
||||||
|
"total": (
|
||||||
|
(self.num_heads + self.num_kv_heads) * self.head_size
|
||||||
|
+ self.num_kv_heads * self.v_head_size,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_qkv_offsets, loaded_shard_id
|
||||||
|
)
|
||||||
|
|
||||||
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||||
|
if loaded_shard_id == "q":
|
||||||
|
shard_rank = self.tp_rank
|
||||||
|
else:
|
||||||
|
shard_rank = self.tp_rank // self.num_kv_head_replicas
|
||||||
|
start_idx = shard_rank * shard_size
|
||||||
|
|
||||||
|
if not is_sharded_weight:
|
||||||
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
|
|
||||||
|
# Special case for per-tensor scales in fused case.
|
||||||
|
elif needs_scalar_to_array:
|
||||||
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
|
param_data, loaded_weight, loaded_shard_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
|
if not ignore_warning:
|
||||||
|
logger.warning(
|
||||||
|
"Loading a weight without `output_dim` attribute in "
|
||||||
|
"QKVParallelLinear, assume the weight is the same "
|
||||||
|
"for all partitions."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,8 @@ from typing import List, Optional, Tuple
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
if current_platform.is_kunlun():
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.triton_utils.importing import HAS_TRITON
|
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
|
||||||
|
|
||||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def vllm_kunlun_forward_cuda(
|
|||||||
self.is_neox_style, self.rotary_dim,
|
self.is_neox_style, self.rotary_dim,
|
||||||
offsets)
|
offsets)
|
||||||
else:
|
else:
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@@ -143,14 +143,11 @@ def vllm_kunlun_mrope_forward_cuda(
|
|||||||
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||||
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||||
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
|
||||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||||
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
|
||||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||||
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
|
||||||
|
|
||||||
|
|
||||||
def Split_Norm_Rope(
|
def Split_Norm_Rope(
|
||||||
|
|||||||
@@ -1,143 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
||||||
|
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
|
||||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
|
||||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
|
||||||
from vllm.model_executor.parameter import BasevLLMParameter
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
|
||||||
"""Unquantized method for embeddings."""
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: list[int], input_size: int,
|
|
||||||
output_size: int, params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs):
|
|
||||||
"""Create weights for embedding layer."""
|
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=params_dtype),
|
|
||||||
requires_grad=False)
|
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
|
||||||
|
|
||||||
def embedding(self, layer: torch.nn.Module,
|
|
||||||
input_: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.embedding(input_, layer.weight)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_vocab_size(vocab_size: int,
|
|
||||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
|
||||||
"""Pad the vocab size to the given value."""
|
|
||||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
|
||||||
|
|
||||||
|
|
||||||
def vocab_range_from_per_partition_vocab_size(
|
|
||||||
per_partition_vocab_size: int,
|
|
||||||
rank: int,
|
|
||||||
offset: int = 0) -> Sequence[int]:
|
|
||||||
index_f = rank * per_partition_vocab_size
|
|
||||||
index_l = index_f + per_partition_vocab_size
|
|
||||||
return index_f + offset, index_l + offset
|
|
||||||
|
|
||||||
|
|
||||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
offset: int = 0) -> Sequence[int]:
|
|
||||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
|
||||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
|
||||||
rank,
|
|
||||||
offset=offset)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VocabParallelEmbeddingShardIndices:
|
|
||||||
"""Indices for a shard of a vocab parallel embedding."""
|
|
||||||
padded_org_vocab_start_index: int
|
|
||||||
padded_org_vocab_end_index: int
|
|
||||||
padded_added_vocab_start_index: int
|
|
||||||
padded_added_vocab_end_index: int
|
|
||||||
|
|
||||||
org_vocab_start_index: int
|
|
||||||
org_vocab_end_index: int
|
|
||||||
added_vocab_start_index: int
|
|
||||||
added_vocab_end_index: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_org_elements(self) -> int:
|
|
||||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_added_elements(self) -> int:
|
|
||||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_org_elements_padded(self) -> int:
|
|
||||||
return (self.padded_org_vocab_end_index -
|
|
||||||
self.padded_org_vocab_start_index)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_added_elements_padded(self) -> int:
|
|
||||||
return (self.padded_added_vocab_end_index -
|
|
||||||
self.padded_added_vocab_start_index)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_org_vocab_padding(self) -> int:
|
|
||||||
return self.num_org_elements_padded - self.num_org_elements
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_added_vocab_padding(self) -> int:
|
|
||||||
return self.num_added_elements_padded - self.num_added_elements
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_elements_padded(self) -> int:
|
|
||||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# sanity checks
|
|
||||||
assert (self.padded_org_vocab_start_index
|
|
||||||
<= self.padded_org_vocab_end_index)
|
|
||||||
assert (self.padded_added_vocab_start_index
|
|
||||||
<= self.padded_added_vocab_end_index)
|
|
||||||
|
|
||||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
|
||||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
|
||||||
|
|
||||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
|
||||||
assert (self.added_vocab_start_index
|
|
||||||
<= self.padded_added_vocab_start_index)
|
|
||||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
|
||||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
|
||||||
|
|
||||||
assert self.num_org_elements <= self.num_org_elements_padded
|
|
||||||
assert self.num_added_elements <= self.num_added_elements_padded
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend="aot_eager")
|
@torch.compile(dynamic=True, backend="aot_eager")
|
||||||
def get_masked_input_and_mask(
|
def get_masked_input_and_mask(
|
||||||
@@ -159,319 +27,25 @@ def get_masked_input_and_mask(
|
|||||||
input_ = vocab_mask * (input_ - valid_offset)
|
input_ = vocab_mask * (input_ - valid_offset)
|
||||||
return input_, ~vocab_mask
|
return input_, ~vocab_mask
|
||||||
|
|
||||||
|
def forward_native_kunlun(self, input_):
|
||||||
|
if self.tp_size > 1:
|
||||||
|
# Build the mask.
|
||||||
|
masked_input, input_mask = get_masked_input_and_mask(
|
||||||
|
input_, self.shard_indices.org_vocab_start_index,
|
||||||
|
self.shard_indices.org_vocab_end_index,
|
||||||
|
self.shard_indices.num_org_vocab_padding,
|
||||||
|
self.shard_indices.added_vocab_start_index,
|
||||||
|
self.shard_indices.added_vocab_end_index)
|
||||||
|
else:
|
||||||
|
masked_input = input_
|
||||||
|
# Get the embeddings.
|
||||||
|
output_parallel = self.quant_method.embedding(self,
|
||||||
|
masked_input.long())
|
||||||
|
# Mask the output embedding.
|
||||||
|
if self.tp_size > 1:
|
||||||
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||||
|
# Reduce across all the model parallel GPUs.
|
||||||
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
return output
|
||||||
|
|
||||||
@CustomOp.register("vllm_kunlun_vocab_parallel_embedding")
|
VocabParallelEmbedding.forward_native = forward_native_kunlun
|
||||||
class VocabParallelEmbedding(CustomOp):
|
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
|
||||||
|
|
||||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
|
||||||
make sure it is divisible by the number of model parallel GPUs.
|
|
||||||
|
|
||||||
In order to support various loading methods, we ensure that LoRA-added
|
|
||||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
|
||||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
|
||||||
and place them in the same tensor.
|
|
||||||
In this example, we will have the original vocab size = 1010,
|
|
||||||
added vocab size = 16 and padding to 64. Therefore, the total
|
|
||||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
|
||||||
1024, add 16, and then pad to 1088).
|
|
||||||
Therefore, the tensor format looks like the following:
|
|
||||||
TP1, rank 0 (no sharding):
|
|
||||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
|
||||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
|
|
||||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
|
||||||
|
|
||||||
TP2, rank 0:
|
|
||||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
|
||||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
|
|
||||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
|
||||||
TP2, rank 1:
|
|
||||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
|
||||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
|
||||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_embeddings: vocabulary size.
|
|
||||||
embedding_dim: size of hidden state.
|
|
||||||
params_dtype: type of the parameters.
|
|
||||||
org_num_embeddings: original vocabulary size (without LoRA).
|
|
||||||
padding_size: padding size for the vocabulary.
|
|
||||||
quant_config: quant config for the layer
|
|
||||||
prefix: full name of the layer in the state dict
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
org_num_embeddings: Optional[int] = None,
|
|
||||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Keep the input dimensions.
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.padding_size = padding_size
|
|
||||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
|
||||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
|
||||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
|
||||||
self.padding_size)
|
|
||||||
self.num_embeddings_padded = pad_vocab_size(
|
|
||||||
self.org_vocab_size_padded + num_added_embeddings,
|
|
||||||
self.padding_size)
|
|
||||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
|
||||||
|
|
||||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
|
||||||
self.org_vocab_size_padded,
|
|
||||||
self.num_embeddings,
|
|
||||||
self.org_vocab_size, tp_rank,
|
|
||||||
self.tp_size)
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
|
|
||||||
quant_method = None
|
|
||||||
if quant_config is not None:
|
|
||||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
|
||||||
if quant_method is None:
|
|
||||||
quant_method = UnquantizedEmbeddingMethod()
|
|
||||||
|
|
||||||
# If we are making an embedding layer, then our quantization linear
|
|
||||||
# method must implement the embedding operation. If we are another
|
|
||||||
# layer type like ParallelLMHead, this is not important.
|
|
||||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
|
||||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
|
||||||
type(quant_method))
|
|
||||||
if is_embedding_layer and not quant_method_implements_embedding:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"The class {type(quant_method).__name__} must implement "
|
|
||||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
|
||||||
|
|
||||||
self.quant_method: QuantizeMethodBase = quant_method
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
# Divide the weight matrix along the vocaburaly dimension.
|
|
||||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
|
||||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
|
||||||
self.tp_size)
|
|
||||||
assert (self.shard_indices.num_elements_padded ==
|
|
||||||
self.num_embeddings_per_partition)
|
|
||||||
self.num_org_embeddings_per_partition = (
|
|
||||||
self.shard_indices.org_vocab_end_index -
|
|
||||||
self.shard_indices.org_vocab_start_index)
|
|
||||||
self.num_added_embeddings_per_partition = (
|
|
||||||
self.shard_indices.added_vocab_end_index -
|
|
||||||
self.shard_indices.added_vocab_start_index)
|
|
||||||
|
|
||||||
self.quant_method.create_weights(self,
|
|
||||||
self.embedding_dim,
|
|
||||||
[self.num_embeddings_per_partition],
|
|
||||||
self.embedding_dim,
|
|
||||||
self.num_embeddings_padded,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
weight_loader=self.weight_loader)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
|
||||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
|
||||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
|
||||||
"""Get start and end indices for vocab parallel embedding, following the
|
|
||||||
layout outlined in the class docstring, based on the given tp_rank and
|
|
||||||
tp_size."""
|
|
||||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
|
||||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
|
||||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
|
||||||
tp_size))
|
|
||||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
|
||||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
|
||||||
tp_rank,
|
|
||||||
tp_size,
|
|
||||||
offset=org_vocab_size))
|
|
||||||
# remove padding
|
|
||||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
|
||||||
org_vocab_size)
|
|
||||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
|
||||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
|
||||||
vocab_size)
|
|
||||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
|
||||||
return VocabParallelEmbeddingShardIndices(
|
|
||||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
|
||||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
|
||||||
org_vocab_start_index, org_vocab_end_index,
|
|
||||||
added_vocab_start_index, added_vocab_end_index)
|
|
||||||
|
|
||||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
|
||||||
"""Get a mapping that can be used to reindex the gathered
|
|
||||||
logits for sampling.
|
|
||||||
|
|
||||||
During sampling, we gather logits from all ranks. The relationship
|
|
||||||
of index->token_id will follow the same format as outlined in the class
|
|
||||||
docstring. However, after the gather, we want to reindex the final
|
|
||||||
logits tensor to map index->token_id one-to-one (the index is always
|
|
||||||
equal the token_id it corresponds to). The indices returned by this
|
|
||||||
method allow us to do that.
|
|
||||||
"""
|
|
||||||
if self.tp_size < 2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
base_embeddings: list[int] = []
|
|
||||||
added_embeddings: list[int] = []
|
|
||||||
padding: list[int] = []
|
|
||||||
for tp_rank in range(self.tp_size):
|
|
||||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
|
||||||
self.org_vocab_size_padded,
|
|
||||||
self.num_embeddings,
|
|
||||||
self.org_vocab_size, tp_rank,
|
|
||||||
self.tp_size)
|
|
||||||
range_start = self.num_embeddings_per_partition * tp_rank
|
|
||||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
|
||||||
base_embeddings.extend(
|
|
||||||
range(range_start,
|
|
||||||
range_start + shard_indices.num_org_elements))
|
|
||||||
padding.extend(
|
|
||||||
range(range_start + shard_indices.num_org_elements,
|
|
||||||
range_start + shard_indices.num_org_elements_padded))
|
|
||||||
added_embeddings.extend(
|
|
||||||
range(
|
|
||||||
range_start + shard_indices.num_org_elements_padded,
|
|
||||||
range_start + shard_indices.num_org_elements_padded +
|
|
||||||
shard_indices.num_added_elements))
|
|
||||||
padding.extend(
|
|
||||||
range(
|
|
||||||
range_start + shard_indices.num_org_elements_padded +
|
|
||||||
shard_indices.num_added_elements,
|
|
||||||
range_start + shard_indices.num_org_elements_padded +
|
|
||||||
shard_indices.num_added_elements_padded))
|
|
||||||
assert (range_start + shard_indices.num_org_elements_padded +
|
|
||||||
shard_indices.num_added_elements_padded == range_end)
|
|
||||||
ret = base_embeddings + added_embeddings + padding
|
|
||||||
assert len(ret) == self.num_embeddings_padded
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
|
||||||
|
|
||||||
# If the parameter is a gguf weight, then load it directly.
|
|
||||||
if getattr(param, "is_gguf_weight_type", None):
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
param.weight_type = loaded_weight.item()
|
|
||||||
return
|
|
||||||
elif isinstance(param, UninitializedParameter):
|
|
||||||
shape = list(loaded_weight.shape)
|
|
||||||
if output_dim is not None:
|
|
||||||
shape[output_dim] = self.num_embeddings_per_partition
|
|
||||||
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
|
||||||
|
|
||||||
# If parameter does not have output dim, then it should
|
|
||||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
|
||||||
if output_dim is None:
|
|
||||||
assert param.data.shape == loaded_weight.shape
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Shard indexes for loading the weight
|
|
||||||
start_idx = self.shard_indices.org_vocab_start_index
|
|
||||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
|
||||||
|
|
||||||
# If param packed on the same dim we are sharding on, then
|
|
||||||
# need to adjust offsets of loaded weight by pack_factor.
|
|
||||||
if packed_dim is not None and packed_dim == output_dim:
|
|
||||||
packed_factor = param.packed_factor if isinstance(
|
|
||||||
param, BasevLLMParameter) else param.pack_factor
|
|
||||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
|
||||||
param.packed_factor)
|
|
||||||
start_idx = start_idx // packed_factor
|
|
||||||
shard_size = shard_size // packed_factor
|
|
||||||
else:
|
|
||||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
|
||||||
|
|
||||||
# Copy the data. Select chunk corresponding to current shard.
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
|
||||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
|
||||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
if self.tp_size > 1:
|
|
||||||
# Build the mask.
|
|
||||||
masked_input, input_mask = get_masked_input_and_mask(
|
|
||||||
input_, self.shard_indices.org_vocab_start_index,
|
|
||||||
self.shard_indices.org_vocab_end_index,
|
|
||||||
self.shard_indices.num_org_vocab_padding,
|
|
||||||
self.shard_indices.added_vocab_start_index,
|
|
||||||
self.shard_indices.added_vocab_end_index)
|
|
||||||
else:
|
|
||||||
masked_input = input_
|
|
||||||
# Get the embeddings.
|
|
||||||
output_parallel = self.quant_method.embedding(self,
|
|
||||||
masked_input.long())
|
|
||||||
# Mask the output embedding.
|
|
||||||
if self.tp_size > 1:
|
|
||||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
|
||||||
# Reduce across all the model parallel GPUs.
|
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
|
||||||
s += f", embedding_dim={self.embedding_dim}"
|
|
||||||
s += f", org_vocab_size={self.org_vocab_size}"
|
|
||||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
|
||||||
s += f', tp_size={self.tp_size}'
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelLMHead(VocabParallelEmbedding):
|
|
||||||
"""Parallelized LM head.
|
|
||||||
|
|
||||||
Output logits weight matrices used in the Sampler. The weight and bias
|
|
||||||
tensors are padded to make sure they are divisible by the number of
|
|
||||||
model parallel GPUs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_embeddings: vocabulary size.
|
|
||||||
embedding_dim: size of hidden state.
|
|
||||||
bias: whether to use bias.
|
|
||||||
params_dtype: type of the parameters.
|
|
||||||
org_num_embeddings: original vocabulary size (without LoRA).
|
|
||||||
padding_size: padding size for the vocabulary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
bias: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
org_num_embeddings: Optional[int] = None,
|
|
||||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = ""):
|
|
||||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
|
||||||
org_num_embeddings, padding_size, quant_config,
|
|
||||||
prefix)
|
|
||||||
self.quant_config = quant_config
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(
|
|
||||||
torch.empty(self.num_embeddings_per_partition,
|
|
||||||
dtype=params_dtype))
|
|
||||||
set_weight_attrs(self.bias, {
|
|
||||||
"output_dim": 0,
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
|
||||||
"""Tie the weights with word embeddings."""
|
|
||||||
# GGUF quantized embed_tokens.
|
|
||||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
|
||||||
return embed_tokens
|
|
||||||
else:
|
|
||||||
self.weight = embed_tokens.weight
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
del input_
|
|
||||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
|
||||||
@@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# [4, 6], it is [0, 4, 10].
|
# [4, 6], it is [0, 4, 10].
|
||||||
seq_start_loc: Optional[torch.Tensor] = None
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Prefix cache loc
|
# Prefix cache loc
|
||||||
kv_lod_cpu: Optional[torch.Tensor] = None
|
kv_lod_cpu: Optional[torch.Tensor] = None
|
||||||
kv_lod_xpu: Optional[torch.Tensor] = None
|
kv_lod_xpu: Optional[torch.Tensor] = None
|
||||||
@@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"kunlunAttention does not support block-sparse attention.")
|
"kunlunAttention does not support block-sparse attention.")
|
||||||
# if logits_soft_cap is not None:
|
|
||||||
# raise ValueError(
|
|
||||||
# "kunlunAttention does not support attention logits soft capping.")
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||||
|
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||||
|
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||||
|
|
||||||
|
# For hybrid Attention (Qwen3-Next.)
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
tmp_block_tables = prefill_meta.block_tables
|
tmp_block_tables = prefill_meta.block_tables
|
||||||
else:
|
else:
|
||||||
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
|
# For hybrid Attention (Qwen3-Next)
|
||||||
|
tmp_block_tables = prefill_meta.block_tables * 2
|
||||||
xtorch_ops.prefill_attention(
|
|
||||||
q=prefill_query,
|
# Prefix cache
|
||||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||||
v=value_cache,
|
xtorch_ops.prefill_attention(
|
||||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
q=prefill_query,
|
||||||
is_causal=True,
|
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||||
is_prefix_cache=True,
|
v=value_cache,
|
||||||
block_table=tmp_block_tables,
|
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
is_causal=True,
|
||||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
is_prefix_cache=True,
|
||||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
block_table=tmp_block_tables,
|
||||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||||
alibi_slopes=self.alibi_slopes,
|
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||||
softmax_lse=None,
|
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||||
sink=self.sinks
|
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||||
)
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
softmax_lse=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
xtorch_ops.prefill_attention(
|
||||||
|
q=prefill_query,
|
||||||
|
k=prefill_key,
|
||||||
|
v=prefill_value,
|
||||||
|
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||||
|
is_causal=True,
|
||||||
|
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||||
|
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
softmax_lse=None,
|
||||||
|
swa_left = self.sliding_window if self.sliding_window is not None else -1,
|
||||||
|
swa_right = 0 if self.sliding_window is not None else -1,
|
||||||
|
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||||
"Encoder-only models should not have decode metadata.")
|
"Encoder-only models should not have decode metadata.")
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
|
|
||||||
|
# For hybrid Attention (Qwen3-Next
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
tmp_block_tables = decode_meta.block_tables
|
tmp_block_tables = decode_meta.block_tables
|
||||||
else:
|
else:
|
||||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||||
|
|
||||||
xtorch_ops.paged_attention(
|
xtorch_ops.speculative_attention(
|
||||||
x=decode_query,
|
|
||||||
k_cache=key_cache,
|
|
||||||
v_cache=value_cache,
|
|
||||||
block_tables=tmp_block_tables,
|
|
||||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
|
||||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
|
||||||
is_context=False,
|
|
||||||
is_causal=True,
|
|
||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
vo_head_dim=self.head_size
|
# Only MLA support q len > 1 right now
|
||||||
)
|
q=decode_query.unsqueeze(0),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||||
|
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||||
|
batch_num=decode_meta.block_tables.shape[0],
|
||||||
|
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||||
|
qlen=1,
|
||||||
|
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||||
|
max_context_len=131072,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
head_dim=self.head_size,
|
||||||
|
scale=0.0,
|
||||||
|
kv_head_num=self.num_kv_heads,
|
||||||
|
block_size=key_cache.shape[2],
|
||||||
|
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||||
|
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||||
|
block_tables=tmp_block_tables,
|
||||||
|
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||||
|
)
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(-1, self.num_heads * self.head_size)
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
def use_cascade_attention(
|
def use_cascade_attention(
|
||||||
@@ -788,4 +817,4 @@ def use_cascade_attention(
|
|||||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||||
|
|
||||||
# Use cascade attention if it is faster than FlashDecoding.
|
# Use cascade attention if it is faster than FlashDecoding.
|
||||||
return cascade_time < flash_decoding_time
|
return cascade_time < flash_decoding_time
|
||||||
@@ -938,6 +938,83 @@ def _fake_rotary_embedding(
|
|||||||
|
|
||||||
rotary_embedding.register_fake(_fake_rotary_embedding)
|
rotary_embedding.register_fake(_fake_rotary_embedding)
|
||||||
|
|
||||||
|
@custom_op("_C::quant2d", mutates_args=())
|
||||||
|
def quant2d(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
max: torch.Tensor,
|
||||||
|
force_sdnn: bool,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.quant2d(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
max=max,
|
||||||
|
force_sdnn=force_sdnn
|
||||||
|
)
|
||||||
|
|
||||||
|
@impl("_C::quant2d", "CUDA")
|
||||||
|
def quant2d_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
max: torch.Tensor,
|
||||||
|
force_sdnn: bool,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.quant2d(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
max=max,
|
||||||
|
force_sdnn=force_sdnn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fake_quant2d(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
max: torch.Tensor,
|
||||||
|
force_sdnn: bool,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
quant2d.register_fake(_fake_quant2d)
|
||||||
|
|
||||||
|
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
|
||||||
|
def gemm_I8_I8_bf16_nt(
|
||||||
|
x_q: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||||
|
lhs=(x_q, x_scale),
|
||||||
|
rhs=(weight, weight_scale),
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
|
||||||
|
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
|
||||||
|
def gemm_I8_I8_bf16_nt_cuda(
|
||||||
|
x_q: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||||
|
lhs=(x_q, x_scale),
|
||||||
|
rhs=(weight, weight_scale),
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fake_gemm_I8_I8_bf16_nt(
|
||||||
|
x_q: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
|
||||||
|
|
||||||
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
|
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
|
||||||
def moe_softmax_topk_norm(
|
def moe_softmax_topk_norm(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -1068,15 +1145,39 @@ def moe_fc(
|
|||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
sorted_tokens_idx: torch.Tensor,
|
sorted_tokens_idx: torch.Tensor,
|
||||||
moe_topk: int,
|
moe_topk: int,
|
||||||
y: torch.Tensor
|
y: torch.Tensor,
|
||||||
|
act: Optional[str] = None,
|
||||||
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||||
|
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||||
|
topk_ids: Optional[torch.Tensor] = None,
|
||||||
|
topk_w: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tgemm_type: Optional[str] = None,
|
||||||
|
tweight_type: Optional[str] = None,
|
||||||
|
scale_n: Optional[int] = 0,
|
||||||
|
scale_k: Optional[int] = 0,
|
||||||
|
use_pack_int4: Optional[bool] = False,
|
||||||
|
sort_mode: Optional[bool] = True
|
||||||
)-> None:
|
)-> None:
|
||||||
xtorch_ops.moe_fc(
|
xtorch_ops.moe_fc(
|
||||||
x,
|
x=x,
|
||||||
weight,
|
weight=weight,
|
||||||
sorted_tokens_num_lod,
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
sorted_tokens_idx,
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
moe_topk,
|
moe_topk=moe_topk,
|
||||||
y)
|
y=y,
|
||||||
|
act=act,
|
||||||
|
x_perchannel_max=x_perchannel_max,
|
||||||
|
w_perchannel_max=w_perchannel_max,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
topk_w=topk_w,
|
||||||
|
bias=bias,
|
||||||
|
tgemm_type=tgemm_type,
|
||||||
|
tweight_type=tweight_type,
|
||||||
|
scale_n=scale_n,
|
||||||
|
scale_k=scale_k,
|
||||||
|
use_pack_int4=use_pack_int4,
|
||||||
|
sort_mode=sort_mode)
|
||||||
|
|
||||||
@impl("_C::moe_fc", "CUDA")
|
@impl("_C::moe_fc", "CUDA")
|
||||||
def moe_fc_cuda(
|
def moe_fc_cuda(
|
||||||
@@ -1085,15 +1186,39 @@ def moe_fc_cuda(
|
|||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
sorted_tokens_idx: torch.Tensor,
|
sorted_tokens_idx: torch.Tensor,
|
||||||
moe_topk: int,
|
moe_topk: int,
|
||||||
y: torch.Tensor
|
y: torch.Tensor,
|
||||||
|
act: Optional[str] = None,
|
||||||
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||||
|
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||||
|
topk_ids: Optional[torch.Tensor] = None,
|
||||||
|
topk_w: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tgemm_type: Optional[str] = None,
|
||||||
|
tweight_type: Optional[str] = None,
|
||||||
|
scale_n: Optional[int] = 0,
|
||||||
|
scale_k: Optional[int] = 0,
|
||||||
|
use_pack_int4: Optional[bool] = False,
|
||||||
|
sort_mode: Optional[bool] = True
|
||||||
)-> None:
|
)-> None:
|
||||||
xtorch_ops.moe_fc(
|
xtorch_ops.moe_fc(
|
||||||
x,
|
x=x,
|
||||||
weight,
|
weight=weight,
|
||||||
sorted_tokens_num_lod,
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
sorted_tokens_idx,
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
moe_topk,
|
moe_topk=moe_topk,
|
||||||
y)
|
y=y,
|
||||||
|
act=act,
|
||||||
|
x_perchannel_max=x_perchannel_max,
|
||||||
|
w_perchannel_max=w_perchannel_max,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
topk_w=topk_w,
|
||||||
|
bias=bias,
|
||||||
|
tgemm_type=tgemm_type,
|
||||||
|
tweight_type=tweight_type,
|
||||||
|
scale_n=scale_n,
|
||||||
|
scale_k=scale_k,
|
||||||
|
use_pack_int4=use_pack_int4,
|
||||||
|
sort_mode=sort_mode)
|
||||||
|
|
||||||
def fake_moe_fc(
|
def fake_moe_fc(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -1101,7 +1226,19 @@ def fake_moe_fc(
|
|||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
sorted_tokens_idx: torch.Tensor,
|
sorted_tokens_idx: torch.Tensor,
|
||||||
moe_topk: int,
|
moe_topk: int,
|
||||||
y: torch.Tensor
|
y: torch.Tensor,
|
||||||
|
act: Optional[str] = None,
|
||||||
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||||
|
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||||
|
topk_ids: Optional[torch.Tensor] = None,
|
||||||
|
topk_w: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tgemm_type: Optional[str] = None,
|
||||||
|
tweight_type: Optional[str] = None,
|
||||||
|
scale_n: Optional[int] = 0,
|
||||||
|
scale_k: Optional[int] = 0,
|
||||||
|
use_pack_int4: Optional[bool] = False,
|
||||||
|
sort_mode: Optional[bool] = True
|
||||||
)-> None:
|
)-> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1151,6 +1288,63 @@ def fake_moe_post(
|
|||||||
moe_post.register_fake(fake_moe_post)
|
moe_post.register_fake(fake_moe_post)
|
||||||
|
|
||||||
|
|
||||||
|
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
|
||||||
|
def moe_sigmoid_group_topk_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_index: torch.Tensor,
|
||||||
|
norm_score: torch.Tensor,
|
||||||
|
block_static: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
n_group: int,
|
||||||
|
topk_group: int
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||||
|
x=x,
|
||||||
|
norm_score=norm_score,
|
||||||
|
topk_index=topk_index,
|
||||||
|
block_static=block_static,
|
||||||
|
bias=bias,
|
||||||
|
n_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
|
||||||
|
def moe_sigmoid_group_topk_norm_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_index: torch.Tensor,
|
||||||
|
norm_score: torch.Tensor,
|
||||||
|
block_static: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
n_group: int,
|
||||||
|
topk_group: int
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||||
|
x=x,
|
||||||
|
norm_score=norm_score,
|
||||||
|
topk_index=topk_index,
|
||||||
|
block_static=block_static,
|
||||||
|
bias=bias,
|
||||||
|
n_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fake_moe_sigmoid_group_topk_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_index: torch.Tensor,
|
||||||
|
norm_score: torch.Tensor,
|
||||||
|
block_static: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
n_group: int,
|
||||||
|
topk_group: int
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
|
||||||
##################################################
|
##################################################
|
||||||
# --------------- awq_dequantize -----------------
|
# --------------- awq_dequantize -----------------
|
||||||
##################################################
|
##################################################
|
||||||
|
|||||||
Reference in New Issue
Block a user