diff --git a/README.md b/README.md
index 578a275..64e8424 100644
--- a/README.md
+++ b/README.md
@@ -45,6 +45,22 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
+
+ | Qwen2 |
+ ✅ |
+ |
+ ✅ |
+ ✅ |
+ |
+
+
+ | Qwen2.5 |
+ ✅ |
+ |
+ ✅ |
+ ✅ |
+ |
+
| Qwen3 |
✅ |
@@ -77,6 +93,38 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
✅ |
|
+
+ | Llama2 |
+ ✅ |
+ |
+ |
+ ✅ |
+ |
+
+
+ | Llama3 |
+ ✅ |
+ |
+ |
+ ✅ |
+ |
+
+
+ | Llama3.1 |
+ ✅ |
+ |
+ |
+ ✅ |
+ |
+
+
+ | gpt-oss |
+ ✅ |
+ |
+ |
+ |
+ |
+
diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py
index e7b953a..5dd90ec 100644
--- a/vllm_kunlun/models/__init__.py
+++ b/vllm_kunlun/models/__init__.py
@@ -76,6 +76,10 @@ def register_model():
ModelRegistry.register_model(
"MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
+
+ ModelRegistry.register_model(
+ "GptOssForCausalLM",
+ "vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
def register_quant_method():
"""to do"""
diff --git a/vllm_kunlun/models/gpt_oss.py b/vllm_kunlun/models/gpt_oss.py
index 2f5d9dd..532718d 100644
--- a/vllm_kunlun/models/gpt_oss.py
+++ b/vllm_kunlun/models/gpt_oss.py
@@ -8,12 +8,15 @@ import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention import AttentionType
+from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size)
-from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.distributed import (get_ep_group, get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_gather)
+from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
@@ -23,12 +26,16 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
-from .utils import extract_layer_index, maybe_prefix
-
+from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsPP
+from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+from vllm_kunlun.ops.activation import SiluAndMul
class OAIAttention(nn.Module):
@@ -71,11 +78,8 @@ class OAIAttention(nn.Module):
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
- dtype=torch.bfloat16,
requires_grad=False))
- self.norm = RMSNorm(config.hidden_size, eps=1e-5)
-
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
@@ -118,36 +122,37 @@ class OAIAttention(nn.Module):
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
- t = self.norm(hidden_states)
-
- qkv, _ = self.qkv(t)
+ qkv, _ = self.qkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
-
- return output + hidden_states
+ return output
class MLPBlock(torch.nn.Module):
def __init__(
self,
- config: GptOssConfig,
+ vllm_config: VllmConfig,
layer_idx: int,
- quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ parallel_config = vllm_config.parallel_config
+
+ self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
- self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.router = torch.nn.Linear(config.hidden_size,
- config.num_local_experts,
- dtype=torch.bfloat16)
+ config.num_local_experts)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
@@ -159,36 +164,67 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
- activation="swigluoai")
+ activation="swigluoai",
+ is_sequence_parallel=self.is_sequence_parallel)
+
+ self.register_buffer("kunlun_linear_weights", torch.zeros(
+ config.num_local_experts,config.hidden_size,dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
- t = self.norm(x)
- g = self.router(t)
- t = self.experts(hidden_states=t, router_logits=g)
- return x + t
+ num_tokens = x.shape[0]
+ if self.is_sequence_parallel:
+ x = sequence_parallel_chunk(x)
+
+ g = self.router(x)
+ x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight)
+
+ if self.is_sequence_parallel:
+ x = tensor_model_parallel_all_gather(x.contiguous(), 0)
+ x = x[:num_tokens]
+ return x
class TransformerBlock(torch.nn.Module):
def __init__(
self,
- config: GptOssConfig,
- quant_config: QuantizationConfig,
+ vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
- self.layer_idx = extract_layer_index(prefix)
- self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
- self.mlp = MLPBlock(config,
- self.layer_idx,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp")
- def forward(self, hidden_states: torch.Tensor,
- positions: torch.Tensor) -> torch.Tensor:
- attn_output = self.attn(hidden_states, positions)
- output = self.mlp(attn_output)
- return output
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+
+ self.layer_idx = extract_layer_index(prefix)
+ self.attn = OAIAttention(config,
+ prefix=f"{prefix}.attn",
+ cache_config=cache_config)
+ self.mlp = MLPBlock(vllm_config,
+ self.layer_idx,
+ prefix=f"{prefix}.mlp")
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.attn(hidden_states, positions)
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ output = self.mlp(hidden_states)
+ return output, residual
@support_torch_compile
@@ -202,87 +238,86 @@ class GptOssModel(nn.Module):
):
super().__init__()
self.config = vllm_config.model_config.hf_config
- self.quant_config = vllm_config.quant_config
+ self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
- self.layers = torch.nn.ModuleList([
- TransformerBlock(
- self.config,
- quant_config=self.quant_config,
- prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
- ) for layer_idx in range(self.config.num_hidden_layers)
- ])
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ self.config.num_hidden_layers,
+ lambda prefix: TransformerBlock(
+ vllm_config,
+ prefix=prefix,
+ ),
+ prefix=f"{prefix}.layers",
+ )
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], self.config.hidden_size))
+ self.aux_hidden_state_layers = tuple[int, ...]()
- def forward(self, input_ids: torch.Tensor,
- positions: torch.Tensor) -> torch.Tensor:
- x = self.embedding(input_ids)
- for layer in self.layers:
- x = layer(x, positions)
- x = self.norm(x)
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embedding(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ x = inputs_embeds
+ else:
+ x = self.get_input_embeddings(input_ids)
+
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ x = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ aux_hidden_states = []
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ if i in self.aux_hidden_state_layers:
+ aux_hidden_states.append(x if residual is None else x +
+ residual)
+ x, residual = layer(x, positions, residual)
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": x,
+ "residual": residual
+ })
+ x, _ = self.norm(x, residual)
+
+ if len(aux_hidden_states) > 0:
+ return x, aux_hidden_states
return x
-
-class GptOssForCausalLM(nn.Module):
-
- def __init__(
- self,
- vllm_config: VllmConfig,
- prefix: str = "",
- ):
- super().__init__()
- self.vllm_config = vllm_config
- self.model_config = vllm_config.model_config.hf_config
- self.model = GptOssModel(
- vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "model"),
- )
- self.lm_head = ParallelLMHead(
- self.model_config.vocab_size,
- self.model_config.hidden_size,
- )
- self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
-
- def forward(self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
- assert intermediate_tensors is None
- assert inputs_embeds is None
- return self.model(input_ids, positions)
-
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
-
def _load_weights_mxfp4(
- self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- rename_mapping = {
- "self_attn": "attn",
- "input_layernorm.weight": "attn.norm.weight",
- "post_attention_layernorm.weight": "mlp.norm.weight",
- "embed_tokens": "embedding",
- }
-
- def maybe_rename(name: str) -> str:
- for remap_name, new_name in rename_mapping.items():
- if remap_name in name:
- return name.replace(remap_name, new_name)
- return name
-
+ self,
+ ep_rank_end: int,
+ ep_rank_start: int,
+ heads_per_rank: int,
+ head_start: int,
+ weights: Iterable[tuple[str, torch.Tensor]],
+ stacked_params_mapping: list[tuple[str, ...]],
+ ) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
+
mxfp4_block = 32
+ use_ep = self.parallel_config.enable_expert_parallel
+ num_experts = self.config.num_local_experts
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
- intermediate_size = self.model_config.intermediate_size
+
+ intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
@@ -294,26 +329,54 @@ class GptOssForCausalLM(nn.Module):
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
- # Attention heads per rank
- heads_per_rank = self.model_config.num_attention_heads // tp_size
- head_start = tp_rank * heads_per_rank
-
- use_ep = self.vllm_config.parallel_config.enable_expert_parallel
- ep_size = get_ep_group().world_size
- ep_rank = get_ep_group().rank
- num_experts = self.model_config.num_local_experts
- experts_per_rank = num_experts // ep_size
- ep_rank_start = ep_rank * experts_per_rank
- ep_rank_end = (ep_rank + 1) * experts_per_rank
-
for name, weight in weights:
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
- if "gate_up_proj_blocks" in name:
- # Handle MLP gate and up projection weights
- new_name = name.replace("gate_up_proj_blocks", "w13_weight")
+ if ".w13_weight_scale" in name:
+ # Handle MLP gate and up projection weights scale
+ if use_ep:
+ narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ narrow_weight = weight[:,
+ 2 * tp_rank_start:2 * tp_rank_end,
+ ...]
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param,
+ narrow_weight,
+ weight_name=name,
+ shard_id=None,
+ expert_id=None)
+ loaded_params.add(name)
+ continue
+ elif ".w2_weight_scale" in name:
+ # Handle MLP down projection weights
+ if use_ep:
+ narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ narrow_weight = weight[..., tp_rank_start //
+ mxfp4_block:tp_rank_end //
+ mxfp4_block]
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param,
+ narrow_weight,
+ weight_name=name,
+ shard_id=None,
+ expert_id=None)
+ loaded_params.add(name)
+ continue
+ elif ".w13_weight" in name:
+ # Handle MLP gate and up projection weights
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
@@ -328,19 +391,18 @@ class GptOssForCausalLM(nn.Module):
2 * tp_rank_start:2 * tp_rank_end,
...]
- param = params_dict[new_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
- weight_name=new_name,
+ weight_name=name,
shard_id=None,
expert_id=None)
- loaded_params.add(new_name)
-
- elif "down_proj_blocks" in name:
+ loaded_params.add(name)
+ continue
+ elif ".w2_weight" in name:
# Handle MLP down projection weights
- new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
@@ -351,60 +413,18 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
- param = params_dict[new_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
- weight_name=new_name,
+ weight_name=name,
shard_id=None,
expert_id=None)
- loaded_params.add(new_name)
-
- elif "gate_up_proj_scales" in name:
- # Handle MLP gate and up projection weights scale
- new_name = name.replace("gate_up_proj_scales",
- "w13_weight_scale")
- if use_ep:
- narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- narrow_weight = weight[:,
- 2 * tp_rank_start:2 * tp_rank_end,
- ...]
-
- param = params_dict[new_name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param,
- narrow_weight,
- weight_name=new_name,
- shard_id=None,
- expert_id=None)
- loaded_params.add(new_name)
-
- elif "down_proj_scales" in name:
- # Handle MLP down projection weights
- new_name = name.replace("down_proj_scales", "w2_weight_scale")
- if use_ep:
- narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- narrow_weight = weight[..., tp_rank_start //
- mxfp4_block:tp_rank_end //
- mxfp4_block]
-
- param = params_dict[new_name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param,
- narrow_weight,
- weight_name=new_name,
- shard_id=None,
- expert_id=None)
- loaded_params.add(new_name)
- elif "gate_up_proj_bias" in name:
+ loaded_params.add(name)
+ continue
+ elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
- new_name = name.replace("gate_up_proj_bias", "w13_bias")
-
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
@@ -412,20 +432,19 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
- param = params_dict[new_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
- weight_name=new_name,
+ weight_name=name,
shard_id=None,
expert_id=None)
- loaded_params.add(new_name)
-
- elif "down_proj_bias" in name:
+ loaded_params.add(name)
+ continue
+ elif ".w2_bias" in name:
# Handle MLP down projection bias
- new_name = name.replace("down_proj_bias", "w2_bias")
- param = params_dict[new_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
@@ -436,87 +455,73 @@ class GptOssForCausalLM(nn.Module):
weight.zero_()
weight_loader(param,
weight,
- weight_name=new_name,
+ weight_name=name,
shard_id=None,
expert_id=None)
- loaded_params.add(new_name)
+ loaded_params.add(name)
+ continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
- name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
- elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
- shard_id = ("q" if "q_proj" in name else
- "k" if "k_proj" in name else "v")
- name = name.replace("self_attn", "attn")
- param_name = name.replace(f"{shard_id}_proj", "qkv")
- param = params_dict[param_name]
- weight_loader = param.weight_loader
- weight_loader(param, weight, loaded_shard_id=shard_id)
- loaded_params.add(param_name)
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, weight)
+ else:
+ weight_loader(param, weight, shard_id)
+ break
else:
# Handle all other weights with potential renaming
- renamed_name = maybe_rename(name)
- if renamed_name not in params_dict:
+ if name not in params_dict:
continue
- param = params_dict[renamed_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
- loaded_params.add(renamed_name)
-
+ loaded_params.add(name)
return loaded_params
def _load_weights_other(
- self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- rename_mapping = {
- "self_attn": "attn",
- "input_layernorm.weight": "attn.norm.weight",
- "post_attention_layernorm.weight": "mlp.norm.weight",
- "embed_tokens": "embedding",
- }
-
- def maybe_rename(name: str) -> str:
- for remap_name, new_name in rename_mapping.items():
- if remap_name in name:
- return name.replace(remap_name, new_name)
- return name
-
+ self,
+ ep_rank_start: int,
+ ep_rank_end: int,
+ heads_per_rank: int,
+ head_start: int,
+ weights: Iterable[tuple[str, torch.Tensor]],
+ stacked_params_mapping: list[tuple[str, ...]],
+ ) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
+ use_ep = self.parallel_config.enable_expert_parallel
+
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
- intermediate_size = self.model_config.intermediate_size
+ intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
- # Attention heads per rank
- heads_per_rank = self.model_config.num_attention_heads // tp_size
- head_start = tp_rank * heads_per_rank
-
- use_ep = self.vllm_config.parallel_config.enable_expert_parallel
- ep_size = get_ep_group().world_size
- ep_rank = get_ep_group().rank
- num_experts = self.model_config.num_local_experts
- experts_per_rank = num_experts // ep_size
- ep_rank_start = ep_rank * experts_per_rank
- ep_rank_end = (ep_rank + 1) * experts_per_rank
-
for name, weight in weights:
- if ".experts.gate_up_proj" in name and "bias" not in name:
- # Handle MLP gate and up projection weights
- new_name = name.replace(".experts.gate_up_proj",
- ".experts.w13_weight")
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ if ".w13_weight" in name:
+ # Handle MLP gate and up projection weights
# Extract gate and up projection parts
- # since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
@@ -524,30 +529,25 @@ class GptOssForCausalLM(nn.Module):
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
- param = params_dict[new_name]
+ param = params_dict[name]
param.copy_(narrow_weight)
- loaded_params.add(new_name)
-
- elif ".experts.down_proj" in name and "bias" not in name:
+ loaded_params.add(name)
+ continue
+ elif ".w2_weight" in name:
# Handle MLP down projection weights
- new_name = name.replace(".experts.down_proj",
- ".experts.w2_weight")
-
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
- param = params_dict[new_name]
+ param = params_dict[name]
param.copy_(narrow_weight)
- loaded_params.add(new_name)
-
- elif "gate_up_proj_bias" in name:
+ loaded_params.add(name)
+ continue
+ elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
- new_name = name.replace("gate_up_proj_bias", "w13_bias")
-
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
@@ -555,60 +555,162 @@ class GptOssForCausalLM(nn.Module):
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
- param = params_dict[new_name]
-
+ param = params_dict[name]
param.copy_(narrow_weight)
- loaded_params.add(new_name)
-
- elif "down_proj_bias" in name:
+ loaded_params.add(name)
+ continue
+ elif ".w2_bias" in name:
# Handle MLP down projection bias
- new_name = name.replace("down_proj_bias", "w2_bias")
-
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
- param = params_dict[new_name]
+ param = params_dict[name]
param.copy_(weight)
- loaded_params.add(new_name)
+ loaded_params.add(name)
+ continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
- name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
- elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
- shard_id = ("q" if "q_proj" in name else
- "k" if "k_proj" in name else "v")
- name = name.replace("self_attn", "attn")
- param_name = name.replace(f"{shard_id}_proj", "qkv")
- param = params_dict[param_name]
- weight_loader = param.weight_loader
- weight_loader(param, weight, loaded_shard_id=shard_id)
- loaded_params.add(param_name)
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, weight)
+ else:
+ weight_loader(param, weight, shard_id)
+ break
else:
# Handle all other weights with potential renaming
-
- renamed_name = maybe_rename(name)
- if renamed_name not in params_dict:
+ if name not in params_dict:
continue
- param = params_dict[renamed_name]
+ param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
- loaded_params.add(renamed_name)
-
+ loaded_params.add(name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- quant_method = (self.model_config.quantization_config['quant_method']
- if hasattr(self.model_config, "quantization_config")
- else None)
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv", ".q_proj", "q"),
+ (".qkv", ".k_proj", "k"),
+ (".qkv", ".v_proj", "v"),
+ ]
+
+ tp_rank = get_tensor_model_parallel_rank()
+ tp_size = get_tensor_model_parallel_world_size()
+
+ # Attention heads per rank
+ heads_per_rank = self.config.num_attention_heads // tp_size
+ head_start = tp_rank * heads_per_rank
+
+ ep_size = get_ep_group().world_size
+ ep_rank = get_ep_group().rank
+ num_experts = self.config.num_local_experts
+ experts_per_rank = num_experts // ep_size
+ ep_rank_start = ep_rank * experts_per_rank
+ ep_rank_end = (ep_rank + 1) * experts_per_rank
+
+ quant_method = (self.config.quantization_config['quant_method'] if
+ hasattr(self.config, "quantization_config") else None)
if quant_method == "mxfp4":
- return self._load_weights_mxfp4(weights)
+ return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
+ heads_per_rank, head_start,
+ weights, stacked_params_mapping)
else:
- return self._load_weights_other(weights)
+ return self._load_weights_other(ep_rank_end, ep_rank_start,
+ heads_per_rank, head_start,
+ weights, stacked_params_mapping)
+
+
+class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
+ packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={
+ ".self_attn.": ".attn.",
+ },
+ orig_to_new_suffix={
+ ".embed_tokens.weight": ".embedding.weight",
+
+ # MoE MXFP4 weights
+ ".gate_up_proj_blocks": ".w13_weight",
+ ".down_proj_blocks": ".w2_weight",
+ ".gate_up_proj_scales": ".w13_weight_scale",
+ ".down_proj_scales": ".w2_weight_scale",
+
+ # MoE other weights
+ ".gate_up_proj": ".w13_weight",
+ ".down_proj": ".w2_weight",
+
+ # MoE Bias
+ ".gate_up_proj_bias": ".w13_bias",
+ ".down_proj_bias": ".w2_bias",
+ },
+ )
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.vllm_config = vllm_config
+ self.config = vllm_config.model_config.hf_config
+
+ self.model = GptOssModel(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+ self.lm_head = ParallelLMHead(
+ self.config.vocab_size,
+ self.config.hidden_size,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ self.logits_processor = LogitsProcessor(self.config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
+ self.model.aux_hidden_state_layers = layers
+
+ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
+ num_layers = len(self.model.layers)
+ return (2, num_layers // 2, num_layers - 3)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
+ return self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."]
+ if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py
index 0849432..6250964 100644
--- a/vllm_kunlun/ops/_kunlun_ops.py
+++ b/vllm_kunlun/ops/_kunlun_ops.py
@@ -418,6 +418,7 @@ class KunlunOps:
w2: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
+ ep_rank: int,
moe_top_k: int,
renormalize: bool,
inplace: bool = False,
@@ -430,7 +431,7 @@ class KunlunOps:
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""fused_moe"""
- global_num_experts = linear_weights.shape[0]
+ global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
@@ -445,82 +446,119 @@ class KunlunOps:
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
-
- torch.ops._C.moe_sigmoid_group_topk_norm(
+ router_logits = router_logits.to(torch.float)
+ if scoring_func == "softmax":
+ torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
+ normed_score=normed_score,
topk_index=topk_ids,
- norm_score=normed_score,
- block_static=block_statistic,
- bias=e_score_correction_bias,
- scale=1.0,
- n_group=num_expert_group,
- topk_group=1,
+ block_statistic=None,
+ stable=True)
+ elif scoring_func == "sigmoid":
+ torch.ops._C.moe_sigmoid_group_topk_norm(
+ x=router_logits,
+ topk_index=topk_ids,
+ norm_score=normed_score,
+ block_static=block_statistic,
+ bias=e_score_correction_bias,
+ scale=1.0,
+ n_group=num_expert_group,
+ topk_group=topk_group,
+ )
+
+ if w1_bias is not None or w2_bias is not None:
+ # Rignt now this branch is for gpt oss
+ # TODO (@xyDong23): faster here using moe_fc kernel
+ normed_score = normed_score.to(hidden_states.dtype)
+ out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
+ repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
+ topk_ids_flat = topk_ids.flatten()
+ for i in range(global_num_experts):
+ experts_id = ep_rank * global_num_experts + i
+ selected_token = topk_ids_flat == experts_id
+ if selected_token.sum():
+ cur_token = repeat_x[selected_token]
+ up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
+ dtype=cur_token.dtype, device=cur_token.device)
+ groupgemm1 = cur_token@ w1[i].T
+ # Add w13 bias
+ if w1_bias is not None:
+ groupgemm1 = groupgemm1 + w1_bias[i]
+ up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1)
+ groupgemm2 = up_gate @ w2[i].T
+ # Add w2 bias
+ if w2_bias is not None:
+ groupgemm2 = groupgemm2 + w2_bias[i]
+ out[selected_token] = groupgemm2
+ ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
+ return ouput
+ else:
+ moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
+ expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
+ sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
+ sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
+
+ torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
+
+ torch.ops._C.moe_pre_sorted(
+ x=hidden_states,
+ topk_index=topk_ids,
+ block_statistic=block_statistic,
+ moe_expand=moe_expand,
+ moe_index=sorted_tokens_idx,
+ expert_m=expert_m,
+ sorted_tokens_num_lod=sorted_tokens_num_lod)
+
+ y = torch.empty(M,moe_top_k,
+ w1.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
+
+ torch.ops._C.moe_fc(
+ x=moe_expand,
+ weight=w1,
+ sorted_tokens_num_lod=sorted_tokens_num_lod,
+ sorted_tokens_idx=sorted_tokens_idx,
+ moe_topk=moe_top_k,
+ y=y,
)
- moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
- expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
- sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
- sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
-
- torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
+ d = y.shape[-1] // 2
+ output_shape = (y.shape[:-1] + (d, ))
+ out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
+ torch.ops._C.silu_and_mul(out1, y)
+
+ out = torch.empty(M,moe_top_k,
+ w2.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
- torch.ops._C.moe_pre_sorted(
- x=hidden_states,
- topk_index=topk_ids,
- block_statistic=block_statistic,
- moe_expand=moe_expand,
- moe_index=sorted_tokens_idx,
- expert_m=expert_m,
- sorted_tokens_num_lod=sorted_tokens_num_lod)
+ out1 = out1.reshape(-1, out1.shape[-1])
- y = torch.empty(M,moe_top_k,
- w1.shape[1],
- dtype=hidden_states.dtype,
- device=hidden_states.device)
+ torch.ops._C.moe_fc(
+ x=out1,
+ weight=w2,
+ sorted_tokens_num_lod=sorted_tokens_num_lod,
+ sorted_tokens_idx=sorted_tokens_idx,
+ moe_topk=moe_top_k,
+ y=out,
+ )
- moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
+ dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
+ output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
+ sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
- torch.ops._C.moe_fc(
- x=moe_expand,
- weight=w1,
- sorted_tokens_num_lod=sorted_tokens_num_lod,
- sorted_tokens_idx=sorted_tokens_idx,
- moe_topk=moe_top_k,
- y=y)
-
- d = y.shape[-1] // 2
- output_shape = (y.shape[:-1] + (d, ))
- out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
- torch.ops._C.silu_and_mul(out1, y)
-
- out = torch.empty(M,moe_top_k,
- w2.shape[1],
- dtype=hidden_states.dtype,
- device=hidden_states.device)
-
- out1 = out1.reshape(-1, out1.shape[-1])
-
- torch.ops._C.moe_fc(
- x=out1,
- weight=w2,
- sorted_tokens_num_lod=sorted_tokens_num_lod,
- sorted_tokens_idx=sorted_tokens_idx,
- moe_topk=moe_top_k,
- y=out)
-
- dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
- output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
- sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
-
- torch.ops._C.moe_post(
- x=out,
- moe_index=sorted_tokens_idx,
- normed_scale=normed_score,
- dequant_scale=dequant_scale,
- y=output
- )
-
- return output
+ torch.ops._C.moe_post(
+ x=out,
+ moe_index=sorted_tokens_idx,
+ normed_scale=normed_score,
+ dequant_scale=dequant_scale,
+ y=output
+ )
+
+ return output
@staticmethod
def fused_moe_ep(
diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py
index b68a627..84cbd36 100644
--- a/vllm_kunlun/ops/fused_moe/layer.py
+++ b/vllm_kunlun/ops/fused_moe/layer.py
@@ -108,6 +108,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w2_weight,
router_logits,
linear_weights,
+ self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
@@ -116,6 +117,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
+ w1_bias = layer.w13_bias,
+ w2_bias = layer.w2_bias,
)
class FusedMoE(VllmFusedMoE):
@@ -144,6 +147,7 @@ class FusedMoE(VllmFusedMoE):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
is_sequence_parallel=False,
+ has_bias: bool = False,
):
super().__init__(
num_experts=num_experts, # Global number of experts
@@ -186,10 +190,12 @@ class FusedMoE(VllmFusedMoE):
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
+ has_bias=has_bias,
# quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
+ self.has_bias=has_bias
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py
index a151568..bbe32a6 100644
--- a/vllm_kunlun/ops/rotary_embedding.py
+++ b/vllm_kunlun/ops/rotary_embedding.py
@@ -147,7 +147,6 @@ RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
-YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
def Split_Norm_Rope(