[Minor] more code cleanup (#4077)

This commit is contained in:
Lianmin Zheng
2025-03-04 21:23:47 -08:00
committed by GitHub
parent 4725e3f652
commit e074d84e5b
15 changed files with 123 additions and 31 deletions

View File

@@ -238,6 +238,9 @@ class CudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
@@ -266,9 +269,9 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
@@ -360,7 +363,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens=global_num_tokens,
global_num_tokens_cpu=global_num_tokens,
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
@@ -430,7 +433,7 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)

View File

@@ -190,7 +190,16 @@ class ForwardBatch:
attn_backend: AttentionBackend = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_cpu: Optional[List[int]] = None
global_num_tokens_gpu: Optional[torch.Tensor] = None
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
@@ -234,7 +243,6 @@ class ForwardBatch:
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
@@ -248,8 +256,9 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,

View File

@@ -13,7 +13,6 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import json
@@ -269,6 +268,7 @@ class ModelRunner:
elif self.device == "cpu":
backend = "gloo"
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
monkey_patch_p2p_access_check()
@@ -299,20 +299,24 @@ class ModelRunner:
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
logger.info(
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
)
return min_per_gpu_memory
def load_model(self):
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
@@ -382,11 +386,13 @@ class ModelRunner:
)
self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
)
def update_weights_from_disk(
@@ -785,12 +791,15 @@ class ModelRunner:
return
tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
def apply_torch_tp(self):
@@ -806,8 +815,12 @@ class ModelRunner:
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
def forward_extend(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
):
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
if forward_batch.input_embeds is None:
return self.model.forward(