[Minor] more code cleanup (#4077)
This commit is contained in:
@@ -238,6 +238,9 @@ class CudaGraphRunner:
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -266,9 +269,9 @@ class CudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||
forward_batch.global_num_tokens
|
||||
)
|
||||
min_num_tokens, max_num_tokens = min(
|
||||
forward_batch.global_num_tokens_cpu
|
||||
), max(forward_batch.global_num_tokens_cpu)
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||
if self.disable_padding
|
||||
@@ -360,7 +363,7 @@ class CudaGraphRunner:
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens=global_num_tokens,
|
||||
global_num_tokens_cpu=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
@@ -430,7 +433,7 @@ class CudaGraphRunner:
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, max(forward_batch.global_num_tokens)
|
||||
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
|
||||
@@ -190,7 +190,16 @@ class ForwardBatch:
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_cpu: Optional[List[int]] = None
|
||||
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
||||
# Has to be None when cuda graph is captured.
|
||||
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||
# for extend, local start pos and num tokens is different in logits processor
|
||||
# this will be computed in get_dp_local_info
|
||||
# this will be recomputed in LogitsMetadata.from_forward_batch
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
@@ -234,7 +243,6 @@ class ForwardBatch:
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
@@ -248,8 +256,9 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
max_len = max(ret.global_num_tokens)
|
||||
if batch.global_num_tokens is not None:
|
||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||
max_len = max(ret.global_num_tokens_cpu)
|
||||
ret.gathered_buffer = torch.zeros(
|
||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
@@ -269,6 +268,7 @@ class ModelRunner:
|
||||
elif self.device == "cpu":
|
||||
backend = "gloo"
|
||||
|
||||
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_p2p_access_check()
|
||||
|
||||
@@ -299,20 +299,24 @@ class ModelRunner:
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.tp_group = get_tp_group()
|
||||
self.attention_tp_group = get_attention_tp_group()
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||
raise ValueError(
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
||||
)
|
||||
return min_per_gpu_memory
|
||||
|
||||
def load_model(self):
|
||||
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
@@ -382,11 +386,13 @@ class ModelRunner:
|
||||
)
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
f"avail mem={after_avail_memory:.2f} GB, "
|
||||
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
||||
)
|
||||
|
||||
def update_weights_from_disk(
|
||||
@@ -785,12 +791,15 @@ class ModelRunner:
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
||||
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
@@ -806,8 +815,12 @@ class ModelRunner:
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
def forward_extend(
|
||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
||||
):
|
||||
if not skip_attn_backend_init:
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
if self.is_generation:
|
||||
if forward_batch.input_embeds is None:
|
||||
return self.model.forward(
|
||||
|
||||
Reference in New Issue
Block a user