Files
sglang/python/sglang/srt/operations.py

205 lines
5.8 KiB
Python

from __future__ import annotations
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
import torch
from sglang.srt.layers.dp_attention import set_dp_buffer_len
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
if _ENABLE_PROFILE:
import nvtx
def execute_operations(inputs, operations):
stages = _convert_operations_to_stages(operations)
executor = _StageExecutor("primary", stages, inputs=inputs)
for _ in range(executor.num_stages):
executor.next()
assert executor.done
return executor.output
def execute_overlapped_operations(
inputs_arr: Sequence,
operations_arr: Sequence,
delta_stages: Sequence[int],
) -> Sequence:
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
inputs_a, inputs_b = inputs_arr
operations_a, operations_b = operations_arr
delta_stage_a, delta_stage_b = delta_stages
assert delta_stage_a == 0
delta_stage = delta_stage_b
stages_a = _convert_operations_to_stages(operations_a)
stages_b = _convert_operations_to_stages(operations_b)
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
for _ in range(delta_stage):
executor_a.next()
for _ in range(executor_a.num_stages - delta_stage):
executor_a.next()
executor_b.next()
for _ in range(delta_stage):
executor_b.next()
assert executor_a.done and executor_b.done
return [executor_a.output, executor_b.output]
class YieldOperation:
pass
@dataclass
class ExecutionOperation:
debug_name: str
fn: Callable
Operation = Union[YieldOperation, ExecutionOperation, Callable]
Stage = List[ExecutionOperation]
class _StageExecutor:
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
self._debug_name = debug_name
self._stages = stages
self._index = 0
self._stage_state = _StateDict()
self._stage_output = inputs
# handling DP attention
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
def next(self):
assert not self.done
stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage:
with _annotate_region(debug_name=op.debug_name):
self._stage_output = op.fn(
state=self._stage_state,
**(
self._stage_output if self._stage_output is not None else {}
),
)
self._index += 1
@property
def output(self):
assert self.done
return self._stage_output
@property
def done(self):
return self._index >= self.num_stages
@property
def num_stages(self):
return len(self._stages)
@contextmanager
def _annotate_region(debug_name):
if _ENABLE_PROFILE:
with torch.autograd.profiler.record_function(debug_name):
with nvtx.annotate(debug_name):
yield
else:
yield
class _StateDict:
def __init__(self):
self._data = {}
def __setattr__(self, key, value):
if key == "_data":
super().__setattr__(key, value)
return
assert (
key not in self._data
), f"`{key}` already exist, are you sure you want to override it?"
self._data[key] = value
def __getattr__(self, item):
return self._data[item]
def __delattr__(self, item):
del self._data[item]
def pop(self, item):
return self._data.pop(item)
def update(self, values: Dict[str, Any]):
for k, v in values.items():
setattr(self, k, v)
def get(self, item):
return self._data.get(item)
def clear(self, expect_keys: Sequence[str]):
if set(self._data.keys()) != set(expect_keys):
raise Exception(
f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}"
)
self._data.clear()
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
operations = _decorate_operations(operations)
operation_chunks = list(
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
)
assert all(len(chunk) > 0 for chunk in operation_chunks)
return operation_chunks
def _chunk_by_separator(
items: List[Any], is_separator: Callable[[Any], bool]
) -> Generator[List[Any], None, None]:
pending_items = []
for item in items:
if is_separator(item):
yield pending_items
pending_items = []
else:
pending_items.append(item)
if len(pending_items) > 0:
yield pending_items
def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
return [_decorate_operation(op, debug_name_prefix) for op in operations]
def _decorate_operation(operation: Operation, debug_name_prefix: str):
if isinstance(operation, YieldOperation):
return operation
return ExecutionOperation(
debug_name=debug_name_prefix
+ getattr(operation, "__name__", "unknown").replace("op_", ""),
fn=operation,
)