Fix retraction + overlap (#1860)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -211,9 +211,6 @@ class Req:
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV store
|
||||
self.cached_tokens = 0
|
||||
|
||||
# For vision inputs
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
|
||||
@@ -223,6 +220,9 @@ class Req:
|
||||
self.last_node = None
|
||||
self.is_being_chunked = 0
|
||||
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = False
|
||||
self.logprob_start_len = 0
|
||||
@@ -242,12 +242,15 @@ class Req:
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
|
||||
# Embedding
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
|
||||
# Constrained decoding
|
||||
self.grammar: Optional[Grammar] = None
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
|
||||
# For Qwen2-VL
|
||||
self.mrope_position_delta = [] # use mutable object
|
||||
|
||||
@@ -561,7 +564,7 @@ class ScheduleBatch:
|
||||
seq_lens[i] -= encoder_len
|
||||
|
||||
if len(req.prefix_indices) < encoder_len:
|
||||
# NOTE: the encoder part should considered as a whole
|
||||
# NOTE: the encoder part should be considered as a whole
|
||||
assert len(req.prefix_indices) == 0
|
||||
input_ids[i] = input_ids[i][encoder_len:]
|
||||
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||
@@ -648,6 +651,7 @@ class ScheduleBatch:
|
||||
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
pt += req.extend_input_len
|
||||
req.is_retracted = False
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -780,6 +784,7 @@ class ScheduleBatch:
|
||||
req.prefix_indices = []
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
req.is_retracted = True
|
||||
|
||||
# For incremental logprobs
|
||||
req.last_update_decode_tokens = 0
|
||||
|
||||
Reference in New Issue
Block a user