upgrade flashinfer v0.2.0.post2 (#3288)
Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,7 @@ runtime_common = [
|
||||
srt = [
|
||||
"sglang[runtime_common]", "cuda-python",
|
||||
"sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1",
|
||||
"flashinfer==0.1.6", "outlines>=0.0.44,<0.1.0"
|
||||
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<0.1.0"
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
|
||||
@@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Check flashinfer version
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.6",
|
||||
"flashinfer_python",
|
||||
"0.2.0.post2",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
@@ -149,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
backend="fa2",
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
@@ -313,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||
custom_mask_buf=self.cuda_graph_custom_mask,
|
||||
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
||||
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
||||
)
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
@@ -1155,41 +1156,24 @@ def fast_decode_plan(
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
empty_q_data = self.empty_q_data
|
||||
empty_kv_cache = self.empty_kv_cache
|
||||
if self.use_tensor_cores:
|
||||
if not self.is_cuda_graph_enabled:
|
||||
# when not using cudagraph, we need to create the indptr buffer, otherwise
|
||||
# the buffer is already created during initialization
|
||||
self._qo_indptr_buf = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=indptr.device
|
||||
)
|
||||
self._wrapper.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._qo_indptr_buf,
|
||||
indptr,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
empty_q_data,
|
||||
)
|
||||
else:
|
||||
self._wrapper.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
indptr,
|
||||
self.last_page_len,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
PosEncodingMode[pos_encoding_mode].value,
|
||||
logits_soft_cap,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
)
|
||||
stream = torch.cuda.current_stream()
|
||||
self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr.to("cpu"),
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
stream.cuda_stream,
|
||||
)
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
|
||||
@@ -69,6 +69,7 @@ class EagleDraftInput:
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
|
||||
pt = 0
|
||||
@@ -353,8 +354,12 @@ class EagleVerifyInput:
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||
unfinished_index
|
||||
]
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||
return (
|
||||
|
||||
@@ -269,6 +269,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
@@ -284,6 +285,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
|
||||
Reference in New Issue
Block a user