Fix the global scale fix does not support EPLB and improve enabling condition (#10369)
This commit is contained in:
@@ -504,14 +504,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||||
return
|
return
|
||||||
|
|
||||||
# ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor
|
|
||||||
load_global_experts = (
|
|
||||||
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
|
||||||
and "input_scale" in weight_name
|
|
||||||
)
|
|
||||||
|
|
||||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
if global_expert_location_metadata is None or load_global_experts:
|
if global_expert_location_metadata is None:
|
||||||
self._weight_loader_impl(
|
self._weight_loader_impl(
|
||||||
param=param,
|
param=param,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
@@ -548,10 +542,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# WARN: This makes the `expert_id` mean "local" and "global" in different cases
|
||||||
|
if not getattr(param, "_sglang_require_global_experts", False):
|
||||||
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||||
|
if expert_id == -1:
|
||||||
|
return
|
||||||
|
|
||||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
|
||||||
if expert_id == -1:
|
|
||||||
return
|
|
||||||
self._weight_loader_impl(
|
self._weight_loader_impl(
|
||||||
param=param,
|
param=param,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
|
|||||||
@@ -999,12 +999,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
|
w13_input_scale._sglang_require_global_experts = True
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
|
||||||
w2_input_scale = PerTensorScaleParameter(
|
w2_input_scale = PerTensorScaleParameter(
|
||||||
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
|
w2_input_scale._sglang_require_global_experts = True
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
def swizzle_blockscale(self, scale: torch.Tensor):
|
def swizzle_blockscale(self, scale: torch.Tensor):
|
||||||
|
|||||||
Reference in New Issue
Block a user