Optimize some rejectsampler functions to make npu op launch non-blocking (#4587)
### What this PR does / why we need it?
- Vetorize the loop (but change not output) in some rejectsampler
functions include: `expand_pytorch`, `sample_recovered_tokens_pytorch`,
`rejection_random_sample_pytorch`, `sample_recovered_tokens`.
- Remove synchronize-launch torchnpu operator in them to accelerate
sampling + MTP postprocess.
### Does this PR introduce _any_ user-facing change?
- No
### How was this patch tested?
- We tested this change with the serve&bench command:
```
===== serve =====
vllm serve $LOCAL_CKPT_DIR \
--host 0.0.0.0 \
--port 8000 \
--data-parallel-size 4 \
--data-parallel-size-local 2 \
--data-parallel-address $MASTER_NODE_IP \
--data-parallel-start-rank $((2*VC_TASK_INDEX)) \
--data-parallel-rpc-port 13387 \
--tensor-parallel-size 8 \
--seed 1024 \
--enable-expert-parallel \
--served-model-name $NAME \
--max-model-len 4096 \
--max-num-seqs 16 \
--trust-remote-code \
--gpu-memory-utilization 0.90 \
$headless \
--speculative_config '{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
--additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true}}'
==== bench =====
vllm bench serve --model $LOCAL_CKPT_DIR --served-model-name DeepseekV3ForCausalLM \
--dataset-name spec_bench --spec-bench-output-len 2048 \
--dataset-path question.jsonl \
--top-p 1.0 --temperature 0.8 \
--ignore-eos \
--num-prompts 64 --trust-remote-code --base-url "http://0.0.0.0:8000" --request-rate 64
```
- In this case, our rj optimization can reduce TPOT from 84.94ms to
64.61ms, about 23% gain.
## before
<img width="1068" height="830" alt="image"
src="https://github.com/user-attachments/assets/278ac878-b49d-4588-b87c-316ca4d537f5"
/>
## after
<img width="781" height="756" alt="image"
src="https://github.com/user-attachments/assets/0c6d37ad-ed77-40b3-a1be-4933c468365c"
/>
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: ZongYuan Zhan <zhanzy178@gmail.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
This commit is contained in:
@@ -27,8 +27,22 @@ GREEDY_TEMPERATURE = 0.0
|
||||
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
|
||||
|
||||
|
||||
def mock_pin_memory(original_func):
|
||||
|
||||
def func_wo_pin_memory(*args, **kwargs):
|
||||
if kwargs.get('pin_memory', False):
|
||||
kwargs['pin_memory'] = False
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return func_wo_pin_memory
|
||||
|
||||
|
||||
class TestAscendRejectionSampler(TestBase):
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_rejection_greedy_sample_pytorch(self):
|
||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||
batch_size = 2
|
||||
@@ -60,6 +74,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[1, 0].item() == 20
|
||||
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_rejection_random_sample_pytorch(self):
|
||||
"""Test random rejection sampling: accept based on uniform probability"""
|
||||
batch_size = 2
|
||||
@@ -104,6 +122,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[0, 1].item() == 0
|
||||
assert output_token_ids[0, 2].item() == 100
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_expand_pytorch(self):
|
||||
"""Test expand_pytorch functionality"""
|
||||
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
|
||||
@@ -122,6 +144,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(output_ptr, expected)
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_expand_batch_to_tokens(self):
|
||||
"""Test expand_batch_to_tokens wrapper"""
|
||||
x = torch.tensor([10, 20, 30])
|
||||
@@ -154,6 +180,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_sample_recovered_tokens_pytorch_ngram(self):
|
||||
"""Test recovered token sampling under n-gram mode"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
@@ -184,6 +214,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[0].item() == 0
|
||||
assert output_token_ids[1].item() == 1
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_sample_recovered_tokens_pytorch_autoregressive(self):
|
||||
"""Test recovered token sampling for autoregressive models"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
|
||||
@@ -303,31 +303,32 @@ def expand_batch_to_tokens(
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
|
||||
num_draft_tensor = torch.tensor(num_draft_tokens,
|
||||
pin_memory=True).to(device,
|
||||
non_blocking=True)
|
||||
has_draft_mask = num_draft_tensor > 0
|
||||
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
temp_q = torch.empty_like(q[i])
|
||||
temp_q.exponential_(generator=generator)
|
||||
q[i] = torch.where(has_draft_mask[i], temp_q, q[i])
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
if HAS_TRITON:
|
||||
@@ -459,45 +460,128 @@ def rejection_random_sample_pytorch(
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
):
|
||||
"""
|
||||
This function implements the Speculative Decoding rejection sampling step.
|
||||
Instead of looping through each request and each token (which causes high
|
||||
overhead), it uses a fully vectorized approach:
|
||||
|
||||
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
|
||||
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
|
||||
variable-length sequences in the batch.
|
||||
2. **Parallel Validation**: Calculates the acceptance condition
|
||||
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
|
||||
simultaneously across the entire batch.
|
||||
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
|
||||
subsequent tokens are ignored. Here, we simulate this by finding the
|
||||
'first_reject_pos' using argmax on the rejection mask and creating a
|
||||
'should_skip' mask for all indices after the first failure.
|
||||
4. **Token Selection**: Uses 'torch.where' to select:
|
||||
- Draft tokens (if accepted)
|
||||
- Recovered tokens (at the point of first rejection)
|
||||
- Bonus tokens (if all tokens in a sequence were accepted)
|
||||
5. **Masking**: Ensures operations only apply to non-greedy requests and
|
||||
within valid sequence lengths.
|
||||
"""
|
||||
|
||||
batch_size = output_token_ids.shape[0]
|
||||
device = output_token_ids.device
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
if is_greedy[req_idx]:
|
||||
continue
|
||||
zero_cpu = torch.tensor([0], pin_memory=True)
|
||||
zero_device = zero_cpu.to(device, non_blocking=True)
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = cu_num_draft_tokens[req_idx - 1].item()
|
||||
end_idx = cu_num_draft_tokens[req_idx].item()
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]])
|
||||
cu_end = cu_num_draft_tokens
|
||||
num_draft_per_batch = cu_end - cu_start
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = draft_token_ids[start_idx + pos].item()
|
||||
max_draft_len = max_spec_len
|
||||
pos_indices_cpu = torch.arange(max_draft_len, pin_memory=True)
|
||||
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
|
||||
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1.0
|
||||
else:
|
||||
draft_prob = draft_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
valid_mask = pos_indices < num_draft_per_batch[:, None]
|
||||
global_token_indices = cu_start[:, None] + pos_indices
|
||||
global_token_indices = global_token_indices.clamp(
|
||||
0, draft_token_ids.shape[0] - 1)
|
||||
draft_tokens = draft_token_ids[
|
||||
global_token_indices] # [batch_size, max_draft_len]
|
||||
|
||||
target_prob = target_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||
if IS_NGRAM:
|
||||
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
|
||||
draft_token_probs = ones_cpu.to(
|
||||
device, non_blocking=True).expand_as(draft_tokens)
|
||||
else:
|
||||
flat_indices = global_token_indices.flatten()
|
||||
flat_draft_tokens = draft_tokens.flatten()
|
||||
flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens]
|
||||
draft_token_probs = flat_draft_probs.view(batch_size, max_draft_len)
|
||||
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
rejected = True
|
||||
token_id = recovered_token_ids[start_idx + pos].item()
|
||||
flat_indices = global_token_indices.flatten()
|
||||
flat_draft_tokens = draft_tokens.flatten()
|
||||
flat_target_probs = target_probs[flat_indices, flat_draft_tokens]
|
||||
target_token_probs = flat_target_probs.view(batch_size, max_draft_len)
|
||||
|
||||
output_token_ids[req_idx, pos] = token_id
|
||||
uniform_token_probs = uniform_probs[global_token_indices]
|
||||
recovered_tokens = recovered_token_ids[global_token_indices]
|
||||
|
||||
if not rejected:
|
||||
bonus_token_id = bonus_token_ids[req_idx].item()
|
||||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
||||
zero_threshold_cpu = torch.tensor([0.0],
|
||||
pin_memory=True,
|
||||
dtype=torch.float32)
|
||||
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
|
||||
|
||||
acceptance_condition = (draft_token_probs > zero_threshold) & (
|
||||
target_token_probs / draft_token_probs >= uniform_token_probs)
|
||||
|
||||
first_rejection = (~acceptance_condition) & valid_mask
|
||||
|
||||
default_pos_cpu = torch.full([batch_size, 1],
|
||||
max_draft_len,
|
||||
pin_memory=True)
|
||||
default_pos = default_pos_cpu.to(device, non_blocking=True)
|
||||
|
||||
first_reject_pos = torch.where(
|
||||
first_rejection.any(dim=1, keepdim=True),
|
||||
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
|
||||
pos_mask = pos_indices >= first_reject_pos
|
||||
should_skip = pos_mask & valid_mask
|
||||
|
||||
final_acceptance = acceptance_condition & (~should_skip)
|
||||
non_greedy_mask = ~is_greedy
|
||||
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
|
||||
|
||||
first_reject_mask = (pos_indices == first_reject_pos
|
||||
) & valid_mask & non_greedy_mask[:, None]
|
||||
final_update_mask = update_mask | first_reject_mask
|
||||
final_tokens = torch.where(
|
||||
first_reject_mask, recovered_tokens,
|
||||
torch.where(final_acceptance, draft_tokens,
|
||||
output_token_ids[:, :max_draft_len]))
|
||||
|
||||
output_token_ids[:, :max_draft_len] = torch.where(
|
||||
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])
|
||||
|
||||
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
|
||||
should_add_bonus = non_greedy_mask & no_rejection
|
||||
|
||||
bonus_positions = num_draft_per_batch # [batch_size]
|
||||
|
||||
seq_len = output_token_ids.shape[1]
|
||||
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
|
||||
all_positions = all_positions_cpu.to(
|
||||
device, non_blocking=True)[None, :] # [1, seq_len]
|
||||
|
||||
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
|
||||
|
||||
max_spec_len_cpu = torch.tensor([max_spec_len], pin_memory=True)
|
||||
max_spec_len_device = max_spec_len_cpu.to(device, non_blocking=True)
|
||||
|
||||
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
|
||||
final_bonus_mask = should_add_bonus & valid_bonus_pos
|
||||
|
||||
bonus_pos_match = (all_positions == batch_bonus_positions)
|
||||
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
|
||||
|
||||
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
|
||||
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
|
||||
output_token_ids)
|
||||
|
||||
|
||||
def expand_pytorch(
|
||||
@@ -508,21 +592,48 @@ def expand_pytorch(
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS,
|
||||
):
|
||||
batch_size = len(input_ptr)
|
||||
"""
|
||||
This function broadcasts batch-level values (input_ptr) to token-level
|
||||
positions (output_ptr) based on cumulative token offsets. It acts like
|
||||
a "scatter" or "repeat_interleave" operation but with custom logic:
|
||||
|
||||
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
|
||||
[num_tokens, batch_size] that identifies which batch index each token
|
||||
belongs to by checking if the token index falls between cu_start and cu_end.
|
||||
2. **Conditional Replacement**: Before expansion, it replaces specific values
|
||||
(e.g., padding or special markers) in the input to prepare the data.
|
||||
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
|
||||
sum that effectively "picks" the correct batch value for every token position
|
||||
simultaneously, avoiding a Python loop over the batch.
|
||||
"""
|
||||
device = cu_num_tokens_ptr.device
|
||||
batch_size = input_ptr.shape[0]
|
||||
num_tokens = output_ptr.shape[0]
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1]
|
||||
end_idx = cu_num_tokens_ptr[req_idx]
|
||||
num_tokens = end_idx - start_idx
|
||||
if batch_size == 0 or num_tokens == 0:
|
||||
return
|
||||
|
||||
src_val = input_ptr[req_idx]
|
||||
src_val = replace_to if src_val == replace_from else src_val
|
||||
cu_start = torch.cat([
|
||||
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
|
||||
cu_num_tokens_ptr[:-1]
|
||||
])
|
||||
cu_end = cu_num_tokens_ptr
|
||||
|
||||
offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device)
|
||||
mask = offset < num_tokens
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=device)[:, None] # [num_tokens, 1]
|
||||
cu_start_exp = cu_start[None, :] # [1, batch_size]
|
||||
cu_end_exp = cu_end[None, :] # [1, batch_size]
|
||||
|
||||
output_slice = start_idx + offset[mask]
|
||||
output_ptr[output_slice] = src_val
|
||||
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
|
||||
|
||||
replaced_input = torch.where(input_ptr == replace_from, replace_to,
|
||||
input_ptr).float()
|
||||
|
||||
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
|
||||
|
||||
needs_update = in_range.any(dim=1)
|
||||
|
||||
output_ptr[:] = torch.where(needs_update, token_values, output_ptr)
|
||||
|
||||
|
||||
def sample_recovered_tokens_pytorch(
|
||||
@@ -535,37 +646,77 @@ def sample_recovered_tokens_pytorch(
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
):
|
||||
batch_size = len(cu_num_draft_tokens)
|
||||
"""
|
||||
When a draft token is rejected, we must sample a "recovered" token from
|
||||
a modified distribution. This function calculates that distribution across
|
||||
the entire flattened batch.
|
||||
|
||||
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
|
||||
determines which request in the batch each token belongs to. This is
|
||||
necessary because 'q' (normalization factor) is stored per-request.
|
||||
2. **Probability Adjustment**:
|
||||
- If N-GRAM: It zeroes out the draft token's probability in the target.
|
||||
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
|
||||
as per the standard speculative decoding algorithm.
|
||||
3. **Normalization & Sampling**: It divides the adjusted probabilities
|
||||
by the normalization distribution 'q'. To remain vectorized, it
|
||||
broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab].
|
||||
4. **Argmax Selection**: It selects the best recovery token for every
|
||||
position in one pass using torch.argmax.
|
||||
"""
|
||||
device = output_token_ids.device
|
||||
num_tokens = output_token_ids.shape[0]
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1]
|
||||
end_idx = cu_num_draft_tokens[req_idx]
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
if num_tokens == 0:
|
||||
return
|
||||
|
||||
for pos in range(num_draft_tokens):
|
||||
token_idx = start_idx + pos
|
||||
cu_start = torch.cat([
|
||||
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
|
||||
cu_num_draft_tokens[:-1],
|
||||
])
|
||||
cu_end = cu_num_draft_tokens
|
||||
|
||||
if IS_NGRAM:
|
||||
draft_token_id = draft_token_ids[token_idx]
|
||||
orig_prob = target_probs[token_idx, draft_token_id].item()
|
||||
target_probs[token_idx, draft_token_id] = 0
|
||||
prob = target_probs[token_idx].clone()
|
||||
else:
|
||||
draft_p = draft_probs[token_idx].clone()
|
||||
target_p = target_probs[token_idx].clone()
|
||||
prob = torch.maximum(target_p - draft_p,
|
||||
torch.tensor(0.0, device=target_p.device))
|
||||
token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
|
||||
|
||||
q_values = torch.full((vocab_size, ),
|
||||
float('-inf'),
|
||||
device=q.device)
|
||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||
token_indices_expanded = token_indices[:, None] # [num_tokens, 1]
|
||||
cu_start_expanded = cu_start[None, :] # [1, batch_size]
|
||||
cu_end_expanded = cu_end[None, :] # [1, batch_size]
|
||||
|
||||
recovered_id = torch.argmax(prob / q_values).item()
|
||||
output_token_ids[token_idx] = recovered_id
|
||||
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
|
||||
token_indices_expanded < cu_end_expanded)
|
||||
|
||||
if IS_NGRAM:
|
||||
target_probs[token_idx, draft_token_id] = orig_prob
|
||||
token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
|
||||
|
||||
has_match = in_range_mask.any(dim=1)
|
||||
token_to_batch = torch.where(has_match, token_to_batch, 0)
|
||||
|
||||
if IS_NGRAM:
|
||||
token_indices = torch.arange(num_tokens, device=device)
|
||||
|
||||
modified_target_probs = target_probs.clone()
|
||||
modified_target_probs[token_indices, draft_token_ids] = 0
|
||||
prob = modified_target_probs
|
||||
|
||||
else:
|
||||
prob = torch.maximum(
|
||||
target_probs - draft_probs,
|
||||
torch.tensor(0.0, pin_memory=True).to(device, non_blocking=True),
|
||||
)
|
||||
|
||||
q_values = q[token_to_batch] # [num_tokens, vocab_size]
|
||||
|
||||
epsilon = 1e-10
|
||||
q_values_safe = torch.where(q_values == 0, epsilon, q_values)
|
||||
q_values_safe = torch.where(torch.isinf(q_values), epsilon, q_values_safe)
|
||||
|
||||
prob_over_q = prob / q_values_safe
|
||||
|
||||
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
|
||||
prob_over_q)
|
||||
|
||||
recovered_ids = torch.argmax(prob_over_q, dim=1)
|
||||
|
||||
output_token_ids[:] = recovered_ids
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
|
||||
Reference in New Issue
Block a user