From 447be242287dbe6ff8c21b68d611a46e3f13ea8f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 28 May 2025 00:59:53 +0800 Subject: [PATCH] Fix OOM when updating expert locations (#6660) --- .../model_executor/expert_location_updater.py | 39 ++++++++++++------- .../sglang/srt/model_executor/model_runner.py | 5 ++- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 13c4adc8d..8023c029e 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import ( logger = logging.getLogger(__name__) -def update_expert_location( - routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], - new_expert_location_metadata: ExpertLocationMetadata, - nnodes: int, - rank: int, -): - old_expert_location_metadata = get_global_expert_location_metadata() - _update_expert_weights( - routed_experts_weights_of_layer, - old_expert_location_metadata, - new_expert_location_metadata, - nnodes, - rank, - ) - old_expert_location_metadata.update(new_expert_location_metadata) +class ExpertLocationUpdater: + def __init__(self): + self._first_execution = True + + def update( + self, + routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], + new_expert_location_metadata: ExpertLocationMetadata, + nnodes: int, + rank: int, + ): + if self._first_execution: + self._first_execution = False + torch.cuda.empty_cache() + + old_expert_location_metadata = get_global_expert_location_metadata() + _update_expert_weights( + routed_experts_weights_of_layer, + old_expert_location_metadata, + new_expert_location_metadata, + nnodes, + rank, + ) + old_expert_location_metadata.update(new_expert_location_metadata) def _update_expert_weights( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3fad97bd6..f89e8629f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import ( TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator -from sglang.srt.model_executor import expert_location_updater from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import ( @@ -267,6 +267,7 @@ class ModelRunner: if self.server_args.enable_eplb and (not self.is_draft_worker) else None ) + self.expert_location_updater = ExpertLocationUpdater() # Load the model self.sampler = Sampler() @@ -600,7 +601,7 @@ class ModelRunner: def update_expert_location( self, new_expert_location_metadata: ExpertLocationMetadata ): - expert_location_updater.update_expert_location( + self.expert_location_updater.update( self.model.routed_experts_weights_of_layer, new_expert_location_metadata, nnodes=self.server_args.nnodes,