Add Llama4 support (#5092)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
Chang Su
2025-04-07 00:29:36 -07:00
committed by GitHub
parent d1bb171180
commit f04c80dc42
27 changed files with 2214 additions and 22 deletions

View File

@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
):

View File

@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),

View File

@@ -905,6 +905,7 @@ class Fp8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv

View File

@@ -344,6 +344,7 @@ class MoeWNA16Method:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -374,6 +375,7 @@ class MoeWNA16Method:
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,

View File

@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),