diff --git a/README.md b/README.md index 578a275..64e8424 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,22 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans + + Qwen2 + ✅ + + ✅ + ✅ + + + + Qwen2.5 + ✅ + + ✅ + ✅ + + Qwen3 ✅ @@ -77,6 +93,38 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans ✅ + + Llama2 + ✅ + + + ✅ + + + + Llama3 + ✅ + + + ✅ + + + + Llama3.1 + ✅ + + + ✅ + + + + gpt-oss + ✅ + + + + + diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index e7b953a..5dd90ec 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -76,6 +76,10 @@ def register_model(): ModelRegistry.register_model( "MiMoV2FlashForCausalLM", "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") + + ModelRegistry.register_model( + "GptOssForCausalLM", + "vllm_kunlun.models.gpt_oss:GptOssForCausalLM") def register_quant_method(): """to do""" diff --git a/vllm_kunlun/models/gpt_oss.py b/vllm_kunlun/models/gpt_oss.py index 2f5d9dd..532718d 100644 --- a/vllm_kunlun/models/gpt_oss.py +++ b/vllm_kunlun/models/gpt_oss.py @@ -8,12 +8,15 @@ import torch.distributed as dist from torch import nn from transformers import GptOssConfig -from vllm.attention import Attention, AttentionType +from vllm.attention import AttentionType +from vllm_kunlun.ops.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm_kunlun.ops.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) @@ -23,12 +26,16 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import extract_layer_index, maybe_prefix - +from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm_kunlun.ops.activation import SiluAndMul class OAIAttention(nn.Module): @@ -71,11 +78,8 @@ class OAIAttention(nn.Module): self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, - dtype=torch.bfloat16, requires_grad=False)) - self.norm = RMSNorm(config.hidden_size, eps=1e-5) - self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size self.scaling = self.head_dim**-0.5 @@ -118,36 +122,37 @@ class OAIAttention(nn.Module): def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - t = self.norm(hidden_states) - - qkv, _ = self.qkv(t) + qkv, _ = self.qkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) v = v.contiguous() attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - - return output + hidden_states + return output class MLPBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, + vllm_config: VllmConfig, layer_idx: int, - quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.router = torch.nn.Linear(config.hidden_size, - config.num_local_experts, - dtype=torch.bfloat16) + config.num_local_experts) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE(num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, @@ -159,36 +164,67 @@ class MLPBlock(torch.nn.Module): prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swigluoai") + activation="swigluoai", + is_sequence_parallel=self.is_sequence_parallel) + + self.register_buffer("kunlun_linear_weights", torch.zeros( + config.num_local_experts,config.hidden_size,dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: - t = self.norm(x) - g = self.router(t) - t = self.experts(hidden_states=t, router_logits=g) - return x + t + num_tokens = x.shape[0] + if self.is_sequence_parallel: + x = sequence_parallel_chunk(x) + + g = self.router(x) + x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight) + + if self.is_sequence_parallel: + x = tensor_model_parallel_all_gather(x.contiguous(), 0) + x = x[:num_tokens] + return x class TransformerBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, - quant_config: QuantizationConfig, + vllm_config: VllmConfig, prefix: str = "", ): super().__init__() - self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") - self.mlp = MLPBlock(config, - self.layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - def forward(self, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - attn_output = self.attn(hidden_states, positions) - output = self.mlp(attn_output) - return output + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, + prefix=f"{prefix}.attn", + cache_config=cache_config) + self.mlp = MLPBlock(vllm_config, + self.layer_idx, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.attn(hidden_states, positions) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + output = self.mlp(hidden_states) + return output, residual @support_torch_compile @@ -202,87 +238,86 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config - self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, ) - self.layers = torch.nn.ModuleList([ - TransformerBlock( - self.config, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, f"block.{layer_idx}"), - ) for layer_idx in range(self.config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: TransformerBlock( + vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + self.aux_hidden_state_layers = tuple[int, ...]() - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - x = self.embedding(input_ids) - for layer in self.layers: - x = layer(x, positions) - x = self.norm(x) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + x = inputs_embeds + else: + x = self.get_input_embeddings(input_ids) + + residual = None + else: + assert intermediate_tensors is not None + x = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if i in self.aux_hidden_state_layers: + aux_hidden_states.append(x if residual is None else x + + residual) + x, residual = layer(x, positions, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": x, + "residual": residual + }) + x, _ = self.norm(x, residual) + + if len(aux_hidden_states) > 0: + return x, aux_hidden_states return x - -class GptOssForCausalLM(nn.Module): - - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__() - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config.hf_config - self.model = GptOssModel( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - ) - self.lm_head = ParallelLMHead( - self.model_config.vocab_size, - self.model_config.hidden_size, - ) - self.logits_processor = LogitsProcessor(self.model_config.vocab_size) - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: - assert intermediate_tensors is None - assert inputs_embeds is None - return self.model(input_ids, positions) - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - def _load_weights_mxfp4( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size + + intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) @@ -294,26 +329,54 @@ class GptOssForCausalLM(nn.Module): tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # FIXME(woosuk): Remove this after testing. weight = weight.cuda() - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous weight = weight.view(num_experts, 2 * intermediate_size, @@ -328,19 +391,18 @@ class GptOssForCausalLM(nn.Module): 2 * tp_rank_start:2 * tp_rank_end, ...] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: + loaded_params.add(name) + continue + elif ".w2_weight" in name: # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 weight = weight.view(num_experts, -1, @@ -351,60 +413,18 @@ class GptOssForCausalLM(nn.Module): narrow_weight = weight[..., tp_rank_start // 2:tp_rank_end // 2] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", - "w13_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w13_bias" in name: # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] @@ -412,20 +432,19 @@ class GptOssForCausalLM(nn.Module): narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w2_bias" in name: # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: @@ -436,87 +455,73 @@ class GptOssForCausalLM(nn.Module): weight.zero_() weight_loader(param, weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) + loaded_params.add(name) + continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break else: # Handle all other weights with potential renaming - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: + if name not in params_dict: continue - param = params_dict[renamed_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(renamed_name) - + loaded_params.add(name) return loaded_params def _load_weights_other( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + use_ep = self.parallel_config.enable_expert_parallel + tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size + intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - for name, weight in weights: - if ".experts.gate_up_proj" in name and "bias" not in name: - # Handle MLP gate and up projection weights - new_name = name.replace(".experts.gate_up_proj", - ".experts.w13_weight") + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ".w13_weight" in name: + # Handle MLP gate and up projection weights # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: @@ -524,30 +529,25 @@ class GptOssForCausalLM(nn.Module): 2 * tp_rank_start:2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif ".experts.down_proj" in name and "bias" not in name: + loaded_params.add(name) + continue + elif ".w2_weight" in name: # Handle MLP down projection weights - new_name = name.replace(".experts.down_proj", - ".experts.w2_weight") - if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "gate_up_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w13_bias" in name: # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] @@ -555,60 +555,162 @@ class GptOssForCausalLM(nn.Module): narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end] - param = params_dict[new_name] - + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w2_bias" in name: # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() - param = params_dict[new_name] + param = params_dict[name] param.copy_(weight) - loaded_params.add(new_name) + loaded_params.add(name) + continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break else: # Handle all other weights with potential renaming - - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: + if name not in params_dict: continue - param = params_dict[renamed_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(renamed_name) - + loaded_params.add(name) return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - quant_method = (self.model_config.quantization_config['quant_method'] - if hasattr(self.model_config, "quantization_config") - else None) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) if quant_method == "mxfp4": - return self._load_weights_mxfp4(weights) + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) else: - return self._load_weights_other(weights) + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + + +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 0849432..6250964 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -418,6 +418,7 @@ class KunlunOps: w2: torch.Tensor, router_logits: torch.Tensor, linear_weights: torch.Tensor, + ep_rank: int, moe_top_k: int, renormalize: bool, inplace: bool = False, @@ -430,7 +431,7 @@ class KunlunOps: e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: """fused_moe""" - global_num_experts = linear_weights.shape[0] + global_num_experts, up_gate_size, _ = w1.shape M, N = hidden_states.shape hidden_dim = w2.shape[1] normed_score = torch.empty(M, @@ -445,82 +446,119 @@ class KunlunOps: block_statistic = torch.zeros( num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device ) - - torch.ops._C.moe_sigmoid_group_topk_norm( + router_logits = router_logits.to(torch.float) + if scoring_func == "softmax": + torch.ops._C.moe_softmax_topk_norm( x=router_logits, + normed_score=normed_score, topk_index=topk_ids, - norm_score=normed_score, - block_static=block_statistic, - bias=e_score_correction_bias, - scale=1.0, - n_group=num_expert_group, - topk_group=1, + block_statistic=None, + stable=True) + elif scoring_func == "sigmoid": + torch.ops._C.moe_sigmoid_group_topk_norm( + x=router_logits, + topk_index=topk_ids, + norm_score=normed_score, + block_static=block_statistic, + bias=e_score_correction_bias, + scale=1.0, + n_group=num_expert_group, + topk_group=topk_group, + ) + + if w1_bias is not None or w2_bias is not None: + # Rignt now this branch is for gpt oss + # TODO (@xyDong23): faster here using moe_fc kernel + normed_score = normed_score.to(hidden_states.dtype) + out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device) + repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0) + topk_ids_flat = topk_ids.flatten() + for i in range(global_num_experts): + experts_id = ep_rank * global_num_experts + i + selected_token = topk_ids_flat == experts_id + if selected_token.sum(): + cur_token = repeat_x[selected_token] + up_gate = torch.empty(selected_token.sum(), up_gate_size//2, + dtype=cur_token.dtype, device=cur_token.device) + groupgemm1 = cur_token@ w1[i].T + # Add w13 bias + if w1_bias is not None: + groupgemm1 = groupgemm1 + w1_bias[i] + up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1) + groupgemm2 = up_gate @ w2[i].T + # Add w2 bias + if w2_bias is not None: + groupgemm2 = groupgemm2 + w2_bias[i] + out[selected_token] = groupgemm2 + ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype) + return ouput + else: + moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float + expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] + sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] + sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) + + torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + + torch.ops._C.moe_pre_sorted( + x=hidden_states, + topk_index=topk_ids, + block_statistic=block_statistic, + moe_expand=moe_expand, + moe_index=sorted_tokens_idx, + expert_m=expert_m, + sorted_tokens_num_lod=sorted_tokens_num_lod) + + y = torch.empty(M,moe_top_k, + w1.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) + + torch.ops._C.moe_fc( + x=moe_expand, + weight=w1, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=y, ) - moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float - expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] - sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] - sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) - - torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + d = y.shape[-1] // 2 + output_shape = (y.shape[:-1] + (d, )) + out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + torch.ops._C.silu_and_mul(out1, y) + + out = torch.empty(M,moe_top_k, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) - torch.ops._C.moe_pre_sorted( - x=hidden_states, - topk_index=topk_ids, - block_statistic=block_statistic, - moe_expand=moe_expand, - moe_index=sorted_tokens_idx, - expert_m=expert_m, - sorted_tokens_num_lod=sorted_tokens_num_lod) + out1 = out1.reshape(-1, out1.shape[-1]) - y = torch.empty(M,moe_top_k, - w1.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) + torch.ops._C.moe_fc( + x=out1, + weight=w2, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_top_k, + y=out, + ) - moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) + dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device) + output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) + sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k) - torch.ops._C.moe_fc( - x=moe_expand, - weight=w1, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=moe_top_k, - y=y) - - d = y.shape[-1] // 2 - output_shape = (y.shape[:-1] + (d, )) - out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) - torch.ops._C.silu_and_mul(out1, y) - - out = torch.empty(M,moe_top_k, - w2.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) - - out1 = out1.reshape(-1, out1.shape[-1]) - - torch.ops._C.moe_fc( - x=out1, - weight=w2, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=moe_top_k, - y=out) - - dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device) - output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) - sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k) - - torch.ops._C.moe_post( - x=out, - moe_index=sorted_tokens_idx, - normed_scale=normed_score, - dequant_scale=dequant_scale, - y=output - ) - - return output + torch.ops._C.moe_post( + x=out, + moe_index=sorted_tokens_idx, + normed_scale=normed_score, + dequant_scale=dequant_scale, + y=output + ) + + return output @staticmethod def fused_moe_ep( diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index b68a627..84cbd36 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -108,6 +108,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): layer.w2_weight, router_logits, linear_weights, + self.moe.ep_rank, top_k, renormalize=renormalize, inplace=True, @@ -116,6 +117,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + w1_bias = layer.w13_bias, + w2_bias = layer.w2_bias, ) class FusedMoE(VllmFusedMoE): @@ -144,6 +147,7 @@ class FusedMoE(VllmFusedMoE): enable_eplb: bool = False, num_redundant_experts: int = 0, is_sequence_parallel=False, + has_bias: bool = False, ): super().__init__( num_experts=num_experts, # Global number of experts @@ -186,10 +190,12 @@ class FusedMoE(VllmFusedMoE): moe_parallel_config=self.moe_parallel_config, in_dtype=model_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + has_bias=has_bias, # quant_config=quant_config, ) self.moe_config = moe self.quant_config = quant_config + self.has_bias=has_bias # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py index a151568..bbe32a6 100644 --- a/vllm_kunlun/ops/rotary_embedding.py +++ b/vllm_kunlun/ops/rotary_embedding.py @@ -147,7 +147,6 @@ RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda RotaryEmbedding.forward = vllm_kunlun_forward_cuda MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda -YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq def Split_Norm_Rope(