[Model] Add LongCat-Flash (#3833)

### What this PR does / why we need it?
Add LongCat-Flash support.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed

- vLLM version: v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chuyuelin <923822139@qq.com>
Co-authored-by: chuyuelin <chuyuelin1@huawei.com>
This commit is contained in:
Chu Yuelin
2025-12-31 17:06:55 +08:00
committed by GitHub
parent 03679cf1d3
commit d07d8a4535
8 changed files with 79 additions and 14 deletions

View File

@@ -225,7 +225,7 @@ def _select_experts_with_fusion_ops(
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
@@ -304,3 +304,28 @@ def _native_select_experts(
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def zero_experts_compute(
expert_indices: torch.Tensor,
expert_scales: torch.Tensor,
num_experts: int,
zero_expert_type: str,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
zero_expert_scales)
hidden_states = hidden_states.unsqueeze(1)
zero_expert_scales = zero_expert_scales.unsqueeze(2)
result = hidden_states * zero_expert_scales
result = result.sum(dim=1)
normal_expert_mask = expert_indices >= num_experts
expert_indices = torch.where(normal_expert_mask, 0, expert_indices)
expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales)
return expert_indices, expert_scales, result

View File

@@ -35,7 +35,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
zero_experts_compute)
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
@@ -92,7 +93,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -107,6 +109,15 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
@@ -119,7 +130,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -131,6 +142,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
class AscendFusedMoE(FusedMoE):
@@ -340,6 +354,7 @@ class AscendFusedMoE(FusedMoE):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,

View File

@@ -94,8 +94,6 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
hf_config = get_current_vllm_config().model_config.hf_config
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if mla_modules.indexer is not None: