Fix OOM when updating expert locations (#6660)
This commit is contained in:
@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def update_expert_location(
|
class ExpertLocationUpdater:
|
||||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
def __init__(self):
|
||||||
new_expert_location_metadata: ExpertLocationMetadata,
|
self._first_execution = True
|
||||||
nnodes: int,
|
|
||||||
rank: int,
|
def update(
|
||||||
):
|
self,
|
||||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
_update_expert_weights(
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
routed_experts_weights_of_layer,
|
nnodes: int,
|
||||||
old_expert_location_metadata,
|
rank: int,
|
||||||
new_expert_location_metadata,
|
):
|
||||||
nnodes,
|
if self._first_execution:
|
||||||
rank,
|
self._first_execution = False
|
||||||
)
|
torch.cuda.empty_cache()
|
||||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
|
||||||
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
|
_update_expert_weights(
|
||||||
|
routed_experts_weights_of_layer,
|
||||||
|
old_expert_location_metadata,
|
||||||
|
new_expert_location_metadata,
|
||||||
|
nnodes,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||||
|
|
||||||
|
|
||||||
def _update_expert_weights(
|
def _update_expert_weights(
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
TokenToKVPoolAllocator,
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor import expert_location_updater
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
|
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.loader import (
|
from sglang.srt.model_loader.loader import (
|
||||||
@@ -267,6 +267,7 @@ class ModelRunner:
|
|||||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.expert_location_updater = ExpertLocationUpdater()
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
@@ -600,7 +601,7 @@ class ModelRunner:
|
|||||||
def update_expert_location(
|
def update_expert_location(
|
||||||
self, new_expert_location_metadata: ExpertLocationMetadata
|
self, new_expert_location_metadata: ExpertLocationMetadata
|
||||||
):
|
):
|
||||||
expert_location_updater.update_expert_location(
|
self.expert_location_updater.update(
|
||||||
self.model.routed_experts_weights_of_layer,
|
self.model.routed_experts_weights_of_layer,
|
||||||
new_expert_location_metadata,
|
new_expert_location_metadata,
|
||||||
nnodes=self.server_args.nnodes,
|
nnodes=self.server_args.nnodes,
|
||||||
|
|||||||
Reference in New Issue
Block a user