Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -105,6 +105,7 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = None
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
|
||||
# Batch sizes to capture
|
||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||
@@ -132,6 +133,9 @@ class CudaGraphRunner:
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
@@ -144,9 +148,18 @@ class CudaGraphRunner:
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||
self.encoder_lens = torch.full(
|
||||
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
with self.model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
@@ -157,11 +170,32 @@ class CudaGraphRunner:
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
def can_run(self, batch_size: int):
|
||||
if self.disable_padding:
|
||||
return batch_size in self.graphs
|
||||
else:
|
||||
return batch_size <= self.max_bs
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||
# because the full_text_row_masked_out_mask tensor will always be ones
|
||||
is_encoder_lens_supported = (
|
||||
torch.all(forward_batch.encoder_lens > 0)
|
||||
if self.is_encoder_decoder
|
||||
else True
|
||||
)
|
||||
return is_bs_supported and is_encoder_lens_supported
|
||||
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
@@ -188,11 +222,19 @@ class CudaGraphRunner:
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
if self.is_encoder_decoder:
|
||||
encoder_lens = self.encoder_lens[:bs]
|
||||
else:
|
||||
encoder_lens = None
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs, req_pool_indices, seq_lens
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
@@ -208,6 +250,7 @@ class CudaGraphRunner:
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
@@ -251,6 +294,8 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -258,6 +303,7 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
self.encoder_lens,
|
||||
)
|
||||
|
||||
# Replay
|
||||
|
||||
@@ -108,6 +108,12 @@ class ForwardBatch:
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
|
||||
# Encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
@@ -194,6 +200,11 @@ class ForwardBatch:
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
image_inputs=batch.image_inputs,
|
||||
encoder_cached=batch.encoder_cached,
|
||||
encoder_lens=batch.encoder_lens,
|
||||
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
@@ -212,11 +223,11 @@ class ForwardBatch:
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
@@ -270,7 +270,6 @@ class ModelRunner:
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
@@ -510,7 +509,7 @@ class ModelRunner:
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.has_cross_attention, (
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
@@ -558,9 +557,7 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
forward_batch.batch_size
|
||||
):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user