Fix illegal tokens during sampling (#676)

This commit is contained in:
Liangsheng Yin
2024-07-20 03:11:15 -07:00
committed by GitHub
parent 490a1f39dd
commit f424e76d96
4 changed files with 18 additions and 15 deletions

View File

@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Union[List[str], str] text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.

View File

@@ -665,16 +665,20 @@ class Batch:
# TODO(lmzheng): apply penalty # TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
try:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
(max_top_k_round, batch_size), device=probs.device batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
) probs, uniform_samples, self.top_ks, self.top_ps
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( )
probs, uniform_samples, self.top_ks, self.top_ps
) # FIXME: this is a temporary fix for the illegal token ids
except RuntimeError as e: illegal_mask = torch.logical_or(
warnings.warn(f"Ignore errors in sampling: {e}") batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal sampled token ids")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
batch_next_token_ids = torch.argmax(probs, dim=-1) batch_next_token_ids = torch.argmax(probs, dim=-1)
if has_regex: if has_regex:

View File

@@ -246,12 +246,11 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list) self, max_batch_size_to_capture=max(batch_size_list)
) )
logger.info(f"Capture for batch sizes {batch_size_list}")
try: try:
self.cuda_graph_runner.capture(batch_size_list) self.cuda_graph_runner.capture(batch_size_list)
except: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed. Possible solutions:\n" f"Capture cuda graph failed {e}. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n" f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n" f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n" f"Open an issue on GitHub with reproducible scripts if you need help.\n"

View File

@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Union[List[str], str] text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.