[Model] Add LongCat-Flash (#3833)

### What this PR does / why we need it?
Add LongCat-Flash support.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed

- vLLM version: v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chuyuelin <923822139@qq.com>
Co-authored-by: chuyuelin <chuyuelin1@huawei.com>
This commit is contained in:
Chu Yuelin
2025-12-31 17:06:55 +08:00
committed by GitHub
parent 03679cf1d3
commit d07d8a4535
8 changed files with 79 additions and 14 deletions

View File

@@ -298,6 +298,12 @@ packed_modules_model_mapping = {
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
}
@@ -514,6 +520,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
@@ -524,9 +531,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, log2phy,
global_redundant_expert_num, **kwargs)
custom_routing_function, scoring_func, routed_scaling_factor,
e_score_correction_bias, is_prefill, enable_force_load_balance,
log2phy, global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):

View File

@@ -199,6 +199,7 @@ class AscendW4A16FusedMoEMethod:
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,

View File

@@ -336,6 +336,7 @@ class AscendW4A8DynamicFusedMoEMethod:
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,

View File

@@ -28,7 +28,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
zero_experts_compute)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
@@ -183,6 +184,7 @@ class AscendW8A8DynamicFusedMoEMethod:
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
@@ -194,8 +196,11 @@ class AscendW8A8DynamicFusedMoEMethod:
pertoken_scale: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if zero_expert_num == 0 or zero_expert_type is None:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \
"Number of global experts mismatch (excluding redundancy)"
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
@@ -213,10 +218,19 @@ class AscendW8A8DynamicFusedMoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
assert topk_ids is not None
assert topk_weights is not None
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
@@ -253,7 +267,7 @@ class AscendW8A8DynamicFusedMoEMethod:
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
return moe_comm_method.fused_experts(
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
@@ -271,6 +285,9 @@ class AscendW8A8DynamicFusedMoEMethod:
dynamic_scale_for_share=dynamic_scale_for_share,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
def process_weights_after_loading(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(