Fix illegal tokens during sampling (#676)
This commit is contained in:
@@ -665,16 +665,20 @@ class Batch:
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
try:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
|
||||
# FIXME: this is a temporary fix for the illegal token ids
|
||||
illegal_mask = torch.logical_or(
|
||||
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
||||
)
|
||||
if torch.any(illegal_mask):
|
||||
warnings.warn("Illegal sampled token ids")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||
|
||||
if has_regex:
|
||||
|
||||
@@ -246,12 +246,11 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self, max_batch_size_to_capture=max(batch_size_list)
|
||||
)
|
||||
logger.info(f"Capture for batch sizes {batch_size_list}")
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
except:
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed. Possible solutions:\n"
|
||||
f"Capture cuda graph failed {e}. Possible solutions:\n"
|
||||
f"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
f"2. set --mem-fraction-static to a smaller value\n"
|
||||
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Union[List[str], str]
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
|
||||
Reference in New Issue
Block a user