[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_dp_buffer,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
|
||||
context: CommunicateContext,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(
|
||||
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
):
|
||||
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||
residual, local_residual = (
|
||||
torch.empty_like(
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
||||
),
|
||||
get_local_dp_buffer(),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(residual, local_residual)
|
||||
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
residual = hidden_states
|
||||
hidden_states = layernorm(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
torch.empty_like(forward_batch.gathered_buffer),
|
||||
get_global_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
_ATTN_TP_SIZE = None
|
||||
_ATTN_DP_RANK = None
|
||||
_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_RANK = None
|
||||
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
|
||||
_ATTN_TP_RANK: Optional[int] = None
|
||||
_ATTN_TP_SIZE: Optional[int] = None
|
||||
_ATTN_DP_RANK: Optional[int] = None
|
||||
_ATTN_DP_SIZE: Optional[int] = None
|
||||
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
||||
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
||||
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
||||
|
||||
|
||||
class DPPaddingMode(IntEnum):
|
||||
class DpPaddingMode(IntEnum):
|
||||
|
||||
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
||||
MAX_LEN = auto()
|
||||
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
|
||||
SUM_LEN = auto()
|
||||
|
||||
def is_max_len(self):
|
||||
return self == DPPaddingMode.MAX_LEN
|
||||
return self == DpPaddingMode.MAX_LEN
|
||||
|
||||
def is_sum_len(self):
|
||||
return self == DPPaddingMode.SUM_LEN
|
||||
return self == DpPaddingMode.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
||||
# we choose the mode that minimizes the communication cost
|
||||
max_len = max(global_num_tokens)
|
||||
sum_len = sum(global_num_tokens)
|
||||
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
|
||||
return cls.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
||||
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
||||
return cls.MAX_LEN
|
||||
|
||||
|
||||
class _DpGatheredBufferWrapper:
|
||||
|
||||
_hidden_size: int
|
||||
_dtype: torch.dtype
|
||||
_device: torch.device
|
||||
_global_dp_buffer_len: int
|
||||
_local_dp_buffer_len: int
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||
cls._hidden_size = hidden_size
|
||||
cls._dtype = dtype
|
||||
cls._device = device
|
||||
|
||||
@classmethod
|
||||
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
cls._global_dp_buffer_len = global_dp_buffer_len
|
||||
cls._local_dp_buffer_len = local_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer(cls) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(cls._global_dp_buffer_len, cls._hidden_size),
|
||||
dtype=cls._dtype,
|
||||
device=cls._device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_local_dp_buffer(cls) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(cls._local_dp_buffer_len, cls._hidden_size),
|
||||
dtype=cls._dtype,
|
||||
device=cls._device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer_len(cls) -> int:
|
||||
return cls._global_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_local_dp_buffer_len(cls) -> int:
|
||||
return cls._local_dp_buffer_len
|
||||
|
||||
|
||||
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
||||
global_dp_buffer_len, local_dp_buffer_len
|
||||
)
|
||||
|
||||
|
||||
def get_global_dp_buffer() -> torch.Tensor:
|
||||
return _DpGatheredBufferWrapper.get_global_dp_buffer()
|
||||
|
||||
|
||||
def get_local_dp_buffer() -> torch.Tensor:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer()
|
||||
|
||||
|
||||
def get_global_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
|
||||
|
||||
|
||||
def get_local_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
|
||||
|
||||
|
||||
def initialize_dp_attention(
|
||||
enable_dp_attention: bool,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
moe_dense_tp_size: int,
|
||||
pp_size: int,
|
||||
server_args: ServerArgs,
|
||||
model_config: ModelConfig,
|
||||
):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
||||
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
||||
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
|
||||
|
||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||
|
||||
enable_dp_attention = server_args.enable_dp_attention
|
||||
tp_size = server_args.tp_size
|
||||
dp_size = server_args.dp_size
|
||||
moe_dense_tp_size = server_args.moe_dense_tp_size
|
||||
pp_size = server_args.pp_size
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
@@ -135,38 +212,48 @@ def initialize_dp_attention(
|
||||
group_name="attention_tp",
|
||||
)
|
||||
|
||||
_DpGatheredBufferWrapper.set_metadata(
|
||||
hidden_size=model_config.hidden_size,
|
||||
dtype=model_config.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
def get_attention_tp_group():
|
||||
|
||||
def is_dp_attention_enabled() -> bool:
|
||||
return _ENABLE_DP_ATTENTION_FLAG
|
||||
|
||||
|
||||
def get_attention_tp_group() -> GroupCoordinator:
|
||||
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_GROUP
|
||||
|
||||
|
||||
def get_attention_tp_rank():
|
||||
def get_attention_tp_rank() -> int:
|
||||
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_RANK
|
||||
|
||||
|
||||
def get_attention_tp_size():
|
||||
def get_attention_tp_size() -> int:
|
||||
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_SIZE
|
||||
|
||||
|
||||
def get_attention_dp_rank():
|
||||
def get_attention_dp_rank() -> int:
|
||||
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_attention_dp_size():
|
||||
def get_attention_dp_size() -> int:
|
||||
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_SIZE
|
||||
|
||||
|
||||
def get_local_attention_dp_rank():
|
||||
def get_local_attention_dp_rank() -> int:
|
||||
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_local_attention_dp_size():
|
||||
def get_local_attention_dp_size() -> int:
|
||||
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
DpPaddingMode,
|
||||
attn_tp_all_gather,
|
||||
attn_tp_all_gather_into_tensor,
|
||||
dp_gather_replicate,
|
||||
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_attention_dp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -108,14 +110,12 @@ class LogitsMetadata:
|
||||
# The start position of local hidden states.
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
# Buffer to gather logits from all ranks.
|
||||
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
||||
global_dp_buffer_len: Optional[int] = None
|
||||
# Number of tokens to sample per DP rank
|
||||
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
||||
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||
# The gather mode for DP attention
|
||||
dp_padding_mode: Optional[DPPaddingMode] = None
|
||||
dp_padding_mode: Optional[DpPaddingMode] = None
|
||||
# for padding
|
||||
padded_static_len: int = -1
|
||||
|
||||
@@ -164,11 +164,10 @@ class LogitsMetadata:
|
||||
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
||||
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
||||
gathered_buffer=forward_batch.gathered_buffer,
|
||||
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
||||
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
|
||||
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
||||
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.SUM_LEN,
|
||||
dp_padding_mode=DpPaddingMode.SUM_LEN,
|
||||
)
|
||||
|
||||
def compute_dp_attention_metadata(self):
|
||||
@@ -188,16 +187,11 @@ class LogitsMetadata:
|
||||
|
||||
if self.global_num_tokens_for_logprob_cpu is not None:
|
||||
# create a smaller buffer to reduce peak memory usage
|
||||
self.gathered_buffer = torch.empty(
|
||||
(
|
||||
sum(self.global_num_tokens_for_logprob_cpu),
|
||||
self.gathered_buffer.shape[1],
|
||||
),
|
||||
dtype=self.gathered_buffer.dtype,
|
||||
device=self.gathered_buffer.device,
|
||||
)
|
||||
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
|
||||
else:
|
||||
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
|
||||
self.global_dp_buffer_len = self.global_dp_buffer_len
|
||||
|
||||
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits_metadata.compute_dp_attention_metadata()
|
||||
hidden_states, local_hidden_states = (
|
||||
logits_metadata.gathered_buffer,
|
||||
get_global_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
@@ -6,7 +6,10 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
|
||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tp_group().device_group
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
if is_dp_attention_enabled():
|
||||
self.tp_sync_group = get_attention_tp_group().device_group
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"device",
|
||||
"disable_chunked_prefix_cache",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_attention",
|
||||
"enable_two_batch_overlap",
|
||||
"tbo_token_distribution_threshold",
|
||||
"enable_dp_lm_head",
|
||||
|
||||
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
DpPaddingMode,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
@@ -349,30 +350,15 @@ class CudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
self.custom_mask = torch.ones(
|
||||
(
|
||||
@@ -556,7 +542,7 @@ class CudaGraphRunner:
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
@@ -572,9 +558,9 @@ class CudaGraphRunner:
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
else:
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
|
||||
spec_info = self.get_spec_info(num_tokens)
|
||||
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||
@@ -607,8 +593,8 @@ class CudaGraphRunner:
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
@@ -637,6 +623,7 @@ class CudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
kwargs = {}
|
||||
if (
|
||||
|
||||
@@ -40,9 +40,10 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
DpPaddingMode,
|
||||
get_attention_dp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import (
|
||||
@@ -274,13 +275,13 @@ class ForwardBatch:
|
||||
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||
# The padding mode for DP attention
|
||||
dp_padding_mode: Optional[DPPaddingMode] = None
|
||||
dp_padding_mode: Optional[DpPaddingMode] = None
|
||||
# for extend, local start pos and num tokens is different in logits processor
|
||||
# this will be computed in get_dp_local_info
|
||||
# this will be recomputed in LogitsMetadata.from_forward_batch
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
global_dp_buffer_len: Optional[int] = None
|
||||
is_extend_in_batch: bool = False
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
@@ -628,7 +629,7 @@ class ForwardBatch:
|
||||
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
||||
) * attn_tp_size
|
||||
|
||||
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
|
||||
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
|
||||
self.dp_padding_mode = dp_padding_mode
|
||||
|
||||
if dp_padding_mode.is_max_len():
|
||||
@@ -642,17 +643,14 @@ class ForwardBatch:
|
||||
else:
|
||||
buffer_len = sum(global_num_tokens)
|
||||
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(buffer_len, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
device=model_runner.device,
|
||||
)
|
||||
|
||||
if len(global_num_tokens) > 1:
|
||||
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
||||
else:
|
||||
num_tokens = global_num_tokens[0]
|
||||
|
||||
self.global_dp_buffer_len = buffer_len
|
||||
set_dp_buffer_len(buffer_len, num_tokens)
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
|
||||
@@ -603,12 +603,8 @@ class ModelRunner:
|
||||
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||
)
|
||||
initialize_dp_attention(
|
||||
enable_dp_attention=self.server_args.enable_dp_attention,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||
pp_size=self.server_args.pp_size,
|
||||
server_args=self.server_args,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
||||
self.layer_id = layer_id
|
||||
self.is_nextn = is_nextn
|
||||
@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
should_allreduce_fusion = (
|
||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
||||
and not (
|
||||
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
|
||||
)
|
||||
and not self.is_nextn
|
||||
)
|
||||
|
||||
@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
)
|
||||
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||
self.layers = nn.ModuleList(
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
||||
)
|
||||
rms_norm_eps = config.rms_norm_eps
|
||||
attention_bias = config.attention_bias
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
)
|
||||
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||
self.layers = nn.ModuleList(
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -565,7 +566,7 @@ class GptOssModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
@@ -466,7 +466,7 @@ class Llama4Model(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
|
||||
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import set_dp_buffer_len
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
|
||||
|
||||
if _ENABLE_PROFILE:
|
||||
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
|
||||
|
||||
|
||||
class _StageExecutor:
|
||||
def __init__(self, debug_name: str, stages: List[Stage], inputs):
|
||||
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
|
||||
self._debug_name = debug_name
|
||||
self._stages = stages
|
||||
self._index = 0
|
||||
self._stage_state = _StateDict()
|
||||
self._stage_output = inputs
|
||||
|
||||
# handling DP attention
|
||||
forward_batch: ForwardBatch = inputs["forward_batch"]
|
||||
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
||||
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
|
||||
|
||||
def next(self):
|
||||
assert not self.done
|
||||
|
||||
stage = self._stages[self._index]
|
||||
|
||||
if self._global_dp_buffer_len is not None:
|
||||
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
|
||||
|
||||
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
||||
for op in stage:
|
||||
with _annotate_region(debug_name=op.debug_name):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
global_num_tokens_for_logprob = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=(
|
||||
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
if hasattr(
|
||||
self.model_runner.model_config.hf_config, "draft_vocab_size"
|
||||
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
else:
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
@@ -678,16 +678,12 @@ class TboForwardBatchPreparer:
|
||||
# TODO improve, e.g. unify w/ `init_raw`
|
||||
if (
|
||||
global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
and batch.gathered_buffer is not None
|
||||
and batch.global_dp_buffer_len is not None
|
||||
):
|
||||
sum_len = end_token_index - start_token_index
|
||||
gathered_buffer = torch.zeros(
|
||||
(sum_len, batch.gathered_buffer.shape[1]),
|
||||
dtype=batch.gathered_buffer.dtype,
|
||||
device=batch.gathered_buffer.device,
|
||||
)
|
||||
global_dp_buffer_len = sum_len
|
||||
else:
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
|
||||
output_dict.update(
|
||||
dict(
|
||||
@@ -706,7 +702,7 @@ class TboForwardBatchPreparer:
|
||||
global_num_tokens_gpu=None,
|
||||
global_num_tokens_cpu=None,
|
||||
dp_padding_mode=None,
|
||||
gathered_buffer=gathered_buffer,
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
global_num_tokens_for_logprob_gpu=None,
|
||||
global_num_tokens_for_logprob_cpu=None,
|
||||
sampling_info=None,
|
||||
|
||||
Reference in New Issue
Block a user