[V1][Structured Output] Add apply_grammar_bitmask() method to model runner (#555)
### What this PR does / why we need it? Add `apply_grammar_bitmask()` method to model runner. This method is necessary for `xgrammar` structured output. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LayerBlockType, cdiv)
|
LayerBlockType, LazyLoader, cdiv)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
@@ -52,7 +52,10 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import xgrammar as xgr # type: ignore[import-untyped]
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
|
||||||
class NPUModelRunner:
|
class NPUModelRunner:
|
||||||
@@ -493,6 +496,60 @@ class NPUModelRunner:
|
|||||||
|
|
||||||
return hidden_states[sample_indices]
|
return hidden_states[sample_indices]
|
||||||
|
|
||||||
|
def apply_grammar_bitmask(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||||
|
# so we receive it in that format.
|
||||||
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||||
|
if grammar_bitmask is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# We receive the structured output bitmask from the scheduler, but the
|
||||||
|
# indices of the requests in the batch may not match the indices of
|
||||||
|
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||||
|
# ordering the requests in the batch. We need to sort the bitmask to
|
||||||
|
# match the order of the requests used here.
|
||||||
|
struct_out_req_batch_indices: dict[str, int] = {}
|
||||||
|
indices_match = True
|
||||||
|
for req_id in self.input_batch.req_ids:
|
||||||
|
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||||
|
req_id)
|
||||||
|
if mask_index is None:
|
||||||
|
# not a structured output request
|
||||||
|
continue
|
||||||
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
if batch_index != mask_index:
|
||||||
|
indices_match = False
|
||||||
|
struct_out_req_batch_indices[req_id] = batch_index
|
||||||
|
|
||||||
|
if not indices_match:
|
||||||
|
# Sort the bitmask to match the order of the requests
|
||||||
|
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||||
|
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||||
|
orig_index = scheduler_output.structured_output_request_ids[
|
||||||
|
req_id]
|
||||||
|
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||||
|
grammar_bitmask = sorted_bitmask
|
||||||
|
|
||||||
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||||
|
|
||||||
|
# TODO: compatibility with spec decode.
|
||||||
|
# NOTE:
|
||||||
|
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
||||||
|
# 2. The logits and bitmask should be on the same device.
|
||||||
|
# 3. XGrammar logits on CPU only supports float32 dtype.
|
||||||
|
logits_dtype = logits.dtype
|
||||||
|
logits = logits.to("cpu").float()
|
||||||
|
xgr.apply_token_bitmask_inplace(
|
||||||
|
logits,
|
||||||
|
grammar_bitmask,
|
||||||
|
indices=list(struct_out_req_batch_indices.values()),
|
||||||
|
)
|
||||||
|
return logits.to(self.device).to(logits_dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@@ -507,6 +564,10 @@ class NPUModelRunner:
|
|||||||
intermediate_tensors)
|
intermediate_tensors)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
|
||||||
|
# Apply structured output bitmasks if present
|
||||||
|
if scheduler_output.grammar_bitmask is not None:
|
||||||
|
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||||
|
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
sampler_output = self.model.sample(
|
sampler_output = self.model.sample(
|
||||||
|
|||||||
Reference in New Issue
Block a user