[1/3] Optimize Slime Update Weights: Remove QWen3MOE Load Weight Overhead (#8751)
This commit is contained in:
@@ -766,7 +766,10 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
num_experts=self.config.num_experts,
|
num_experts=self.config.num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
||||||
|
if not hasattr(self, "_cached_params_dict"):
|
||||||
|
self._cached_params_dict = dict(self.named_parameters())
|
||||||
|
params_dict = self._cached_params_dict
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
layer_id = get_layer_id(name)
|
layer_id = get_layer_id(name)
|
||||||
if (
|
if (
|
||||||
@@ -805,11 +808,22 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Track if this is an expert weight to enable early skipping
|
||||||
|
is_expert_weight = False
|
||||||
|
|
||||||
for mapping in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Mark as expert weight regardless of whether we can process it
|
||||||
|
is_expert_weight = True
|
||||||
|
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
if name not in params_dict:
|
||||||
|
# Expert weight not on this rank, will be skipped below
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
@@ -821,6 +835,10 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
if is_expert_weight:
|
||||||
|
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
||||||
|
continue
|
||||||
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
@@ -837,11 +855,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
logger.warning(f"Parameter {name} not found in params_dict")
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
# TODO mimic deepseek
|
# TODO mimic deepseek
|
||||||
self.routed_experts_weights_of_layer = {
|
# Lazy initialization of expert weights cache to avoid slowing down load_weights
|
||||||
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
if not hasattr(self, "routed_experts_weights_of_layer"):
|
||||||
for layer_id in range(self.start_layer, self.end_layer)
|
self.routed_experts_weights_of_layer = {
|
||||||
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
|
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
||||||
}
|
for layer_id in range(self.start_layer, self.end_layer)
|
||||||
|
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_config_for_expert_location(cls, config):
|
def get_model_config_for_expert_location(cls, config):
|
||||||
|
|||||||
Reference in New Issue
Block a user