diff --git a/tests/ut/sample/test_rejection_sampler.py b/tests/ut/sample/test_rejection_sampler.py index 9e2c23b1..bfae2c65 100644 --- a/tests/ut/sample/test_rejection_sampler.py +++ b/tests/ut/sample/test_rejection_sampler.py @@ -27,8 +27,22 @@ GREEDY_TEMPERATURE = 0.0 MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens +def mock_pin_memory(original_func): + + def func_wo_pin_memory(*args, **kwargs): + if kwargs.get('pin_memory', False): + kwargs['pin_memory'] = False + return original_func(*args, **kwargs) + + return func_wo_pin_memory + + class TestAscendRejectionSampler(TestBase): + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_rejection_greedy_sample_pytorch(self): """Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token""" batch_size = 2 @@ -60,6 +74,10 @@ class TestAscendRejectionSampler(TestBase): assert output_token_ids[1, 0].item() == 20 assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_rejection_random_sample_pytorch(self): """Test random rejection sampling: accept based on uniform probability""" batch_size = 2 @@ -104,6 +122,10 @@ class TestAscendRejectionSampler(TestBase): assert output_token_ids[0, 1].item() == 0 assert output_token_ids[0, 2].item() == 100 + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_expand_pytorch(self): """Test expand_pytorch functionality""" input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32) @@ -122,6 +144,10 @@ class TestAscendRejectionSampler(TestBase): expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) assert torch.equal(output_ptr, expected) + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_expand_batch_to_tokens(self): """Test expand_batch_to_tokens wrapper""" x = torch.tensor([10, 20, 30]) @@ -154,6 +180,10 @@ class TestAscendRejectionSampler(TestBase): expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) assert torch.equal(result, expected) + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_sample_recovered_tokens_pytorch_ngram(self): """Test recovered token sampling under n-gram mode""" output_token_ids = torch.empty(2, dtype=torch.int32) @@ -184,6 +214,10 @@ class TestAscendRejectionSampler(TestBase): assert output_token_ids[0].item() == 0 assert output_token_ids[1].item() == 1 + @patch('torch.arange', new=mock_pin_memory(torch.arange)) + @patch('torch.ones', new=mock_pin_memory(torch.ones)) + @patch('torch.full', new=mock_pin_memory(torch.full)) + @patch('torch.tensor', new=mock_pin_memory(torch.tensor)) def test_sample_recovered_tokens_pytorch_autoregressive(self): """Test recovered token sampling for autoregressive models""" output_token_ids = torch.empty(2, dtype=torch.int32) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 3361c6f2..b9a6a10b 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -303,31 +303,32 @@ def expand_batch_to_tokens( def sample_recovered_tokens( max_spec_len: int, num_draft_tokens: list[int], - # [batch_size] cu_num_draft_tokens: torch.Tensor, - # [num_tokens] draft_token_ids: torch.Tensor, - # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: - # NOTE(woosuk): Create only one distribution for each request. batch_size = len(num_draft_tokens) vocab_size = target_probs.shape[-1] + q = torch.empty( (batch_size, vocab_size), dtype=torch.float32, device=device, ) q.exponential_() + + num_draft_tensor = torch.tensor(num_draft_tokens, + pin_memory=True).to(device, + non_blocking=True) + has_draft_mask = num_draft_tensor > 0 + for i, generator in sampling_metadata.generators.items(): - # Do not generate random numbers for requests with no draft tokens. - # This can be important for reproducibility. - if num_draft_tokens[i] > 0: - q[i].exponential_(generator=generator) + temp_q = torch.empty_like(q[i]) + temp_q.exponential_(generator=generator) + q[i] = torch.where(has_draft_mask[i], temp_q, q[i]) recovered_token_ids = torch.empty_like(draft_token_ids) if HAS_TRITON: @@ -459,45 +460,128 @@ def rejection_random_sample_pytorch( vocab_size, IS_NGRAM=False, ): + """ + This function implements the Speculative Decoding rejection sampling step. + Instead of looping through each request and each token (which causes high + overhead), it uses a fully vectorized approach: + + 1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D + [batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle + variable-length sequences in the batch. + 2. **Parallel Validation**: Calculates the acceptance condition + (target_prob / draft_prob >= uniform_sample) for ALL draft tokens + simultaneously across the entire batch. + 3. **Short-circuit Simulation**: In the loop version, once a token is rejected, + subsequent tokens are ignored. Here, we simulate this by finding the + 'first_reject_pos' using argmax on the rejection mask and creating a + 'should_skip' mask for all indices after the first failure. + 4. **Token Selection**: Uses 'torch.where' to select: + - Draft tokens (if accepted) + - Recovered tokens (at the point of first rejection) + - Bonus tokens (if all tokens in a sequence were accepted) + 5. **Masking**: Ensures operations only apply to non-greedy requests and + within valid sequence lengths. + """ + batch_size = output_token_ids.shape[0] + device = output_token_ids.device - for req_idx in range(batch_size): - if is_greedy[req_idx]: - continue + zero_cpu = torch.tensor([0], pin_memory=True) + zero_device = zero_cpu.to(device, non_blocking=True) - if req_idx == 0: - start_idx = 0 - else: - start_idx = cu_num_draft_tokens[req_idx - 1].item() - end_idx = cu_num_draft_tokens[req_idx].item() - num_draft_tokens = end_idx - start_idx + cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]]) + cu_end = cu_num_draft_tokens + num_draft_per_batch = cu_end - cu_start - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = draft_token_ids[start_idx + pos].item() + max_draft_len = max_spec_len + pos_indices_cpu = torch.arange(max_draft_len, pin_memory=True) + pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] - if IS_NGRAM: - draft_prob = 1.0 - else: - draft_prob = draft_probs[start_idx + pos, - draft_token_id].item() + valid_mask = pos_indices < num_draft_per_batch[:, None] + global_token_indices = cu_start[:, None] + pos_indices + global_token_indices = global_token_indices.clamp( + 0, draft_token_ids.shape[0] - 1) + draft_tokens = draft_token_ids[ + global_token_indices] # [batch_size, max_draft_len] - target_prob = target_probs[start_idx + pos, - draft_token_id].item() - uniform_prob = uniform_probs[start_idx + pos].item() + if IS_NGRAM: + ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) + draft_token_probs = ones_cpu.to( + device, non_blocking=True).expand_as(draft_tokens) + else: + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens] + draft_token_probs = flat_draft_probs.view(batch_size, max_draft_len) - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - token_id = draft_token_id - else: - rejected = True - token_id = recovered_token_ids[start_idx + pos].item() + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_target_probs = target_probs[flat_indices, flat_draft_tokens] + target_token_probs = flat_target_probs.view(batch_size, max_draft_len) - output_token_ids[req_idx, pos] = token_id + uniform_token_probs = uniform_probs[global_token_indices] + recovered_tokens = recovered_token_ids[global_token_indices] - if not rejected: - bonus_token_id = bonus_token_ids[req_idx].item() - output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + zero_threshold_cpu = torch.tensor([0.0], + pin_memory=True, + dtype=torch.float32) + zero_threshold = zero_threshold_cpu.to(device, non_blocking=True) + + acceptance_condition = (draft_token_probs > zero_threshold) & ( + target_token_probs / draft_token_probs >= uniform_token_probs) + + first_rejection = (~acceptance_condition) & valid_mask + + default_pos_cpu = torch.full([batch_size, 1], + max_draft_len, + pin_memory=True) + default_pos = default_pos_cpu.to(device, non_blocking=True) + + first_reject_pos = torch.where( + first_rejection.any(dim=1, keepdim=True), + first_rejection.float().argmax(dim=1, keepdim=True), default_pos) + pos_mask = pos_indices >= first_reject_pos + should_skip = pos_mask & valid_mask + + final_acceptance = acceptance_condition & (~should_skip) + non_greedy_mask = ~is_greedy + update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip) + + first_reject_mask = (pos_indices == first_reject_pos + ) & valid_mask & non_greedy_mask[:, None] + final_update_mask = update_mask | first_reject_mask + final_tokens = torch.where( + first_reject_mask, recovered_tokens, + torch.where(final_acceptance, draft_tokens, + output_token_ids[:, :max_draft_len])) + + output_token_ids[:, :max_draft_len] = torch.where( + final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]) + + no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch + should_add_bonus = non_greedy_mask & no_rejection + + bonus_positions = num_draft_per_batch # [batch_size] + + seq_len = output_token_ids.shape[1] + all_positions_cpu = torch.arange(seq_len, pin_memory=True) + all_positions = all_positions_cpu.to( + device, non_blocking=True)[None, :] # [1, seq_len] + + batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1] + + max_spec_len_cpu = torch.tensor([max_spec_len], pin_memory=True) + max_spec_len_device = max_spec_len_cpu.to(device, non_blocking=True) + + valid_bonus_pos = bonus_positions < (max_spec_len_device + 1) + final_bonus_mask = should_add_bonus & valid_bonus_pos + + bonus_pos_match = (all_positions == batch_bonus_positions) + bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None] + + bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len) + output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, + output_token_ids) def expand_pytorch( @@ -508,21 +592,48 @@ def expand_pytorch( replace_to, MAX_NUM_TOKENS, ): - batch_size = len(input_ptr) + """ + This function broadcasts batch-level values (input_ptr) to token-level + positions (output_ptr) based on cumulative token offsets. It acts like + a "scatter" or "repeat_interleave" operation but with custom logic: + + 1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size + [num_tokens, batch_size] that identifies which batch index each token + belongs to by checking if the token index falls between cu_start and cu_end. + 2. **Conditional Replacement**: Before expansion, it replaces specific values + (e.g., padding or special markers) in the input to prepare the data. + 3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted + sum that effectively "picks" the correct batch value for every token position + simultaneously, avoiding a Python loop over the batch. + """ + device = cu_num_tokens_ptr.device + batch_size = input_ptr.shape[0] + num_tokens = output_ptr.shape[0] - for req_idx in range(batch_size): - start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1] - end_idx = cu_num_tokens_ptr[req_idx] - num_tokens = end_idx - start_idx + if batch_size == 0 or num_tokens == 0: + return - src_val = input_ptr[req_idx] - src_val = replace_to if src_val == replace_from else src_val + cu_start = torch.cat([ + torch.tensor([0], pin_memory=True).to(device, non_blocking=True), + cu_num_tokens_ptr[:-1] + ]) + cu_end = cu_num_tokens_ptr - offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device) - mask = offset < num_tokens + token_indices = torch.arange(num_tokens, + device=device)[:, None] # [num_tokens, 1] + cu_start_exp = cu_start[None, :] # [1, batch_size] + cu_end_exp = cu_end[None, :] # [1, batch_size] - output_slice = start_idx + offset[mask] - output_ptr[output_slice] = src_val + in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp) + + replaced_input = torch.where(input_ptr == replace_from, replace_to, + input_ptr).float() + + token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) + + needs_update = in_range.any(dim=1) + + output_ptr[:] = torch.where(needs_update, token_values, output_ptr) def sample_recovered_tokens_pytorch( @@ -535,37 +646,77 @@ def sample_recovered_tokens_pytorch( vocab_size, IS_NGRAM=False, ): - batch_size = len(cu_num_draft_tokens) + """ + When a draft token is rejected, we must sample a "recovered" token from + a modified distribution. This function calculates that distribution across + the entire flattened batch. + + 1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it + determines which request in the batch each token belongs to. This is + necessary because 'q' (normalization factor) is stored per-request. + 2. **Probability Adjustment**: + - If N-GRAM: It zeroes out the draft token's probability in the target. + - If Probabilistic: It calculates max(0, target_probs - draft_probs) + as per the standard speculative decoding algorithm. + 3. **Normalization & Sampling**: It divides the adjusted probabilities + by the normalization distribution 'q'. To remain vectorized, it + broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab]. + 4. **Argmax Selection**: It selects the best recovery token for every + position in one pass using torch.argmax. + """ + device = output_token_ids.device + num_tokens = output_token_ids.shape[0] - for req_idx in range(batch_size): - start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] - end_idx = cu_num_draft_tokens[req_idx] - num_draft_tokens = end_idx - start_idx + if num_tokens == 0: + return - for pos in range(num_draft_tokens): - token_idx = start_idx + pos + cu_start = torch.cat([ + torch.tensor([0], pin_memory=True).to(device, non_blocking=True), + cu_num_draft_tokens[:-1], + ]) + cu_end = cu_num_draft_tokens - if IS_NGRAM: - draft_token_id = draft_token_ids[token_idx] - orig_prob = target_probs[token_idx, draft_token_id].item() - target_probs[token_idx, draft_token_id] = 0 - prob = target_probs[token_idx].clone() - else: - draft_p = draft_probs[token_idx].clone() - target_p = target_probs[token_idx].clone() - prob = torch.maximum(target_p - draft_p, - torch.tensor(0.0, device=target_p.device)) + token_indices = torch.arange(num_tokens, device=device) # [num_tokens] - q_values = torch.full((vocab_size, ), - float('-inf'), - device=q.device) - q_values[:vocab_size] = q[req_idx, :vocab_size] + token_indices_expanded = token_indices[:, None] # [num_tokens, 1] + cu_start_expanded = cu_start[None, :] # [1, batch_size] + cu_end_expanded = cu_end[None, :] # [1, batch_size] - recovered_id = torch.argmax(prob / q_values).item() - output_token_ids[token_idx] = recovered_id + in_range_mask = (token_indices_expanded >= cu_start_expanded) & ( + token_indices_expanded < cu_end_expanded) - if IS_NGRAM: - target_probs[token_idx, draft_token_id] = orig_prob + token_to_batch = torch.argmax(in_range_mask.int(), dim=1) + + has_match = in_range_mask.any(dim=1) + token_to_batch = torch.where(has_match, token_to_batch, 0) + + if IS_NGRAM: + token_indices = torch.arange(num_tokens, device=device) + + modified_target_probs = target_probs.clone() + modified_target_probs[token_indices, draft_token_ids] = 0 + prob = modified_target_probs + + else: + prob = torch.maximum( + target_probs - draft_probs, + torch.tensor(0.0, pin_memory=True).to(device, non_blocking=True), + ) + + q_values = q[token_to_batch] # [num_tokens, vocab_size] + + epsilon = 1e-10 + q_values_safe = torch.where(q_values == 0, epsilon, q_values) + q_values_safe = torch.where(torch.isinf(q_values), epsilon, q_values_safe) + + prob_over_q = prob / q_values_safe + + prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, + prob_over_q) + + recovered_ids = torch.argmax(prob_over_q, dim=1) + + output_token_ids[:] = recovered_ids @triton.jit(do_not_specialize=["max_spec_len"])