[1/N][UT][v1 MTP] add basic v1 mtp features (#890)

### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
https://github.com/vllm-project/vllm-ascend/pull/874 and
https://github.com/vllm-project/vllm-ascend/pull/844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
XWFAlone
2025-05-30 08:59:58 +08:00
committed by GitHub
parent 5903547d09
commit 3442fbdb23
7 changed files with 477 additions and 13 deletions

View File

@@ -93,6 +93,7 @@ jobs:
- name: Run vllm-project/vllm-ascend long term test
run: |
# spec decode test
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
VLLM_USE_MODELSCOPE=true pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import random
from typing import Any
import pytest
from vllm import LLM, SamplingParams
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
trust_remote_code=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm

View File

@@ -922,6 +922,7 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
assert draft_worker.get_spec_proposals.call_count == 1
@pytest.mark.skipif(True, reason="TODO revert me after fix it by CMQ")
def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.

View File

@@ -16,13 +16,26 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
# _attn_mask_builder = None
def __init__(self,
runner: "NPUModelRunner",
runner,
metadata_cls: Optional[AscendMLAMetadata] = None):
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -230,6 +247,7 @@ class AscendMLAMetadataBuilder:
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
@@ -252,6 +270,7 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
query_start_loc = None
prefill_metadata = None
if self._num_prefills > 0:
@@ -259,6 +278,9 @@ class AscendMLAMetadataBuilder:
tokens_start = self._num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
query_start_loc = common_attn_metadata.query_start_loc
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
@@ -269,6 +291,7 @@ class AscendMLAMetadataBuilder:
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
)
decode_metadata = None
@@ -325,6 +348,9 @@ class AscendMLAMetadataBuilder:
attn_state=self.runner.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
)

View File

@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
device="npu",
dtype=value.dtype,
)
num_query = torch.sum(q_mask).item()
num_add_query = num_query - query.size(0)
# mtp will come in
if num_add_query > 0:
add_query_size = query.size()
add_query_size = list(add_query_size)
add_query_size[0] = num_add_query
pad_tensor = torch.zeros(add_query_size,
dtype=query.dtype,
device=query.device)
query = torch.cat([query, pad_tensor], dim=0)
pad_q[q_mask] = query
pad_k[kv_c_mask] = key[kv_c_mask]
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
attn_output = (attn_output[q_mask].view([-1, num_heads,
v_head_dim]).to(output.dtype))
output = output.view_as(attn_output)
output.copy_(attn_output)
output = output.view([-1, num_heads, v_head_dim])
output.copy_(attn_output[:query.size(0) - num_add_query])
return attn_output

View File

@@ -59,8 +59,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
@@ -201,6 +203,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
elif self.speculative_config.method == "eagle":
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
elif self.speculative_config.method == 'deepseek_mtp':
self.drafter = MtpProposer(self.vllm_config, self)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
@@ -216,6 +220,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -591,18 +601,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs = {}
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
# Add graph_pad_size here
if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
if self.vllm_config.model_config.use_mla:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_attn_metadata=common_attn_metadata,
common_prefix_len=None,
**extra_builder_kwargs,
)
else:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
# Prepare input_ids
@@ -836,6 +871,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"eagle method for spec decode doesn't work on vllm-ascend currently"
)
elif self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer)
spec_token_ids = self._generate_mtp_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, attn_metadata)
return spec_token_ids
@torch.inference_mode()
@@ -1126,7 +1167,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.model = get_model(vllm_config=self.vllm_config)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
self.drafter.load_model()
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@@ -1333,3 +1374,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def _generate_mtp_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: SpecDecodeMetadata,
):
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids

View File

@@ -0,0 +1,222 @@
import torch
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs
class MtpProposer:
def __init__(
self,
vllm_config: VllmConfig,
runner,
):
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.runner = runner
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
BLOCK_SIZE = 1024
prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
seq_lens = (target_positions[last_token_indices] + 1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
# scheduler_output as arguments, we should probably refactor
# the method to use new data structures which are independent
# from input_batch and scheduler_output.
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
previous_hidden_states=target_hidden_states,
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
def load_model(self) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
# TODO Using torch instead of triton may result in poor performance
def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,
cu_num_tokens: torch.Tensor, block_size: int):
device = cu_query_lens.device
dtype = out_ptr.dtype
offsets = torch.arange(block_size, device=device, dtype=dtype)
start_pos = cu_num_tokens[:-1]
end_pos = cu_num_tokens[1:]
num_tokens = end_pos - start_pos
global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))
mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))
global_indices_flat = global_indices[mask]
values_flat = values[mask]
out_ptr[global_indices_flat] = values_flat