2025-08-13 21:08:06 -07:00
from __future__ import annotations
2025-05-20 12:05:30 +08:00
import os
from contextlib import contextmanager
from dataclasses import dataclass
2025-08-13 21:08:06 -07:00
from typing import TYPE_CHECKING , Any , Callable , Dict , Generator , List , Sequence , Union
2025-05-20 12:05:30 +08:00
import torch
2025-08-13 21:08:06 -07:00
from sglang . srt . layers . dp_attention import set_dp_buffer_len
if TYPE_CHECKING :
from sglang . srt . model_executor . forward_batch_info import ForwardBatch
2025-05-20 12:05:30 +08:00
_ENABLE_PROFILE = bool ( int ( os . environ . get ( " SGLANG_OPERATIONS_ENABLE_PROFILE " , " 0 " ) ) )
if _ENABLE_PROFILE :
import nvtx
def execute_operations ( inputs , operations ) :
2025-05-25 08:39:07 +08:00
stages = _convert_operations_to_stages ( operations )
2025-05-20 12:05:30 +08:00
executor = _StageExecutor ( " primary " , stages , inputs = inputs )
for _ in range ( executor . num_stages ) :
executor . next ( )
assert executor . done
return executor . output
2025-05-25 08:39:07 +08:00
def execute_overlapped_operations (
inputs_arr : Sequence ,
operations_arr : Sequence ,
delta_stages : Sequence [ int ] ,
) - > Sequence :
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
inputs_a , inputs_b = inputs_arr
operations_a , operations_b = operations_arr
delta_stage_a , delta_stage_b = delta_stages
assert delta_stage_a == 0
delta_stage = delta_stage_b
stages_a = _convert_operations_to_stages ( operations_a )
stages_b = _convert_operations_to_stages ( operations_b )
executor_a = _StageExecutor ( " a " , stages_a , inputs = inputs_a )
executor_b = _StageExecutor ( " b " , stages_b , inputs = inputs_b )
for _ in range ( delta_stage ) :
executor_a . next ( )
for _ in range ( executor_a . num_stages - delta_stage ) :
executor_a . next ( )
executor_b . next ( )
for _ in range ( delta_stage ) :
executor_b . next ( )
assert executor_a . done and executor_b . done
return [ executor_a . output , executor_b . output ]
2025-05-20 12:05:30 +08:00
class YieldOperation :
pass
@dataclass
class ExecutionOperation :
debug_name : str
fn : Callable
Operation = Union [ YieldOperation , ExecutionOperation , Callable ]
Stage = List [ ExecutionOperation ]
class _StageExecutor :
2025-08-13 21:08:06 -07:00
def __init__ ( self , debug_name : str , stages : List [ Stage ] , inputs : dict ) :
2025-05-20 12:05:30 +08:00
self . _debug_name = debug_name
self . _stages = stages
self . _index = 0
self . _stage_state = _StateDict ( )
self . _stage_output = inputs
2025-08-13 21:08:06 -07:00
# handling DP attention
forward_batch : ForwardBatch = inputs [ " forward_batch " ]
self . _global_dp_buffer_len = forward_batch . global_dp_buffer_len
self . _local_dp_buffer_len = forward_batch . input_ids . shape [ 0 ]
2025-08-15 22:08:11 -07:00
self . _global_num_tokens = forward_batch . global_num_tokens_cpu
2025-08-13 21:08:06 -07:00
2025-05-20 12:05:30 +08:00
def next ( self ) :
assert not self . done
stage = self . _stages [ self . _index ]
2025-08-13 21:08:06 -07:00
if self . _global_dp_buffer_len is not None :
2025-08-15 22:08:11 -07:00
set_dp_buffer_len (
self . _global_dp_buffer_len ,
self . _local_dp_buffer_len ,
self . _global_num_tokens ,
)
2025-08-13 21:08:06 -07:00
2025-05-20 12:05:30 +08:00
with _annotate_region ( debug_name = f " { self . _debug_name } { self . _index } " ) :
for op in stage :
with _annotate_region ( debug_name = op . debug_name ) :
self . _stage_output = op . fn (
state = self . _stage_state ,
* * (
self . _stage_output if self . _stage_output is not None else { }
) ,
)
self . _index + = 1
@property
def output ( self ) :
assert self . done
return self . _stage_output
@property
def done ( self ) :
return self . _index > = self . num_stages
@property
def num_stages ( self ) :
return len ( self . _stages )
@contextmanager
def _annotate_region ( debug_name ) :
if _ENABLE_PROFILE :
with torch . autograd . profiler . record_function ( debug_name ) :
with nvtx . annotate ( debug_name ) :
yield
else :
yield
class _StateDict :
def __init__ ( self ) :
self . _data = { }
def __setattr__ ( self , key , value ) :
if key == " _data " :
super ( ) . __setattr__ ( key , value )
return
assert (
key not in self . _data
) , f " ` { key } ` already exist, are you sure you want to override it? "
self . _data [ key ] = value
def __getattr__ ( self , item ) :
return self . _data [ item ]
def __delattr__ ( self , item ) :
del self . _data [ item ]
def pop ( self , item ) :
return self . _data . pop ( item )
def update ( self , values : Dict [ str , Any ] ) :
for k , v in values . items ( ) :
setattr ( self , k , v )
2025-05-25 08:39:07 +08:00
def get ( self , item ) :
return self . _data . get ( item )
2025-05-20 12:05:30 +08:00
def clear ( self , expect_keys : Sequence [ str ] ) :
if set ( self . _data . keys ( ) ) != set ( expect_keys ) :
raise Exception (
f " Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. { list ( self . _data . keys ( ) ) =} { expect_keys =} "
)
self . _data . clear ( )
def _convert_operations_to_stages ( operations : List [ Operation ] ) - > List [ Stage ] :
2025-05-25 08:39:07 +08:00
operations = _decorate_operations ( operations )
2025-05-20 12:05:30 +08:00
operation_chunks = list (
_chunk_by_separator ( operations , lambda op : isinstance ( op , YieldOperation ) )
)
assert all ( len ( chunk ) > 0 for chunk in operation_chunks )
return operation_chunks
def _chunk_by_separator (
items : List [ Any ] , is_separator : Callable [ [ Any ] , bool ]
) - > Generator [ List [ Any ] , None , None ] :
pending_items = [ ]
for item in items :
if is_separator ( item ) :
yield pending_items
pending_items = [ ]
else :
pending_items . append ( item )
if len ( pending_items ) > 0 :
yield pending_items
2025-05-25 08:39:07 +08:00
def _decorate_operations ( operations : List [ Operation ] , debug_name_prefix : str = " " ) :
2025-05-20 12:05:30 +08:00
return [ _decorate_operation ( op , debug_name_prefix ) for op in operations ]
def _decorate_operation ( operation : Operation , debug_name_prefix : str ) :
if isinstance ( operation , YieldOperation ) :
return operation
return ExecutionOperation (
debug_name = debug_name_prefix
+ getattr ( operation , " __name__ " , " unknown " ) . replace ( " op_ " , " " ) ,
fn = operation ,
)