Organize sampling batch info better (#1562)

This commit is contained in:
Lianmin Zheng
2024-10-03 18:29:49 -07:00
committed by GitHub
parent e0b5dbcec1
commit 32eb6e96f2
8 changed files with 43 additions and 35 deletions

View File

@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@dataclasses.dataclass
class SamplingBatchInfo:
# Basic Info
vocab_size: int
# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# Dispatch in CUDA graph
need_min_p_sampling: bool
# Bias Tensors
vocab_size: int
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
ret = cls(vocab_size=vocab_size)
with torch.device("cuda"):
ret.temperatures = torch.tensor(
temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
ret.top_ps = torch.tensor(
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
ret.top_ks = torch.tensor(
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int
)
ret.min_ps = torch.tensor(
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this