[DP Attention] Refactor: adding some utility functions (#9136)

This commit is contained in:
Cheng Wan
2025-08-13 21:08:06 -07:00
committed by GitHub
parent b3363cc1aa
commit b87aacb5c5
21 changed files with 216 additions and 159 deletions

View File

@@ -1,10 +1,17 @@
from __future__ import annotations
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
import torch
from sglang.srt.layers.dp_attention import set_dp_buffer_len
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
if _ENABLE_PROFILE:
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
class _StageExecutor:
def __init__(self, debug_name: str, stages: List[Stage], inputs):
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
self._debug_name = debug_name
self._stages = stages
self._index = 0
self._stage_state = _StateDict()
self._stage_output = inputs
# handling DP attention
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
def next(self):
assert not self.done
stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage:
with _annotate_region(debug_name=op.debug_name):