Fix incorrect spec_num_draft_tokens in draft_extend (#7757)
This commit is contained in:
@@ -237,6 +237,10 @@ def _dp_gather(
|
||||
assert (
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
|
||||
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
||||
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
||||
# actual size of the accepted tokens.
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
@@ -291,6 +295,10 @@ def dp_scatter(
|
||||
assert (
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between local_tokens and global_tokens not allowed"
|
||||
|
||||
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
||||
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
||||
# actual size of the accepted tokens.
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
|
||||
Reference in New Issue
Block a user