diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0bfe0f84..320481ec 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -21,7 +21,6 @@ import copy import gc import itertools import math -import re import time from collections import defaultdict from collections.abc import Iterator @@ -34,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, import numpy as np import numpy.typing as npt +import regex as re import torch import torch._dynamo.cache_size import torch.distributed as dist @@ -92,6 +92,7 @@ from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -1699,70 +1700,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) return metadata - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ) -> torch.Tensor: - grammar_bitmask = scheduler_output.grammar_bitmask - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.zeros_like(grammar_bitmask, - shape=(logits.shape[0], - grammar_bitmask.shape[1])) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask) - - # NOTE: - # 1. XGrammar bitmask applying only supports CPU and GPU. - # 2. The logits and bitmask should be on the same device. - # 3. XGrammar logits on CPU only supports float32 dtype. - logits_dtype = logits.dtype - logits = logits.to("cpu").float() - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask, - indices=out_indices, - ) - return logits.to(self.device).to(logits_dtype) - def propose_draft_token_ids( self, valid_sampled_token_ids: list[list[int]], @@ -2011,7 +1948,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - logits = self.apply_grammar_bitmask(scheduler_output, logits) + assert logits is not None + # NOTE: + # 1. XGrammar bitmask applying only supports CPU and GPU. + # 2. The logits and bitmask should be on the same device. + # 3. XGrammar logits on CPU only supports float32 dtype. + logits_dtype = logits.dtype + logits = logits.to("cpu").float() + apply_grammar_bitmask(scheduler_output, self.input_batch, + logits, torch.device("cpu")) + logits = logits.to(self.device).to(logits_dtype) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata