Overlapped weight offload (#8034)
This commit is contained in:
@@ -1996,6 +1996,23 @@ class DeepseekV2Model(nn.Module):
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix=add_prefix("layers", prefix),
|
||||
offloader_kwargs=dict(
|
||||
submodule_accessor=lambda layer: (
|
||||
layer.mlp.experts
|
||||
if isinstance(layer.mlp, DeepseekV2MoE)
|
||||
else layer.mlp
|
||||
),
|
||||
whitelist_param_names_creator=lambda module: (
|
||||
[
|
||||
"w13_weight",
|
||||
"w2_weight",
|
||||
"w13_blockscale_swizzled",
|
||||
"w2_blockscale_swizzled",
|
||||
]
|
||||
if isinstance(module, FusedMoE)
|
||||
else []
|
||||
),
|
||||
),
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
Reference in New Issue
Block a user