Fix OOM when updating expert locations (#6660)
This commit is contained in:
@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_expert_location(
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
_update_expert_weights(
|
||||
routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata,
|
||||
new_expert_location_metadata,
|
||||
nnodes,
|
||||
rank,
|
||||
)
|
||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||
class ExpertLocationUpdater:
|
||||
def __init__(self):
|
||||
self._first_execution = True
|
||||
|
||||
def update(
|
||||
self,
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
if self._first_execution:
|
||||
self._first_execution = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
_update_expert_weights(
|
||||
routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata,
|
||||
new_expert_location_metadata,
|
||||
nnodes,
|
||||
rank,
|
||||
)
|
||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||
|
||||
|
||||
def _update_expert_weights(
|
||||
|
||||
@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor import expert_location_updater
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import (
|
||||
@@ -267,6 +267,7 @@ class ModelRunner:
|
||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||
else None
|
||||
)
|
||||
self.expert_location_updater = ExpertLocationUpdater()
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
@@ -600,7 +601,7 @@ class ModelRunner:
|
||||
def update_expert_location(
|
||||
self, new_expert_location_metadata: ExpertLocationMetadata
|
||||
):
|
||||
expert_location_updater.update_expert_location(
|
||||
self.expert_location_updater.update(
|
||||
self.model.routed_experts_weights_of_layer,
|
||||
new_expert_location_metadata,
|
||||
nnodes=self.server_args.nnodes,
|
||||
|
||||
Reference in New Issue
Block a user