Fix dependency (#3813)
This commit is contained in:
@@ -28,17 +28,11 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
except ImportError:
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
else:
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
except ImportError:
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +29,7 @@ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
@@ -48,7 +48,7 @@ class Sampler(nn.Module):
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
self._apply_custom_logit_processor(logits, sampling_info)
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
@@ -97,7 +97,7 @@ class Sampler(nn.Module):
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
if self.use_nan_detectioin and not torch.all(success):
|
||||
if self.use_nan_detection and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
|
||||
|
||||
@@ -162,12 +162,9 @@ class ServerArgs:
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# Custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
tool_call_parser: str = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
|
||||
enable_flashinfer_mla: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -918,7 +915,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Return hidden states in the response.",
|
||||
)
|
||||
# Function Calling
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user