[Bugfix] grammar_bitmask IndexError caused by outdated apply_grammar_bitmask method (#2022)

### What this PR does / why we need it?
Fix #2033 

Sync https://github.com/vllm-project/vllm/pull/14702 to solve
`grammar_bitmask` IndexError caused by outdated `apply_grammar_bitmask`
method

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tested by upstream vllm


- vLLM version: v0.10.0
- vLLM main:
6e599eebe8

Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
ApsarasX
2025-07-31 09:03:27 +08:00
committed by GitHub
parent 75e28d0356
commit 72eceff94d

View File

@@ -1348,40 +1348,52 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
) -> torch.Tensor:
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.
# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
out_indices = []
# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.zeros_like(grammar_bitmask,
shape=(logits.shape[0],
grammar_bitmask.shape[1]))
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)
# TODO: compatibility with spec decode.
# NOTE:
# 1. XGrammar bitmask applying only supports CPU and GPU.
# 2. The logits and bitmask should be on the same device.
@@ -1391,7 +1403,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask,
indices=list(struct_out_req_batch_indices.values()),
indices=out_indices,
)
return logits.to(self.device).to(logits_dtype)