Qwen2vl support cuda graph and disable radix cache (#1780)
This commit is contained in:
@@ -152,6 +152,7 @@ class CudaGraphRunner:
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||
@@ -233,6 +234,7 @@ class CudaGraphRunner:
|
||||
encoder_lens = None
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
@@ -259,6 +261,7 @@ class CudaGraphRunner:
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
|
||||
@@ -301,6 +304,8 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -142,11 +142,12 @@ class ForwardBatch:
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
self.extend_start_loc[i],
|
||||
self.extend_seq_lens[i],
|
||||
self.extend_prefix_lens[i],
|
||||
extend_start_loc_cpu[i],
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
)
|
||||
if image_inputs is None:
|
||||
# text only
|
||||
@@ -160,20 +161,16 @@ class ForwardBatch:
|
||||
] * 3
|
||||
mrope_position_delta = 0
|
||||
else:
|
||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
].tolist(),
|
||||
],
|
||||
image_grid_thw=image_inputs.image_grid_thws,
|
||||
video_grid_thw=None,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
extend_prefix_len=extend_prefix_len.item(),
|
||||
)
|
||||
)
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
@@ -125,11 +125,11 @@ class ModelRunner:
|
||||
)
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
|
||||
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
||||
if self.model_config.hf_config.architectures == [
|
||||
"Qwen2VLForConditionalGeneration"
|
||||
]:
|
||||
server_args.disable_cuda_graph = True
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
|
||||
Reference in New Issue
Block a user