Organize sampling batch info better (#1562)
This commit is contained in:
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Basic Info
|
||||
vocab_size: int
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
|
||||
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
|
||||
regex_fsms: List[RegexGuide] = None
|
||||
regex_fsm_states: List[int] = None
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool = False
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
linear_penalties: torch.Tensor = None
|
||||
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
|
||||
with torch.device("cuda"):
|
||||
ret.temperatures = torch.tensor(
|
||||
temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
|
||||
Reference in New Issue
Block a user