Fix dependency (#3813)

This commit is contained in:
Lianmin Zheng
2025-02-24 03:50:58 -08:00
committed by GitHub
parent c979580817
commit 27a46317b6
6 changed files with 43 additions and 31 deletions

View File

@@ -28,17 +28,11 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarObject,
)
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
if is_hip_:
try:
from outlines.fsm.json_schema import build_regex_from_schema
except ImportError:
from outlines_core.fsm.json_schema import build_regex_from_schema
else:
try:
from outlines.fsm.json_schema import build_regex_from_schema
except ImportError:
from outlines_core.fsm.json_schema import build_regex_from_schema
logger = logging.getLogger(__name__)

View File

@@ -29,7 +29,7 @@ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group
if global_server_args_dict["enable_dp_attention"]:
@@ -48,7 +48,7 @@ class Sampler(nn.Module):
if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
@@ -97,7 +97,7 @@ class Sampler(nn.Module):
filter_apply_order="joint",
)
if self.use_nan_detectioin and not torch.all(success):
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)

View File

@@ -162,12 +162,9 @@ class ServerArgs:
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False
# Custom logit processor
enable_custom_logit_processor: bool = False
tool_call_parser: str = None
enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
def __post_init__(self):
@@ -918,7 +915,6 @@ class ServerArgs:
action="store_true",
help="Return hidden states in the response.",
)
# Function Calling
parser.add_argument(
"--tool-call-parser",
type=str,