Simplify pytorch sampling kernel and logit processor (#2491)
This commit is contained in:
@@ -100,9 +100,154 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.do_tensor_parallel_all_gather = (
|
self.do_tensor_parallel_all_gather = (
|
||||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||||
)
|
)
|
||||||
|
self.final_logit_softcapping = getattr(
|
||||||
|
self.config, "final_logit_softcapping", None
|
||||||
|
)
|
||||||
|
|
||||||
def _get_normalized_prompt_logprobs(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
lm_head: VocabParallelEmbedding,
|
||||||
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||||
|
):
|
||||||
|
if isinstance(logits_metadata, ForwardBatch):
|
||||||
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||||
|
assert isinstance(logits_metadata, LogitsMetadata)
|
||||||
|
|
||||||
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
|
if logits_metadata.forward_mode.is_decode():
|
||||||
|
last_index = None
|
||||||
|
last_hidden = hidden_states
|
||||||
|
else:
|
||||||
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
|
last_logits = self._get_logits(last_hidden, lm_head)
|
||||||
|
if self.do_tensor_parallel_all_gather:
|
||||||
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
|
if self.final_logit_softcapping:
|
||||||
|
last_logits.div_(self.final_logit_softcapping)
|
||||||
|
torch.tanh(last_logits, out=last_logits)
|
||||||
|
last_logits.mul_(self.final_logit_softcapping)
|
||||||
|
|
||||||
|
# Return only last_logits if logprob is not requested
|
||||||
|
if not logits_metadata.return_logprob:
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=last_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
|
last_logits, logits_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if logits_metadata.forward_mode.is_decode():
|
||||||
|
if logits_metadata.return_top_logprob:
|
||||||
|
output_top_logprobs_val, output_top_logprobs_idx = (
|
||||||
|
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_top_logprobs_val = output_top_logprobs_idx = None
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=last_logits,
|
||||||
|
next_token_logprobs=last_logprobs,
|
||||||
|
output_top_logprobs_val=output_top_logprobs_val,
|
||||||
|
output_top_logprobs_idx=output_top_logprobs_idx,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Slice the requested tokens to compute logprob
|
||||||
|
pt, states, pruned_input_ids = 0, [], []
|
||||||
|
for start_len, extend_len in zip(
|
||||||
|
logits_metadata.extend_logprob_start_lens_cpu,
|
||||||
|
logits_metadata.extend_seq_lens_cpu,
|
||||||
|
):
|
||||||
|
states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||||
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||||
|
pt += extend_len
|
||||||
|
|
||||||
|
# Compute the logits and logprobs for all required tokens
|
||||||
|
states = torch.cat(states, dim=0)
|
||||||
|
all_logits = self._get_logits(states, lm_head)
|
||||||
|
if self.do_tensor_parallel_all_gather:
|
||||||
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
|
|
||||||
|
# The LM head's weights may be zero-padded for parallelism. Remove any
|
||||||
|
# extra logits that this padding may have produced.
|
||||||
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
|
if self.final_logit_softcapping:
|
||||||
|
all_logits.div_(self.final_logit_softcapping)
|
||||||
|
torch.tanh(all_logits, out=all_logits)
|
||||||
|
all_logits.mul_(self.final_logit_softcapping)
|
||||||
|
|
||||||
|
all_logprobs = all_logits
|
||||||
|
del all_logits, hidden_states
|
||||||
|
|
||||||
|
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
|
all_logprobs, logits_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the logprob of top-k tokens
|
||||||
|
if logits_metadata.return_top_logprob:
|
||||||
|
(
|
||||||
|
input_top_logprobs_val,
|
||||||
|
input_top_logprobs_idx,
|
||||||
|
output_top_logprobs_val,
|
||||||
|
output_top_logprobs_idx,
|
||||||
|
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
||||||
|
else:
|
||||||
|
input_top_logprobs_val = input_top_logprobs_idx = (
|
||||||
|
output_top_logprobs_val
|
||||||
|
) = output_top_logprobs_idx = None
|
||||||
|
|
||||||
|
# Compute the normalized logprobs for the requested tokens.
|
||||||
|
# Note that we pad a zero at the end for easy batching.
|
||||||
|
input_token_logprobs = all_logprobs[
|
||||||
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat(pruned_input_ids)[1:],
|
||||||
|
torch.tensor([0], device="cuda"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
|
input_token_logprobs,
|
||||||
|
logits_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=last_logits,
|
||||||
|
next_token_logprobs=last_logprobs,
|
||||||
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||||
|
input_token_logprobs=input_token_logprobs,
|
||||||
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
|
output_top_logprobs_val=output_top_logprobs_val,
|
||||||
|
output_top_logprobs_idx=output_top_logprobs_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
lm_head: VocabParallelEmbedding,
|
||||||
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hasattr(lm_head, "weight"):
|
||||||
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
||||||
|
else:
|
||||||
|
# GGUF models
|
||||||
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||||
|
|
||||||
|
# Optional scaling factor
|
||||||
|
if self.logit_scale is not None:
|
||||||
|
logits.mul_(self.logit_scale) # In-place multiply
|
||||||
|
return logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_normalized_prompt_logprobs(
|
||||||
input_token_logprobs: torch.Tensor,
|
input_token_logprobs: torch.Tensor,
|
||||||
logits_metadata: LogitsMetadata,
|
logits_metadata: LogitsMetadata,
|
||||||
):
|
):
|
||||||
@@ -177,142 +322,11 @@ class LogitsProcessor(nn.Module):
|
|||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
@staticmethod
|
||||||
self,
|
def compute_temp_top_p_normalized_logprobs(
|
||||||
input_ids,
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||||
hidden_states,
|
|
||||||
lm_head: VocabParallelEmbedding,
|
|
||||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
|
||||||
):
|
|
||||||
if isinstance(logits_metadata, ForwardBatch):
|
|
||||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
|
||||||
assert isinstance(logits_metadata, LogitsMetadata)
|
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
|
||||||
if logits_metadata.forward_mode.is_decode():
|
|
||||||
last_index = None
|
|
||||||
last_hidden = hidden_states
|
|
||||||
else:
|
|
||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
|
||||||
last_hidden = hidden_states[last_index]
|
|
||||||
|
|
||||||
last_logits = self._get_logits(last_hidden, lm_head)
|
|
||||||
if self.do_tensor_parallel_all_gather:
|
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
|
||||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
|
||||||
|
|
||||||
if hasattr(self.config, "final_logit_softcapping"):
|
|
||||||
last_logits.div_(self.config.final_logit_softcapping)
|
|
||||||
torch.tanh(last_logits, out=last_logits)
|
|
||||||
last_logits.mul_(self.config.final_logit_softcapping)
|
|
||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
|
||||||
if not logits_metadata.return_logprob:
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
|
||||||
|
|
||||||
if logits_metadata.forward_mode.is_decode():
|
|
||||||
if logits_metadata.return_top_logprob:
|
|
||||||
output_top_logprobs_val, output_top_logprobs_idx = (
|
|
||||||
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_top_logprobs_val = output_top_logprobs_idx = None
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
next_token_logprobs=last_logprobs,
|
|
||||||
output_top_logprobs_val=output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Slice the requested tokens to compute logprob
|
|
||||||
pt, states, pruned_input_ids = 0, [], []
|
|
||||||
for start_len, extend_len in zip(
|
|
||||||
logits_metadata.extend_logprob_start_lens_cpu,
|
|
||||||
logits_metadata.extend_seq_lens_cpu,
|
|
||||||
):
|
|
||||||
states.append(hidden_states[pt + start_len : pt + extend_len])
|
|
||||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
|
||||||
pt += extend_len
|
|
||||||
|
|
||||||
# Compute the logits and logprobs for all required tokens
|
|
||||||
states = torch.cat(states, dim=0)
|
|
||||||
all_logits = self._get_logits(states, lm_head)
|
|
||||||
if self.do_tensor_parallel_all_gather:
|
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
|
||||||
|
|
||||||
# The LM head's weights may be zero-padded for parallelism. Remove any
|
|
||||||
# extra logits that this padding may have produced.
|
|
||||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
|
||||||
|
|
||||||
if hasattr(self.config, "final_logit_softcapping"):
|
|
||||||
all_logits.div_(self.config.final_logit_softcapping)
|
|
||||||
torch.tanh(all_logits, out=all_logits)
|
|
||||||
all_logits.mul_(self.config.final_logit_softcapping)
|
|
||||||
|
|
||||||
all_logprobs = all_logits
|
|
||||||
del all_logits, hidden_states
|
|
||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
|
||||||
if logits_metadata.return_top_logprob:
|
|
||||||
(
|
|
||||||
input_top_logprobs_val,
|
|
||||||
input_top_logprobs_idx,
|
|
||||||
output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx,
|
|
||||||
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
|
||||||
else:
|
|
||||||
input_top_logprobs_val = input_top_logprobs_idx = (
|
|
||||||
output_top_logprobs_val
|
|
||||||
) = output_top_logprobs_idx = None
|
|
||||||
|
|
||||||
# Compute the normalized logprobs for the requested tokens.
|
|
||||||
# Note that we pad a zero at the end for easy batching.
|
|
||||||
input_token_logprobs = all_logprobs[
|
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.cat(pruned_input_ids)[1:],
|
|
||||||
torch.tensor([0], device="cuda"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
|
||||||
input_token_logprobs,
|
|
||||||
logits_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
next_token_logprobs=last_logprobs,
|
|
||||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
|
||||||
input_token_logprobs=input_token_logprobs,
|
|
||||||
input_top_logprobs_val=input_top_logprobs_val,
|
|
||||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
|
||||||
output_top_logprobs_val=output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
lm_head: VocabParallelEmbedding,
|
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(lm_head, "weight"):
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
|
||||||
else:
|
|
||||||
# GGUF models
|
|
||||||
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
|
||||||
|
|
||||||
# Optional scaling factor, backported from vLLM 0.4
|
|
||||||
if self.logit_scale is not None:
|
|
||||||
logits.mul_(self.logit_scale) # In-place multiply
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
|
|||||||
# Post process logits
|
# Post process logits
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
logits = None
|
|
||||||
del logits
|
del logits
|
||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
|
|||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
sampling_info.min_ps,
|
sampling_info.min_ps,
|
||||||
|
sampling_info.need_min_p_sampling,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
top_ks: torch.Tensor,
|
top_ks: torch.Tensor,
|
||||||
top_ps: torch.Tensor,
|
top_ps: torch.Tensor,
|
||||||
min_ps: torch.Tensor,
|
min_ps: torch.Tensor,
|
||||||
|
need_min_p_sampling: bool,
|
||||||
):
|
):
|
||||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
|
||||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
|
||||||
probs_sort[
|
probs_sort[
|
||||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||||
>= top_ks.view(-1, 1)
|
>= top_ks.view(-1, 1)
|
||||||
] = 0.0
|
] = 0.0
|
||||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
|
||||||
|
if need_min_p_sampling:
|
||||||
|
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||||
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||||
|
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
# int32 range is enough to represent the token ids
|
# int32 range is enough to represent the token ids
|
||||||
probs_idx = probs_idx.to(torch.int32)
|
probs_idx = probs_idx.to(torch.int32)
|
||||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_normalize_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
top_ps: torch.Tensor,
|
||||||
|
):
|
||||||
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
|
return top_p_renorm_prob(probs, top_ps)
|
||||||
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
|
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
||||||
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1086,9 +1086,9 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.return_logprob = self.return_logprob or other.return_logprob
|
self.return_logprob |= other.return_logprob
|
||||||
self.has_stream = self.has_stream or other.has_stream
|
self.has_stream |= other.has_stream
|
||||||
self.has_grammar = self.has_grammar or other.has_grammar
|
self.has_grammar |= other.has_grammar
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||||
@@ -1115,7 +1115,6 @@ class ScheduleBatch:
|
|||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
seq_lens_sum=self.seq_lens_sum,
|
seq_lens_sum=self.seq_lens_sum,
|
||||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=self.global_num_tokens,
|
||||||
@@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
|
|||||||
# The sum of all sequence lengths
|
# The sum of all sequence lengths
|
||||||
seq_lens_sum: int
|
seq_lens_sum: int
|
||||||
|
|
||||||
# The memory pool operation records
|
|
||||||
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
|
||||||
|
|
||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
top_logprobs_nums: Optional[List[int]]
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
|||||||
@@ -387,8 +387,14 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
if forward_batch.return_logprob:
|
if forward_batch.return_logprob:
|
||||||
next_token_logprobs = torch.nn.functional.log_softmax(
|
logits_metadata = LogitsMetadata(
|
||||||
next_token_logits, dim=-1
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||||
|
)
|
||||||
|
next_token_logprobs = (
|
||||||
|
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
|
||||||
|
next_token_logits, logits_metadata
|
||||||
|
)
|
||||||
)
|
)
|
||||||
logits_output = LogitsProcessorOutput(
|
logits_output = LogitsProcessorOutput(
|
||||||
next_token_logits=next_token_logits,
|
next_token_logits=next_token_logits,
|
||||||
@@ -396,10 +402,6 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
logits_metadata = LogitsMetadata(
|
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
|
||||||
)
|
|
||||||
(
|
(
|
||||||
logits_output.output_top_logprobs_val,
|
logits_output.output_top_logprobs_val,
|
||||||
logits_output.output_top_logprobs_idx,
|
logits_output.output_top_logprobs_idx,
|
||||||
|
|||||||
@@ -698,11 +698,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--disable-nan-detection",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable the NaN detection for better performance.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-overlap-schedule",
|
"--disable-overlap-schedule",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user