[Feature] Support gpt-oss and update model list (#71)

* [Docs] Update Support Models

* [Feature] Support gpt-oss

* [Docs] fix model support list

* Fix Moe

* Fix

* Fix moe_ep

* remove gpt oss graph support , not yet

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
Xinyu Dong
2026-01-04 21:19:49 +08:00
committed by GitHub
parent ded24f5026
commit fe666fb24f
6 changed files with 537 additions and 340 deletions

View File

@@ -45,6 +45,22 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
<tr>
<td class="model-name">Qwen2</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen2.5</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr> <tr>
<td class="model-name">Qwen3</td> <td class="model-name">Qwen3</td>
<td class="status-support"></td> <td class="status-support"></td>
@@ -77,6 +93,38 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
<td class="status-support"></td> <td class="status-support"></td>
<td></td> <td></td>
</tr> </tr>
<tr>
<td class="model-name">Llama2</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Llama3</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Llama3.1</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">gpt-oss</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td></td>
<td></td>
</tr>
</tbody> </tbody>
</table> </table>

View File

@@ -76,6 +76,10 @@ def register_model():
ModelRegistry.register_model( ModelRegistry.register_model(
"MiMoV2FlashForCausalLM", "MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
def register_quant_method(): def register_quant_method():
"""to do""" """to do"""

View File

@@ -8,12 +8,15 @@ import torch.distributed as dist
from torch import nn from torch import nn
from transformers import GptOssConfig from transformers import GptOssConfig
from vllm.attention import Attention, AttentionType from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank,
from vllm.model_executor.layers.fused_moe import FusedMoE get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
@@ -23,12 +26,16 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
from .utils import extract_layer_index, maybe_prefix from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm_kunlun.ops.activation import SiluAndMul
class OAIAttention(nn.Module): class OAIAttention(nn.Module):
@@ -71,11 +78,8 @@ class OAIAttention(nn.Module):
self.sinks = torch.nn.Parameter( self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size, torch.empty(config.num_attention_heads // tp_size,
dtype=torch.bfloat16,
requires_grad=False)) requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.q_size = self.num_attention_heads * self.head_dim // tp_size self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@@ -118,36 +122,37 @@ class OAIAttention(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor: positions: torch.Tensor) -> torch.Tensor:
t = self.norm(hidden_states) qkv, _ = self.qkv(hidden_states)
qkv, _ = self.qkv(t)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
v = v.contiguous() v = v.contiguous()
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output
return output + hidden_states
class MLPBlock(torch.nn.Module): class MLPBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
layer_idx: int, layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.router = torch.nn.Linear(config.hidden_size, self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts, config.num_local_experts)
dtype=torch.bfloat16)
assert config.intermediate_size % self.world_size == 0 assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts, self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
@@ -159,36 +164,67 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
has_bias=True, has_bias=True,
activation="swigluoai") activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
self.register_buffer("kunlun_linear_weights", torch.zeros(
config.num_local_experts,config.hidden_size,dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x) num_tokens = x.shape[0]
g = self.router(t) if self.is_sequence_parallel:
t = self.experts(hidden_states=t, router_logits=g) x = sequence_parallel_chunk(x)
return x + t
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x
class TransformerBlock(torch.nn.Module): class TransformerBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
self.mlp = MLPBlock(config,
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, hidden_states: torch.Tensor, config = vllm_config.model_config.hf_config
positions: torch.Tensor) -> torch.Tensor: cache_config = vllm_config.cache_config
attn_output = self.attn(hidden_states, positions)
output = self.mlp(attn_output) self.layer_idx = extract_layer_index(prefix)
return output self.attn = OAIAttention(config,
prefix=f"{prefix}.attn",
cache_config=cache_config)
self.mlp = MLPBlock(vllm_config,
self.layer_idx,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.attn(hidden_states, positions)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
output = self.mlp(hidden_states)
return output, residual
@support_torch_compile @support_torch_compile
@@ -202,87 +238,86 @@ class GptOssModel(nn.Module):
): ):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
) )
self.layers = torch.nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
TransformerBlock( self.config.num_hidden_layers,
self.config, lambda prefix: TransformerBlock(
quant_config=self.quant_config, vllm_config,
prefix=maybe_prefix(prefix, f"block.{layer_idx}"), prefix=prefix,
) for layer_idx in range(self.config.num_hidden_layers) ),
]) prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
self.aux_hidden_state_layers = tuple[int, ...]()
def forward(self, input_ids: torch.Tensor, def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
positions: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids)
x = self.embedding(input_ids)
for layer in self.layers: def forward(
x = layer(x, positions) self,
x = self.norm(x) input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
x = inputs_embeds
else:
x = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x +
residual)
x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": x,
"residual": residual
})
x, _ = self.norm(x, residual)
if len(aux_hidden_states) > 0:
return x, aux_hidden_states
return x return x
class GptOssForCausalLM(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.model_config.vocab_size,
self.model_config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
assert intermediate_tensors is None
assert inputs_embeds is None
return self.model(input_ids, positions)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def _load_weights_mxfp4( def _load_weights_mxfp4(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self,
rename_mapping = { ep_rank_end: int,
"self_attn": "attn", ep_rank_start: int,
"input_layernorm.weight": "attn.norm.weight", heads_per_rank: int,
"post_attention_layernorm.weight": "mlp.norm.weight", head_start: int,
"embed_tokens": "embedding", weights: Iterable[tuple[str, torch.Tensor]],
} stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
mxfp4_block = 32 mxfp4_block = 32
use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block, per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size) tp_size)
@@ -294,26 +329,54 @@ class GptOssForCausalLM(nn.Module):
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size) intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights: for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# FIXME(woosuk): Remove this after testing. # FIXME(woosuk): Remove this after testing.
weight = weight.cuda() weight = weight.cuda()
if "gate_up_proj_blocks" in name: if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights # Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_blocks", "w13_weight") if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_weight_scale" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w13_weight" in name:
# Handle MLP gate and up projection weights
# flat weight from (E, 2 * N, block_size, entry_per_block) # flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size, weight = weight.view(num_experts, 2 * intermediate_size,
@@ -328,19 +391,18 @@ class GptOssForCausalLM(nn.Module):
2 * tp_rank_start:2 * tp_rank_end, 2 * tp_rank_start:2 * tp_rank_end,
...] ...]
param = params_dict[new_name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight_loader(param,
narrow_weight, narrow_weight,
weight_name=new_name, weight_name=name,
shard_id=None, shard_id=None,
expert_id=None) expert_id=None)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "down_proj_blocks" in name: elif ".w2_weight" in name:
# Handle MLP down projection weights # Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1 # same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2 # uint8, divide by 2
weight = weight.view(num_experts, -1, weight = weight.view(num_experts, -1,
@@ -351,60 +413,18 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[..., narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2] tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[new_name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight_loader(param,
narrow_weight, narrow_weight,
weight_name=new_name, weight_name=name,
shard_id=None, shard_id=None,
expert_id=None) expert_id=None)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "gate_up_proj_scales" in name: elif ".w13_bias" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales",
"w13_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases # Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts # Extract gate and up projection bias parts
if use_ep: if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
@@ -412,20 +432,19 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[:, narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end] 2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight_loader(param,
narrow_weight, narrow_weight,
weight_name=new_name, weight_name=name,
shard_id=None, shard_id=None,
expert_id=None) expert_id=None)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "down_proj_bias" in name: elif ".w2_bias" in name:
# Handle MLP down projection bias # Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias") param = params_dict[name]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
if use_ep: if use_ep:
@@ -436,87 +455,73 @@ class GptOssForCausalLM(nn.Module):
weight.zero_() weight.zero_()
weight_loader(param, weight_loader(param,
weight, weight,
weight_name=new_name, weight_name=name,
shard_id=None, shard_id=None,
expert_id=None) expert_id=None)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "sinks" in name: elif "sinks" in name:
# Handle attention sinks (distributed across ranks) # Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name] param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank) narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight) param.data.copy_(narrow_weight)
loaded_params.add(name) loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name: continue
shard_id = ("q" if "q_proj" in name else for param_name, weight_name, shard_id in stacked_params_mapping:
"k" if "k_proj" in name else "v") if weight_name not in name:
name = name.replace("self_attn", "attn") continue
param_name = name.replace(f"{shard_id}_proj", "qkv") name = name.replace(weight_name, param_name)
param = params_dict[param_name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = getattr(param, "weight_loader",
weight_loader(param, weight, loaded_shard_id=shard_id) default_weight_loader)
loaded_params.add(param_name) if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else: else:
# Handle all other weights with potential renaming # Handle all other weights with potential renaming
renamed_name = maybe_rename(name) if name not in params_dict:
if renamed_name not in params_dict:
continue continue
param = params_dict[renamed_name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight) weight_loader(param, weight)
loaded_params.add(renamed_name) loaded_params.add(name)
return loaded_params return loaded_params
def _load_weights_other( def _load_weights_other(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self,
rename_mapping = { ep_rank_start: int,
"self_attn": "attn", ep_rank_end: int,
"input_layernorm.weight": "attn.norm.weight", heads_per_rank: int,
"post_attention_layernorm.weight": "mlp.norm.weight", head_start: int,
"embed_tokens": "embedding", weights: Iterable[tuple[str, torch.Tensor]],
} stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size) per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank # Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size) intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights: for name, weight in weights:
if ".experts.gate_up_proj" in name and "bias" not in name: # Skip layers on other devices.
# Handle MLP gate and up projection weights if is_pp_missing_parameter(name, self):
new_name = name.replace(".experts.gate_up_proj", continue
".experts.w13_weight")
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts # Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep: if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else: else:
@@ -524,30 +529,25 @@ class GptOssForCausalLM(nn.Module):
2 * tp_rank_start:2 * tp_rank_end] 2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name] param = params_dict[name]
param.copy_(narrow_weight) param.copy_(narrow_weight)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif ".experts.down_proj" in name and "bias" not in name: elif ".w2_weight" in name:
# Handle MLP down projection weights # Handle MLP down projection weights
new_name = name.replace(".experts.down_proj",
".experts.w2_weight")
if use_ep: if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else: else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name] param = params_dict[name]
param.copy_(narrow_weight) param.copy_(narrow_weight)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "gate_up_proj_bias" in name: elif ".w13_bias" in name:
# Handle MLP gate and up projection biases # Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts # Extract gate and up projection bias parts
if use_ep: if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
@@ -555,60 +555,162 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[:, narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end] 2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name] param = params_dict[name]
param.copy_(narrow_weight) param.copy_(narrow_weight)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "down_proj_bias" in name: elif ".w2_bias" in name:
# Handle MLP down projection bias # Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
if use_ep: if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...] weight = weight[ep_rank_start:ep_rank_end, ...]
else: else:
# (only load on rank 0 to avoid duplication) # (only load on rank 0 to avoid duplication)
if tp_rank != 0: if tp_rank != 0:
weight.zero_() weight.zero_()
param = params_dict[new_name] param = params_dict[name]
param.copy_(weight) param.copy_(weight)
loaded_params.add(new_name) loaded_params.add(name)
continue
elif "sinks" in name: elif "sinks" in name:
# Handle attention sinks (distributed across ranks) # Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name] param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank) narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight) param.data.copy_(narrow_weight)
loaded_params.add(name) loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name: continue
shard_id = ("q" if "q_proj" in name else for param_name, weight_name, shard_id in stacked_params_mapping:
"k" if "k_proj" in name else "v") if weight_name not in name:
name = name.replace("self_attn", "attn") continue
param_name = name.replace(f"{shard_id}_proj", "qkv") name = name.replace(weight_name, param_name)
param = params_dict[param_name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = getattr(param, "weight_loader",
weight_loader(param, weight, loaded_shard_id=shard_id) default_weight_loader)
loaded_params.add(param_name) if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else: else:
# Handle all other weights with potential renaming # Handle all other weights with potential renaming
if name not in params_dict:
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue continue
param = params_dict[renamed_name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight) weight_loader(param, weight)
loaded_params.add(renamed_name) loaded_params.add(name)
return loaded_params return loaded_params
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
quant_method = (self.model_config.quantization_config['quant_method'] stacked_params_mapping = [
if hasattr(self.model_config, "quantization_config") # (param_name, shard_name, shard_id)
else None) (".qkv", ".q_proj", "q"),
(".qkv", ".k_proj", "k"),
(".qkv", ".v_proj", "v"),
]
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# Attention heads per rank
heads_per_rank = self.config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
quant_method = (self.config.quantization_config['quant_method'] if
hasattr(self.config, "quantization_config") else None)
if quant_method == "mxfp4": if quant_method == "mxfp4":
return self._load_weights_mxfp4(weights) return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
else: else:
return self._load_weights_other(weights) return self._load_weights_other(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".self_attn.": ".attn.",
},
orig_to_new_suffix={
".embed_tokens.weight": ".embedding.weight",
# MoE MXFP4 weights
".gate_up_proj_blocks": ".w13_weight",
".down_proj_blocks": ".w2_weight",
".gate_up_proj_scales": ".w13_weight_scale",
".down_proj_scales": ".w2_weight_scale",
# MoE other weights
".gate_up_proj": ".w13_weight",
".down_proj": ".w2_weight",
# MoE Bias
".gate_up_proj_bias": ".w13_bias",
".down_proj_bias": ".w2_bias",
},
)
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -418,6 +418,7 @@ class KunlunOps:
w2: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
linear_weights: torch.Tensor, linear_weights: torch.Tensor,
ep_rank: int,
moe_top_k: int, moe_top_k: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
@@ -430,7 +431,7 @@ class KunlunOps:
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""fused_moe""" """fused_moe"""
global_num_experts = linear_weights.shape[0] global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape M, N = hidden_states.shape
hidden_dim = w2.shape[1] hidden_dim = w2.shape[1]
normed_score = torch.empty(M, normed_score = torch.empty(M,
@@ -445,82 +446,119 @@ class KunlunOps:
block_statistic = torch.zeros( block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
) )
router_logits = router_logits.to(torch.float)
torch.ops._C.moe_sigmoid_group_topk_norm( if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits, x=router_logits,
normed_score=normed_score,
topk_index=topk_ids, topk_index=topk_ids,
norm_score=normed_score, block_statistic=None,
block_static=block_statistic, stable=True)
bias=e_score_correction_bias, elif scoring_func == "sigmoid":
scale=1.0, torch.ops._C.moe_sigmoid_group_topk_norm(
n_group=num_expert_group, x=router_logits,
topk_group=1, topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
experts_id = ep_rank * global_num_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
groupgemm1 = cur_token@ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1)
groupgemm2 = up_gate @ w2[i].T
# Add w2 bias
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
return ouput
else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
) )
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float d = y.shape[-1] // 2
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] output_shape = (y.shape[:-1] + (d, ))
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) torch.ops._C.silu_and_mul(out1, y)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic) out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops._C.moe_pre_sorted( out1 = out1.reshape(-1, out1.shape[-1])
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,moe_top_k, torch.ops._C.moe_fc(
w1.shape[1], x=out1,
dtype=hidden_states.dtype, weight=w2,
device=hidden_states.device) sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_fc( torch.ops._C.moe_post(
x=moe_expand, x=out,
weight=w1, moe_index=sorted_tokens_idx,
sorted_tokens_num_lod=sorted_tokens_num_lod, normed_scale=normed_score,
sorted_tokens_idx=sorted_tokens_idx, dequant_scale=dequant_scale,
moe_topk=moe_top_k, y=output
y=y) )
d = y.shape[-1] // 2 return output
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
@staticmethod @staticmethod
def fused_moe_ep( def fused_moe_ep(

View File

@@ -108,6 +108,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w2_weight, layer.w2_weight,
router_logits, router_logits,
linear_weights, linear_weights,
self.moe.ep_rank,
top_k, top_k,
renormalize=renormalize, renormalize=renormalize,
inplace=True, inplace=True,
@@ -116,6 +117,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
w2_bias = layer.w2_bias,
) )
class FusedMoE(VllmFusedMoE): class FusedMoE(VllmFusedMoE):
@@ -144,6 +147,7 @@ class FusedMoE(VllmFusedMoE):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
is_sequence_parallel=False, is_sequence_parallel=False,
has_bias: bool = False,
): ):
super().__init__( super().__init__(
num_experts=num_experts, # Global number of experts num_experts=num_experts, # Global number of experts
@@ -186,10 +190,12 @@ class FusedMoE(VllmFusedMoE):
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype, in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
# quant_config=quant_config, # quant_config=quant_config,
) )
self.moe_config = moe self.moe_config = moe
self.quant_config = quant_config self.quant_config = quant_config
self.has_bias=has_bias
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.

View File

@@ -147,7 +147,6 @@ RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda RotaryEmbedding.forward = vllm_kunlun_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
def Split_Norm_Rope( def Split_Norm_Rope(