[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import set_dp_buffer_len
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
|
||||
|
||||
if _ENABLE_PROFILE:
|
||||
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
|
||||
|
||||
|
||||
class _StageExecutor:
|
||||
def __init__(self, debug_name: str, stages: List[Stage], inputs):
|
||||
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
|
||||
self._debug_name = debug_name
|
||||
self._stages = stages
|
||||
self._index = 0
|
||||
self._stage_state = _StateDict()
|
||||
self._stage_output = inputs
|
||||
|
||||
# handling DP attention
|
||||
forward_batch: ForwardBatch = inputs["forward_batch"]
|
||||
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
||||
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
|
||||
|
||||
def next(self):
|
||||
assert not self.done
|
||||
|
||||
stage = self._stages[self._index]
|
||||
|
||||
if self._global_dp_buffer_len is not None:
|
||||
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
|
||||
|
||||
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
||||
for op in stage:
|
||||
with _annotate_region(debug_name=op.debug_name):
|
||||
|
||||
Reference in New Issue
Block a user