[V1][Structured Output] Add apply_grammar_bitmask() method to model runner (#555)
### What this PR does / why we need it? Add `apply_grammar_bitmask()` method to model runner. This method is necessary for `xgrammar` structured output. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv)
|
||||
LayerBlockType, LazyLoader, cdiv)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
@@ -52,7 +52,10 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
|
||||
class NPUModelRunner:
|
||||
@@ -493,6 +496,60 @@ class NPUModelRunner:
|
||||
|
||||
return hidden_states[sample_indices]
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
return
|
||||
|
||||
# We receive the structured output bitmask from the scheduler, but the
|
||||
# indices of the requests in the batch may not match the indices of
|
||||
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||
# ordering the requests in the batch. We need to sort the bitmask to
|
||||
# match the order of the requests used here.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
indices_match = True
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||
req_id)
|
||||
if mask_index is None:
|
||||
# not a structured output request
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
if batch_index != mask_index:
|
||||
indices_match = False
|
||||
struct_out_req_batch_indices[req_id] = batch_index
|
||||
|
||||
if not indices_match:
|
||||
# Sort the bitmask to match the order of the requests
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||
orig_index = scheduler_output.structured_output_request_ids[
|
||||
req_id]
|
||||
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
|
||||
# TODO: compatibility with spec decode.
|
||||
# NOTE:
|
||||
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
||||
# 2. The logits and bitmask should be on the same device.
|
||||
# 3. XGrammar logits on CPU only supports float32 dtype.
|
||||
logits_dtype = logits.dtype
|
||||
logits = logits.to("cpu").float()
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask,
|
||||
indices=list(struct_out_req_batch_indices.values()),
|
||||
)
|
||||
return logits.to(self.device).to(logits_dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -507,6 +564,10 @@ class NPUModelRunner:
|
||||
intermediate_tensors)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
sampler_output = self.model.sample(
|
||||
|
||||
Reference in New Issue
Block a user