Add Llama4 support (#5092)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: fzyzcjy <ch271828n@outlook.com> Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
|
||||
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale_inv),
|
||||
w2_scale=(layer.w2_weight_scale_inv),
|
||||
|
||||
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
|
||||
@@ -344,6 +344,7 @@ class MoeWNA16Method:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -374,6 +375,7 @@ class MoeWNA16Method:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=layer.w13_scales,
|
||||
|
||||
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
|
||||
Reference in New Issue
Block a user