Rename arguments --disable-nan-detection to --enable-nan-detection (#2066)
This commit is contained in:
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
|
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ global_server_args_dict = {
|
|||||||
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
||||||
"disable_mla": ServerArgs.disable_mla,
|
"disable_mla": ServerArgs.disable_mla,
|
||||||
"torchao_config": ServerArgs.torchao_config,
|
"torchao_config": ServerArgs.torchao_config,
|
||||||
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class ModelRunner:
|
|||||||
"disable_mla": server_args.disable_mla,
|
"disable_mla": server_args.disable_mla,
|
||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
"disable_penalizer": server_args.disable_penalizer,
|
"disable_penalizer": server_args.disable_penalizer,
|
||||||
"disable_nan_detection": server_args.disable_nan_detection,
|
"enable_nan_detection": server_args.enable_nan_detection,
|
||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -276,6 +276,10 @@ class ModelRunner:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.dtype = self.vllm_model_config.dtype
|
self.dtype = self.vllm_model_config.dtype
|
||||||
|
if self.sliding_window_size:
|
||||||
|
assert (
|
||||||
|
self.server_args.attention_backend == "flashinfer"
|
||||||
|
), "Only flashinfer supports window attention."
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight end. "
|
f"Load weight end. "
|
||||||
|
|||||||
@@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
# Gemma does not apply LoRA to the embedding layer.
|
# Gemma does not apply LoRA to the embedding layer.
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class ServerArgs:
|
|||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
disable_mla: bool = False
|
disable_mla: bool = False
|
||||||
disable_penalizer: bool = False
|
disable_penalizer: bool = False
|
||||||
disable_nan_detection: bool = False
|
|
||||||
enable_overlap_schedule: bool = False
|
enable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
@@ -132,6 +131,7 @@ class ServerArgs:
|
|||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: int = 160
|
cuda_graph_max_bs: int = 160
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
|
enable_nan_detection: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
num_continuous_decode_steps: int = 1
|
num_continuous_decode_steps: int = 1
|
||||||
@@ -171,11 +171,11 @@ class ServerArgs:
|
|||||||
else:
|
else:
|
||||||
gpu_mem = get_nvgpu_memory_capacity()
|
gpu_mem = get_nvgpu_memory_capacity()
|
||||||
if gpu_mem < 25000:
|
if gpu_mem < 25000:
|
||||||
|
self.chunked_prefill_size //= 4 # make it 2048
|
||||||
|
self.cuda_graph_max_bs = 4
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Automatically adjust --chunked-prefill-size for small GPUs."
|
"Automatically adjust --chunked-prefill-size for small GPUs."
|
||||||
)
|
)
|
||||||
self.chunked_prefill_size //= 4 # make it 2048
|
|
||||||
self.cuda_graph_max_bs = 4
|
|
||||||
|
|
||||||
if not is_flashinfer_available():
|
if not is_flashinfer_available():
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
@@ -194,7 +194,7 @@ class ServerArgs:
|
|||||||
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
|
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
|
||||||
self.enable_overlap_schedule = False
|
self.enable_overlap_schedule = False
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||||
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
|
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
|
||||||
"Data parallel size is adjusted to be the same as tensor parallel size."
|
"Data parallel size is adjusted to be the same as tensor parallel size."
|
||||||
)
|
)
|
||||||
@@ -204,21 +204,8 @@ class ServerArgs:
|
|||||||
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
||||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||||
"and embedding APIs are not supported and will lead to wrong results. "
|
"and embedding APIs are not supported and will lead to wrong results. "
|
||||||
"The NaN detection is also disabled."
|
|
||||||
)
|
)
|
||||||
self.disable_penalizer = True
|
self.disable_penalizer = True
|
||||||
self.disable_nan_detection = True
|
|
||||||
|
|
||||||
# Model-specific patches
|
|
||||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
|
||||||
logger.info(
|
|
||||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
|
||||||
)
|
|
||||||
self.trust_remote_code = False
|
|
||||||
|
|
||||||
if "gemma-2" in self.model_path.lower():
|
|
||||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
|
||||||
self.attention_backend = "flashinfer"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -683,6 +670,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.torchao_config,
|
default=ServerArgs.torchao_config,
|
||||||
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
|
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-nan-detection",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable the NaN detection for debugging purposes.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-p2p-check",
|
"--enable-p2p-check",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user