Fix wrong weight reference in dynamic EPLB (#6818)

This commit is contained in:
fzyzcjy
2025-06-03 14:26:04 +08:00
committed by GitHub
parent 27e327b415
commit 0ea330ca34
3 changed files with 27 additions and 13 deletions

View File

@@ -91,6 +91,7 @@ from sglang.srt.two_batch_overlap import (
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
get_bool_env_var,
@@ -1661,6 +1662,18 @@ class DeepseekV2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
):
@@ -1873,14 +1886,6 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True
# TODO support nextn later
if not is_nextn:
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):

View File

@@ -18,15 +18,10 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import logging
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.distributed import (
get_pp_group,
@@ -811,6 +806,7 @@ class Qwen3MoeForCausalLM(nn.Module):
else:
logger.warning(f"Parameter {name} not found in params_dict")
# TODO mimic deepseek
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)

View File

@@ -2257,3 +2257,16 @@ except:
def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
class LazyValue:
def __init__(self, creator: Callable):
self._creator = creator
self._value = None
@property
def value(self):
if self._creator is not None:
self._value = self._creator()
self._creator = None
return self._value