[Minor] make the __init__ function of model_runner.py shorter (#4132)

This commit is contained in:
Lianmin Zheng
2025-03-06 01:51:12 -08:00
committed by GitHub
parent fcc2e37f69
commit 98c73d71cb
2 changed files with 119 additions and 102 deletions

View File

@@ -427,7 +427,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch):
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size

View File

@@ -122,34 +122,115 @@ class ModelRunner:
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
# Model-specific adjustment
self.model_specific_adjustment()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_outlines_disk_cache:
from outlines.caching import disable_cache
disable_cache()
# Global vars
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
}
)
# CPU offload
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
# Load the model
self.sampler = Sampler()
self.load_model()
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
# Init lora
if server_args.lora_paths is not None:
self.init_lora_manager()
# Init memory pool and attention backends
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests,
server_args.max_total_tokens,
)
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
self.init_cuda_graphs()
else:
self.cuda_graph_runner = None
self.init_attention_backend()
def model_specific_adjustment(self):
server_args = self.server_args
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
and not server_args.disable_mla
):
# TODO: add MLA optimization on CPU
if self.server_args.device != "cpu":
if server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
logger.info(
"MLA optimization is turned on. Use flashinfer mla backend."
)
self.server_args.attention_backend = "flashinfer_mla"
server_args.attention_backend = "flashinfer_mla"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
if server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self.server_args.attention_backend = "triton"
self.server_args.disable_cuda_graph = True
if self.server_args.ds_heavy_channel_type is None:
server_args.attention_backend = "triton"
server_args.disable_cuda_graph = True
if server_args.ds_heavy_channel_type is None:
raise ValueError(
"Please specify the heavy channel type for double sparsity optimization."
)
self.init_double_sparsity_channel_config(
self.server_args.ds_heavy_channel_type
)
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal:
self.mem_fraction_static *= 0.95
@@ -174,96 +255,10 @@ class ModelRunner:
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_outlines_disk_cache:
from outlines.caching import disable_cache
disable_cache()
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
}
)
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
# Load the model
self.sampler = Sampler()
self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
# Init memory pool and attention backends
if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests,
server_args.max_total_tokens,
)
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
self.init_cuda_graphs()
else:
self.cuda_graph_runner = None
self.init_attention_backend()
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
torch.get_device_module(self.device).set_device(self.gpu_id)
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"
elif self.device == "xpu":
@@ -400,6 +395,18 @@ class ModelRunner:
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
)
# Handle the case where some ranks do not finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:
@@ -772,6 +779,10 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
@@ -880,18 +891,24 @@ class ModelRunner:
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
def forward(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if (
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
):
return self.cuda_graph_runner.replay(forward_batch)
return self.cuda_graph_runner.replay(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
return self.forward_extend(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else: