Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
@dataclass
@@ -42,13 +43,15 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}

View File

@@ -44,6 +44,7 @@ class ModelConfig:
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
@@ -51,11 +52,16 @@ class ModelConfig:
# Parse args
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
@@ -64,6 +70,9 @@ class ModelConfig:
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
@@ -71,7 +80,9 @@ class ModelConfig:
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "Grok1VForCausalLM" in model_architectures
or "Grok1AForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures