Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,13 +43,15 @@ class LoadConfig:
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
|
||||
@@ -44,6 +44,7 @@ class ModelConfig:
|
||||
is_embedding: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
@@ -51,11 +52,16 @@ class ModelConfig:
|
||||
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
@@ -64,6 +70,9 @@ class ModelConfig:
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
|
||||
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
|
||||
self.is_audio_model = is_audio_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
@@ -71,7 +80,9 @@ class ModelConfig:
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
|
||||
if get_bool_env_var(
|
||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
|
||||
):
|
||||
logger.warning(
|
||||
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||
f"This may lead to incorrect model outputs or CUDA errors."
|
||||
@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "Grok1VForCausalLM" in model_architectures
|
||||
or "Grok1AForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
||||
@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_image_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_audio_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
Reference in New Issue
Block a user