Slightly improve the sampler to skip unnecessary steps (#6956)
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_group
|
from sglang.srt.distributed import get_tp_group
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -30,7 +30,7 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||||
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
self.tp_sync_group = get_tp_group().device_group
|
||||||
|
|
||||||
if global_server_args_dict["enable_dp_attention"]:
|
if global_server_args_dict["enable_dp_attention"]:
|
||||||
self.tp_sync_group = get_attention_tp_group().device_group
|
self.tp_sync_group = get_attention_tp_group().device_group
|
||||||
@@ -59,7 +59,7 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Apply the custom logit processors if registered in the sampling info.
|
# Apply the custom logit processors if registered in the sampling info.
|
||||||
if sampling_info.has_custom_logit_processor:
|
if sampling_info.has_custom_logit_processor:
|
||||||
self._apply_custom_logit_processor(logits, sampling_info)
|
apply_custom_logit_processor(logits, sampling_info)
|
||||||
|
|
||||||
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
@@ -81,49 +81,39 @@ class Sampler(nn.Module):
|
|||||||
probs = logits
|
probs = logits
|
||||||
del logits
|
del logits
|
||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if True: # Keep this redundant check to simplify some internal code sync
|
||||||
if return_logprob:
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
if sampling_info.need_min_p_sampling:
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||||
# so we use the torch implementation.
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||||
# NOTE: OpenAI's logprobs is independent of top-p, we use the
|
batch_next_token_ids = min_p_sampling_from_probs(
|
||||||
# same rule.
|
probs, sampling_info.min_ps
|
||||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
)
|
||||||
|
else:
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||||
if sampling_info.need_min_p_sampling:
|
probs,
|
||||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
sampling_info.top_ks,
|
||||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
sampling_info.top_ps,
|
||||||
batch_next_token_ids = min_p_sampling_from_probs(
|
filter_apply_order="joint",
|
||||||
probs, sampling_info.min_ps
|
check_nan=self.use_nan_detection,
|
||||||
)
|
)
|
||||||
else:
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
# Check Nan will throw exception, only check when crash_on_warnings is True
|
# A slower fallback implementation with torch native operations.
|
||||||
check_nan = self.use_nan_detection and crash_on_warnings()
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
probs,
|
||||||
probs.contiguous(),
|
|
||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
filter_apply_order="joint",
|
sampling_info.min_ps,
|
||||||
check_nan=check_nan,
|
sampling_info.need_min_p_sampling,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
if return_logprob:
|
||||||
# A slower fallback implementation with torch native operations.
|
# clamp to avoid -inf
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
probs,
|
|
||||||
sampling_info.top_ks,
|
|
||||||
sampling_info.top_ps,
|
|
||||||
sampling_info.min_ps,
|
|
||||||
sampling_info.need_min_p_sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_logprob:
|
|
||||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Attach logprobs to logits_output (in-place modification)
|
# Attach logprobs to logits_output (in-place modification)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
@@ -160,39 +150,6 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
def _apply_custom_logit_processor(
|
|
||||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
|
||||||
):
|
|
||||||
"""Apply custom logit processors to the logits.
|
|
||||||
This function will modify the logits in-place."""
|
|
||||||
|
|
||||||
assert logits.shape[0] == len(sampling_batch_info), (
|
|
||||||
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
|
||||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, (
|
|
||||||
processor,
|
|
||||||
batch_mask,
|
|
||||||
) in sampling_batch_info.custom_logit_processor.items():
|
|
||||||
# Get the batch indices that need to be processed
|
|
||||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
|
||||||
|
|
||||||
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
|
||||||
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
|
||||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply the processor to the logits
|
|
||||||
logits[batch_mask] = processor(
|
|
||||||
logits[batch_mask],
|
|
||||||
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Custom logit processor {processor.__class__.__name__} is applied."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||||
|
"""A sampling implementation with native pytorch operations, without
|
||||||
|
top-k, top-p, or min-p filtering."""
|
||||||
|
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||||
|
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
||||||
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
def top_p_normalize_probs_torch(
|
def top_p_normalize_probs_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
top_ps: torch.Tensor,
|
top_ps: torch.Tensor,
|
||||||
@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
|||||||
output_token_ids_logprobs_idx.append([])
|
output_token_ids_logprobs_idx.append([])
|
||||||
|
|
||||||
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
||||||
|
|
||||||
|
|
||||||
|
def apply_custom_logit_processor(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_batch_info: SamplingBatchInfo,
|
||||||
|
num_tokens_in_batch: int = 1,
|
||||||
|
):
|
||||||
|
"""Apply custom logit processors to the logits.
|
||||||
|
This function will modify the logits in-place.
|
||||||
|
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
|
||||||
|
tokens. By default, we assume each batch contains only 1 token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, (
|
||||||
|
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||||
|
f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch "
|
||||||
|
f"({num_tokens_in_batch})"
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, (
|
||||||
|
processor,
|
||||||
|
batch_mask,
|
||||||
|
) in sampling_batch_info.custom_logit_processor.items():
|
||||||
|
# Get the batch indices that need to be processed
|
||||||
|
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||||
|
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||||
|
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||||
|
)
|
||||||
|
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
|
||||||
|
|
||||||
|
# Apply the processor to the logits
|
||||||
|
logits[batch_mask] = processor(
|
||||||
|
logits[batch_mask],
|
||||||
|
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Custom logit processor {processor.__class__.__name__} is applied."
|
||||||
|
)
|
||||||
|
|||||||
@@ -852,7 +852,7 @@ class TokenizerManager:
|
|||||||
obj.load_format = self.server_args.load_format
|
obj.load_format = self.server_args.load_format
|
||||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||||
|
|
||||||
if True:
|
if True: # Keep this redundant check to simplify some internal code sync
|
||||||
# Hold the lock if it is not async. This means that weight sync
|
# Hold the lock if it is not async. This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ import torch
|
|||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
|
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
|
|||||||
# Whether all requests use greedy sampling
|
# Whether all requests use greedy sampling
|
||||||
is_all_greedy: bool
|
is_all_greedy: bool
|
||||||
|
|
||||||
|
# Whether any requests use top_p sampling
|
||||||
|
need_top_p_sampling: bool
|
||||||
|
|
||||||
|
# Whether any requests use top_k sampling
|
||||||
|
need_top_k_sampling: bool
|
||||||
|
|
||||||
# Whether any request needs min_p sampling
|
# Whether any request needs min_p sampling
|
||||||
need_min_p_sampling: bool
|
need_min_p_sampling: bool
|
||||||
|
|
||||||
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
|
|||||||
top_ks=top_ks,
|
top_ks=top_ks,
|
||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||||
|
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||||
|
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
penalizer_orchestrator=penalizer_orchestrator,
|
penalizer_orchestrator=penalizer_orchestrator,
|
||||||
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
# Apply the mask
|
# Apply the mask
|
||||||
for i, grammar in enumerate(self.grammars):
|
for i, grammar in enumerate(self.grammars):
|
||||||
if grammar and not grammar.finished:
|
if grammar and not grammar.finished and not grammar.is_terminated():
|
||||||
grammar.fill_vocab_mask(self.vocab_mask, i)
|
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||||
|
|
||||||
# Move the mask to the device if needed
|
# Move the mask to the device if needed
|
||||||
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
|
|||||||
setattr(self, item, torch.cat([self_val, other_val]))
|
setattr(self, item, torch.cat([self_val, other_val]))
|
||||||
|
|
||||||
self.is_all_greedy &= other.is_all_greedy
|
self.is_all_greedy &= other.is_all_greedy
|
||||||
|
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||||
|
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-6
|
_SAMPLING_EPS = 1e-6
|
||||||
|
TOP_K_ALL = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
@@ -84,7 +85,7 @@ class SamplingParams:
|
|||||||
self.temperature = 1.0
|
self.temperature = 1.0
|
||||||
self.top_k = 1
|
self.top_k = 1
|
||||||
if self.top_k == -1:
|
if self.top_k == -1:
|
||||||
self.top_k = 1 << 30 # whole vocabulary
|
self.top_k = TOP_K_ALL # whole vocabulary
|
||||||
|
|
||||||
def verify(self):
|
def verify(self):
|
||||||
if self.temperature < 0.0:
|
if self.temperature < 0.0:
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.communicator import (
|
from sglang.srt.layers.communicator import (
|
||||||
CommunicateContext,
|
CommunicateContext,
|
||||||
CommunicateSimpleFn,
|
|
||||||
CommunicateSummableTensorPairFn,
|
CommunicateSummableTensorPairFn,
|
||||||
ScatterMode,
|
ScatterMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati
|
|||||||
from sglang.srt.operations_strategy import OperationsStrategy
|
from sglang.srt.operations_strategy import OperationsStrategy
|
||||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
||||||
|
|
||||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -46,7 +41,7 @@ def compute_split_seq_index(
|
|||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||||
|
|||||||
@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
|
|||||||
setattr(triton, "next_power_of_2", next_power_of_2)
|
setattr(triton, "next_power_of_2", next_power_of_2)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
class EmptyContextManager:
|
||||||
def empty_context(*args, **kwargs):
|
def __enter__(self):
|
||||||
try:
|
return self
|
||||||
# Setup code goes here
|
|
||||||
yield
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
finally:
|
|
||||||
# Cleanup code goes here
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def empty_context(*args, **kwargs):
|
||||||
|
return EmptyContextManager()
|
||||||
|
|
||||||
|
|
||||||
def add_prefix(name: str, prefix: str) -> str:
|
def add_prefix(name: str, prefix: str) -> str:
|
||||||
"""Add a weight path prefix to a module name.
|
"""Add a weight path prefix to a module name.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user