Fix OOM when updating expert locations (#6660)

This commit is contained in:
fzyzcjy
2025-05-28 00:59:53 +08:00
committed by GitHub
parent 183d9f969c
commit 447be24228
2 changed files with 27 additions and 17 deletions

View File

@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import (
logger = logging.getLogger(__name__)
def update_expert_location(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
class ExpertLocationUpdater:
def __init__(self):
self._first_execution = True
def update(
self,
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
if self._first_execution:
self._first_execution = False
torch.cuda.empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights(

View File

@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor import expert_location_updater
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import (
@@ -267,6 +267,7 @@ class ModelRunner:
if self.server_args.enable_eplb and (not self.is_draft_worker)
else None
)
self.expert_location_updater = ExpertLocationUpdater()
# Load the model
self.sampler = Sampler()
@@ -600,7 +601,7 @@ class ModelRunner:
def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata
):
expert_location_updater.update_expert_location(
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
nnodes=self.server_args.nnodes,