bugfix for mtp (#3300)
### What this PR does / why we need it? when mtp>1, we need refresh cos ans sin in each step. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0 Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -97,7 +97,7 @@ def split_decodes_and_prefills(
|
|||||||
return num_reqs, 0, num_tokens, 0
|
return num_reqs, 0, num_tokens, 0
|
||||||
|
|
||||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||||
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
|
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
||||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||||
num_decodes = first_prefill
|
num_decodes = first_prefill
|
||||||
num_prefills = num_reqs - num_decodes
|
num_prefills = num_reqs - num_decodes
|
||||||
|
|||||||
@@ -523,6 +523,14 @@ class MtpProposer(Proposer):
|
|||||||
input_ids = draft_token_ids_list[-1].int()
|
input_ids = draft_token_ids_list[-1].int()
|
||||||
positions += 1
|
positions += 1
|
||||||
|
|
||||||
|
if not self.torchair_graph_enabled:
|
||||||
|
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||||
|
1:batch_size + 1].tolist()
|
||||||
|
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||||
|
positions].unsqueeze(1).unsqueeze(2)
|
||||||
|
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||||
|
positions].unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
# NOTE(woosuk): We should handle the case where the draft model
|
# NOTE(woosuk): We should handle the case where the draft model
|
||||||
# generates tokens beyond the max model length. Since it is complex
|
# generates tokens beyond the max model length. Since it is complex
|
||||||
# to remove such requests from the batch, we keep them in the batch
|
# to remove such requests from the batch, we keep them in the batch
|
||||||
@@ -556,6 +564,8 @@ class MtpProposer(Proposer):
|
|||||||
|
|
||||||
if attn_metadata_i.prefill is not None:
|
if attn_metadata_i.prefill is not None:
|
||||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||||
|
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
||||||
|
)
|
||||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||||
num_input_tokens]
|
num_input_tokens]
|
||||||
@@ -565,6 +575,8 @@ class MtpProposer(Proposer):
|
|||||||
self.runner.model_config.max_model_len)
|
self.runner.model_config.max_model_len)
|
||||||
if attn_metadata_i.decode is not None:
|
if attn_metadata_i.decode is not None:
|
||||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||||
|
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||||
|
)
|
||||||
attn_metadata_i.decode.input_positions = self.positions[:
|
attn_metadata_i.decode.input_positions = self.positions[:
|
||||||
num_input_tokens]
|
num_input_tokens]
|
||||||
attn_metadata_i.decode.max_seq_lens += 1
|
attn_metadata_i.decode.max_seq_lens += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user