Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741)
This commit is contained in:
@@ -621,7 +621,6 @@ Please cite our paper, [SGLang: Efficient Execution of Structured Language Model
|
|||||||
We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
|
We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="#sglangtop" target="_blank">
|
<a href="#sglangtop" target="_blank">
|
||||||
<bold>Back To Top </bold>
|
<bold>Back To Top </bold>
|
||||||
|
|||||||
@@ -25,7 +25,11 @@ class AttentionBackend(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -144,7 +144,11 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
):
|
):
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
)
|
)
|
||||||
self.forward_metadata = (self.decode_wrappers,)
|
self.forward_metadata = (self.decode_wrappers,)
|
||||||
else:
|
else:
|
||||||
@@ -134,10 +135,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Some heuristics to check whether to use ragged forward
|
# Some heuristics to check whether to use ragged forward
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
if (
|
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
||||||
torch.sum(forward_batch.seq_lens).item() >= 4096
|
|
||||||
and self.num_wrappers == 1
|
|
||||||
):
|
|
||||||
use_ragged = True
|
use_ragged = True
|
||||||
|
|
||||||
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
||||||
@@ -181,15 +179,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
self.indices_updater_decode.update(
|
||||||
|
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
|
||||||
|
)
|
||||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||||
self.forward_metadata = (decode_wrappers,)
|
self.forward_metadata = (decode_wrappers,)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
):
|
):
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
seq_lens_sum,
|
||||||
|
self.cuda_graph_metadata[bs],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
@@ -305,13 +313,30 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
assert attn_backend.num_wrappers == 1
|
assert attn_backend.num_wrappers == 1
|
||||||
self.update = self.update_single_wrapper
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
def update_single_wrapper(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
decode_wrappers=None,
|
||||||
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
|
decode_wrappers[0],
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
self.kv_indptr[0],
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
def update_sliding_window(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
decode_wrappers=None,
|
||||||
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
|
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -331,6 +356,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
decode_wrappers[wrapper_id],
|
decode_wrappers[wrapper_id],
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
|
seq_lens_sum,
|
||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
)
|
)
|
||||||
@@ -339,13 +365,18 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
|
self,
|
||||||
|
wrapper,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
# TODO: optimize the blocking call on kv_indptr[-1]
|
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
|
|||||||
@@ -91,7 +91,11 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
):
|
):
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|||||||
@@ -416,7 +416,6 @@ class ScheduleBatch:
|
|||||||
req_to_token_pool: ReqToTokenPool = None
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
tree_cache: BasePrefixCache = None
|
tree_cache: BasePrefixCache = None
|
||||||
|
|
||||||
forward_mode: ForwardMode = None
|
forward_mode: ForwardMode = None
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
@@ -424,9 +423,13 @@ class ScheduleBatch:
|
|||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
|
# The output locations of the KV cache
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
output_ids: torch.Tensor = None
|
output_ids: torch.Tensor = None
|
||||||
|
|
||||||
|
# The sum of all sequence lengths
|
||||||
|
seq_lens_sum: int = None
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
@@ -435,7 +438,6 @@ class ScheduleBatch:
|
|||||||
prefix_lens: List[int] = None
|
prefix_lens: List[int] = None
|
||||||
extend_lens: List[int] = None
|
extend_lens: List[int] = None
|
||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
running_bs: int = None
|
|
||||||
decoding_reqs: List[Req] = None
|
decoding_reqs: List[Req] = None
|
||||||
|
|
||||||
# Stream
|
# Stream
|
||||||
@@ -549,10 +551,12 @@ class ScheduleBatch:
|
|||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
|
|
||||||
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
@@ -571,12 +575,11 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
||||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||||
extend_num_tokens = self.extend_num_tokens + running_bs
|
|
||||||
|
|
||||||
self.merge_batch(running_batch)
|
self.merge_batch(running_batch)
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens += running_bs
|
||||||
|
|
||||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||||
self.prefix_lens.extend(
|
self.prefix_lens.extend(
|
||||||
@@ -775,6 +778,7 @@ class ScheduleBatch:
|
|||||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||||
)
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
self.seq_lens_sum += bs
|
||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
@@ -805,6 +809,7 @@ class ScheduleBatch:
|
|||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||||
self.output_ids = self.output_ids[new_indices]
|
self.output_ids = self.output_ids[new_indices]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
@@ -828,6 +833,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
|
self.seq_lens_sum += other.seq_lens_sum
|
||||||
if self.output_ids is not None:
|
if self.output_ids is not None:
|
||||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||||
if self.return_logprob and other.return_logprob:
|
if self.return_logprob and other.return_logprob:
|
||||||
@@ -873,9 +879,11 @@ class ScheduleBatch:
|
|||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
seq_lens_sum=self.seq_lens_sum,
|
||||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
extend_num_tokens=self.extend_num_tokens,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||||
@@ -917,6 +925,9 @@ class ModelWorkerBatch:
|
|||||||
# The indices of output tokens in the token_to_kv_pool
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
# The sum of all sequence lengths
|
||||||
|
seq_lens_sum: int
|
||||||
|
|
||||||
# The memory pool operation records
|
# The memory pool operation records
|
||||||
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
||||||
|
|
||||||
@@ -925,6 +936,7 @@ class ModelWorkerBatch:
|
|||||||
top_logprobs_nums: Optional[List[int]]
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
|
extend_num_tokens: Optional[int]
|
||||||
extend_seq_lens: Optional[List[int]]
|
extend_seq_lens: Optional[List[int]]
|
||||||
extend_prefix_lens: Optional[List[int]]
|
extend_prefix_lens: Optional[List[int]]
|
||||||
extend_logprob_start_lens: Optional[List[int]]
|
extend_logprob_start_lens: Optional[List[int]]
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ class CudaGraphRunner:
|
|||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:bs]
|
||||||
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
@@ -206,6 +207,7 @@ class CudaGraphRunner:
|
|||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
attn_backend=self.model_runner.attn_backend,
|
attn_backend=self.model_runner.attn_backend,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=seq_lens_sum,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||||
@@ -252,7 +254,10 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs, self.req_pool_indices, self.seq_lens
|
bs,
|
||||||
|
self.req_pool_indices,
|
||||||
|
self.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
|
|||||||
@@ -87,6 +87,9 @@ class ForwardBatch:
|
|||||||
# The indices of output tokens in the token_to_kv_pool
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
# The sum of all sequence lengths
|
||||||
|
seq_lens_sum: int
|
||||||
|
|
||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
@@ -95,6 +98,7 @@ class ForwardBatch:
|
|||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
|
extend_num_tokens: Optional[int] = None
|
||||||
extend_seq_lens: Optional[torch.Tensor] = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||||
extend_start_loc: Optional[torch.Tensor] = None
|
extend_start_loc: Optional[torch.Tensor] = None
|
||||||
@@ -175,21 +179,6 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||||
|
|
||||||
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
|
|
||||||
device = model_runner.device
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
self.positions = (self.seq_lens - 1).to(torch.int64)
|
|
||||||
else:
|
|
||||||
self.positions = torch.concat(
|
|
||||||
[
|
|
||||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
|
||||||
for prefix_len, extend_len in zip(
|
|
||||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
|
||||||
)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -205,6 +194,7 @@ class ForwardBatch:
|
|||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
seq_lens_sum=batch.seq_lens_sum,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
@@ -213,7 +203,17 @@ class ForwardBatch:
|
|||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if not ret.forward_mode.is_decode():
|
if not ret.forward_mode.is_decode():
|
||||||
|
ret.positions = torch.concat(
|
||||||
|
[
|
||||||
|
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||||
|
for prefix_len, extend_len in zip(
|
||||||
|
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||||
|
)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
ret.image_inputs = batch.image_inputs
|
ret.image_inputs = batch.image_inputs
|
||||||
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
@@ -225,12 +225,8 @@ class ForwardBatch:
|
|||||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
|
|
||||||
# Init position information
|
if model_runner.model_is_mrope:
|
||||||
is_mrope = model_runner.model_is_mrope
|
|
||||||
if is_mrope:
|
|
||||||
ret.compute_mrope_positions(model_runner, batch)
|
ret.compute_mrope_positions(model_runner, batch)
|
||||||
else:
|
|
||||||
ret.compute_positions(model_runner, batch)
|
|
||||||
|
|
||||||
# Init attention information
|
# Init attention information
|
||||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
|
|||||||
Reference in New Issue
Block a user