[Model] Add LongCat-Flash (#3833)
### What this PR does / why we need it?
Add LongCat-Flash support.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed
- vLLM version: v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: chuyuelin <923822139@qq.com>
Co-authored-by: chuyuelin <chuyuelin1@huawei.com>
This commit is contained in:
@@ -225,7 +225,7 @@ def _select_experts_with_fusion_ops(
|
||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20))
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
@@ -304,3 +304,28 @@ def _native_select_experts(
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def zero_experts_compute(
|
||||
expert_indices: torch.Tensor,
|
||||
expert_scales: torch.Tensor,
|
||||
num_experts: int,
|
||||
zero_expert_type: str,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if zero_expert_type == "identity":
|
||||
zero_expert_mask = expert_indices < num_experts
|
||||
zero_expert_scales = expert_scales.clone()
|
||||
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
|
||||
zero_expert_scales)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
zero_expert_scales = zero_expert_scales.unsqueeze(2)
|
||||
result = hidden_states * zero_expert_scales
|
||||
result = result.sum(dim=1)
|
||||
|
||||
normal_expert_mask = expert_indices >= num_experts
|
||||
expert_indices = torch.where(normal_expert_mask, 0, expert_indices)
|
||||
expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales)
|
||||
|
||||
return expert_indices, expert_scales, result
|
||||
|
||||
@@ -35,7 +35,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
||||
from vllm_ascend.eplb.utils import moe_load_async_stream
|
||||
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
||||
set_flash_common3_context)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
|
||||
zero_experts_compute)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
FusedExpertsResult,
|
||||
setup_moe_comm_method)
|
||||
@@ -92,7 +93,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -107,6 +109,15 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
|
||||
expert_indices=topk_ids,
|
||||
expert_scales=topk_weights,
|
||||
num_experts=global_num_experts,
|
||||
zero_expert_type=zero_expert_type,
|
||||
hidden_states=x,
|
||||
)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
@@ -119,7 +130,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -131,6 +142,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
@@ -340,6 +354,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
|
||||
@@ -94,8 +94,6 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
hf_config = get_current_vllm_config().model_config.hf_config
|
||||
self.enable_shared_expert_dp = get_ascend_config(
|
||||
).enable_shared_expert_dp
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
self.first_k_dense_replace = hf_config.first_k_dense_replace
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layers = hf_config.num_hidden_layers
|
||||
if mla_modules.indexer is not None:
|
||||
|
||||
@@ -298,6 +298,12 @@ packed_modules_model_mapping = {
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"longcat_flash": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -514,6 +520,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
@@ -524,9 +531,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, log2phy,
|
||||
global_redundant_expert_num, **kwargs)
|
||||
custom_routing_function, scoring_func, routed_scaling_factor,
|
||||
e_score_correction_bias, is_prefill, enable_force_load_balance,
|
||||
log2phy, global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
@@ -199,6 +199,7 @@ class AscendW4A16FusedMoEMethod:
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
|
||||
@@ -336,6 +336,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
|
||||
@@ -28,7 +28,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
|
||||
zero_experts_compute)
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
||||
|
||||
|
||||
@@ -183,6 +184,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
@@ -194,8 +196,11 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
pertoken_scale: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
if zero_expert_num == 0 or zero_expert_type is None:
|
||||
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \
|
||||
"Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
fc3_context = get_flash_common3_context()
|
||||
@@ -213,10 +218,19 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
assert topk_ids is not None
|
||||
assert topk_weights is not None
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
|
||||
expert_indices=topk_ids,
|
||||
expert_scales=topk_weights,
|
||||
num_experts=global_num_experts,
|
||||
zero_expert_type=zero_expert_type,
|
||||
hidden_states=x,
|
||||
)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
@@ -253,7 +267,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||
== MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
|
||||
return moe_comm_method.fused_experts(
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
@@ -271,6 +285,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
return final_hidden_states
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
|
||||
@@ -2240,9 +2240,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
num_attn_module = 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
self.kv_caches, num_attn_module)
|
||||
return kv_caches
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
Reference in New Issue
Block a user