[Minor] make the __init__ function of model_runner.py shorter (#4132)
This commit is contained in:
@@ -427,7 +427,7 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
self.recapture_if_needed(forward_batch)
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
@@ -122,34 +122,115 @@ class ModelRunner:
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
if server_args.disable_outlines_disk_cache:
|
||||
from outlines.caching import disable_cache
|
||||
|
||||
disable_cache()
|
||||
|
||||
# Global vars
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"attention_backend": server_args.attention_backend,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
}
|
||||
)
|
||||
|
||||
# CPU offload
|
||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Apply torchao quantization
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
if not torchao_applied:
|
||||
apply_torchao_config_to_model(
|
||||
self.model, global_server_args_dict["torchao_config"]
|
||||
)
|
||||
|
||||
# Apply torch TP if the model supports it
|
||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||
if self.tp_size > 1 and supports_torch_tp:
|
||||
self.apply_torch_tp()
|
||||
self.torch_tp_applied = True
|
||||
else:
|
||||
self.torch_tp_applied = False
|
||||
|
||||
# Init lora
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
|
||||
# Init memory pool and attention backends
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_running_requests,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
if self.device == "cuda":
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.cuda_graph_runner = None
|
||||
self.init_attention_backend()
|
||||
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
and not server_args.disable_mla
|
||||
):
|
||||
# TODO: add MLA optimization on CPU
|
||||
if self.server_args.device != "cpu":
|
||||
if server_args.device != "cpu":
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
self.server_args.attention_backend = "flashinfer_mla"
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
server_args.attention_backend = "triton"
|
||||
|
||||
if self.server_args.enable_double_sparsity:
|
||||
if server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
||||
)
|
||||
self.server_args.attention_backend = "triton"
|
||||
self.server_args.disable_cuda_graph = True
|
||||
if self.server_args.ds_heavy_channel_type is None:
|
||||
server_args.attention_backend = "triton"
|
||||
server_args.disable_cuda_graph = True
|
||||
if server_args.ds_heavy_channel_type is None:
|
||||
raise ValueError(
|
||||
"Please specify the heavy channel type for double sparsity optimization."
|
||||
)
|
||||
self.init_double_sparsity_channel_config(
|
||||
self.server_args.ds_heavy_channel_type
|
||||
)
|
||||
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
||||
|
||||
if self.is_multimodal:
|
||||
self.mem_fraction_static *= 0.95
|
||||
@@ -174,96 +255,10 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = -1
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_outlines_disk_cache:
|
||||
from outlines.caching import disable_cache
|
||||
|
||||
disable_cache()
|
||||
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"attention_backend": server_args.attention_backend,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
}
|
||||
)
|
||||
|
||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Handle the case where some of models don't finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
# Apply torchao quantization
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
if not torchao_applied:
|
||||
apply_torchao_config_to_model(
|
||||
self.model, global_server_args_dict["torchao_config"]
|
||||
)
|
||||
|
||||
# Apply torch TP if the model supports it
|
||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||
if self.tp_size > 1 and supports_torch_tp:
|
||||
self.apply_torch_tp()
|
||||
self.torch_tp_applied = True
|
||||
else:
|
||||
self.torch_tp_applied = False
|
||||
|
||||
# Init memory pool and attention backends
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_running_requests,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
if self.device == "cuda":
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.cuda_graph_runner = None
|
||||
self.init_attention_backend()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
elif self.device == "xpu":
|
||||
@@ -400,6 +395,18 @@ class ModelRunner:
|
||||
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
||||
)
|
||||
|
||||
# Handle the case where some ranks do not finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
def update_weights_from_disk(
|
||||
self, model_path: str, load_format: str
|
||||
) -> tuple[bool, str]:
|
||||
@@ -772,6 +779,10 @@ class ModelRunner:
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
@@ -880,18 +891,24 @@ class ModelRunner:
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
def forward(
|
||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
||||
) -> LogitsProcessorOutput:
|
||||
if (
|
||||
forward_batch.forward_mode.is_cuda_graph()
|
||||
and self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(forward_batch)
|
||||
):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
return self.cuda_graph_runner.replay(
|
||||
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
||||
)
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
return self.forward_extend(forward_batch)
|
||||
return self.forward_extend(
|
||||
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
||||
)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
return self.forward_idle(forward_batch)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user