Tiny make fp4 moe method parameters more static (#8520)
This commit is contained in:
@@ -812,6 +812,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
# Only use `swizzle_blockscale` for shapes, not for real content
|
||||
layer.w13_blockscale_swizzled = Parameter(
|
||||
self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
layer.num_local_experts,
|
||||
@@ -826,6 +831,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(
|
||||
self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
extra_weight_attrs.update(
|
||||
@@ -1129,16 +1138,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Process w13 weights
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
||||
layer.w13_blockscale_swizzled = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# Process w2 weights
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_blockscale_swizzled = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
||||
|
||||
Reference in New Issue
Block a user