Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -515,11 +515,11 @@ class ScheduleBatch:
|
|||||||
assert seq_len - pre_len == req.extend_input_len
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
if pre_len > 0:
|
if pre_len > 0:
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
||||||
:pre_len
|
req.prefix_indices
|
||||||
] = req.prefix_indices
|
)
|
||||||
|
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||||
out_cache_loc[pt : pt + req.extend_input_len]
|
out_cache_loc[pt : pt + req.extend_input_len]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -535,10 +535,15 @@ class ScheduleBatch:
|
|||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
with out_cache_loc.device:
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
self.device, non_blocking=True
|
||||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
|
)
|
||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
@@ -782,8 +787,8 @@ class ScheduleBatch:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||||
new_indices = torch.tensor(
|
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||||
keep_indices, dtype=torch.int32, device=self.seq_lens.device
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ class Scheduler:
|
|||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||||
|
self.device = self.tp_worker.device
|
||||||
|
|
||||||
# Get token and memory info from the model worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
@@ -758,9 +759,7 @@ class Scheduler:
|
|||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
logits_output.next_token_logprobs[
|
logits_output.next_token_logprobs[
|
||||||
torch.arange(
|
torch.arange(len(next_token_ids), device=self.device),
|
||||||
len(next_token_ids), device=next_token_ids.device
|
|
||||||
),
|
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
)
|
)
|
||||||
@@ -828,7 +827,7 @@ class Scheduler:
|
|||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
next_token_logprobs = logits_output.next_token_logprobs[
|
next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
torch.arange(len(next_token_ids), device=self.device),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class BaseTokenToKVPool:
|
|||||||
select_index = self.free_slots[:need_size]
|
select_index = self.free_slots[:need_size]
|
||||||
self.free_slots = self.free_slots[need_size:]
|
self.free_slots = self.free_slots[need_size:]
|
||||||
|
|
||||||
return select_index.to(self.device)
|
return select_index.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
def free(self, free_index: torch.Tensor):
|
def free(self, free_index: torch.Tensor):
|
||||||
if self.is_not_in_free_group:
|
if self.is_not_in_free_group:
|
||||||
|
|||||||
@@ -135,25 +135,22 @@ class ForwardBatch:
|
|||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if not ret.forward_mode.is_decode():
|
if not ret.forward_mode.is_decode():
|
||||||
ret.positions = torch.tensor(
|
ret.positions = torch.concat(
|
||||||
np.concatenate(
|
[
|
||||||
[
|
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||||
np.arange(prefix_len, prefix_len + extend_len)
|
for prefix_len, extend_len in zip(
|
||||||
for prefix_len, extend_len in zip(
|
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
)
|
||||||
)
|
],
|
||||||
],
|
axis=0,
|
||||||
axis=0,
|
|
||||||
),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.image_inputs = batch.image_inputs
|
ret.image_inputs = batch.image_inputs
|
||||||
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
ret.extend_seq_lens = torch.tensor(
|
||||||
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
|
).to(device, non_blocking=True)
|
||||||
ret.extend_prefix_lens = torch.tensor(
|
ret.extend_prefix_lens = torch.tensor(
|
||||||
batch.extend_prefix_lens, device=device
|
batch.extend_prefix_lens, dtype=torch.int32
|
||||||
)
|
).to(device, non_blocking=True)
|
||||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||||
|
|||||||
@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
|
|||||||
|
|
||||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||||
|
|
||||||
|
is_required = False
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
penalizer.prepare_if_required()
|
pen_is_required = penalizer.prepare_if_required()
|
||||||
|
is_required |= pen_is_required
|
||||||
|
self.is_required = is_required
|
||||||
|
|
||||||
self.cumulate_input_tokens(
|
if self.is_required:
|
||||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
self.cumulate_input_tokens(
|
||||||
)
|
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||||
|
)
|
||||||
|
|
||||||
def reqs(self):
|
def reqs(self):
|
||||||
return self.batch.reqs
|
return self.batch.reqs
|
||||||
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
|
|||||||
Args:
|
Args:
|
||||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_required:
|
||||||
|
return
|
||||||
|
|
||||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The logits after applying the penalizers.
|
torch.Tensor: The logits after applying the penalizers.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_required:
|
||||||
|
return
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
logits = penalizer.apply(logits)
|
logits = penalizer.apply(logits)
|
||||||
|
|
||||||
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
|
|||||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_required:
|
||||||
|
return
|
||||||
|
|
||||||
empty_indices = len(indices_to_keep) == 0
|
empty_indices = len(indices_to_keep) == 0
|
||||||
|
|
||||||
|
is_required = False
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
if not penalizer.is_required() or empty_indices:
|
tmp_is_required = penalizer.is_required()
|
||||||
|
is_required = is_required or tmp_is_required
|
||||||
|
if not tmp_is_required or empty_indices:
|
||||||
penalizer.teardown()
|
penalizer.teardown()
|
||||||
else:
|
else:
|
||||||
# create tensor index only when it's needed
|
# create tensor index only when it's needed
|
||||||
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
|
|||||||
indices_to_keep=indices_to_keep,
|
indices_to_keep=indices_to_keep,
|
||||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||||
)
|
)
|
||||||
|
self.is_required = is_required
|
||||||
|
|
||||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||||
"""
|
"""
|
||||||
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
|
|||||||
Args:
|
Args:
|
||||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||||
"""
|
"""
|
||||||
if self.vocab_size != their.vocab_size:
|
if not self.is_required and not their.is_required:
|
||||||
raise ValueError(
|
return
|
||||||
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self.is_required |= their.is_required
|
||||||
for Penalizer, their_penalizer in their.penalizers.items():
|
for Penalizer, their_penalizer in their.penalizers.items():
|
||||||
if Penalizer not in self.penalizers:
|
if Penalizer not in self.penalizers:
|
||||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||||
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
def prepare_if_required(self):
|
def prepare_if_required(self):
|
||||||
if self.is_required():
|
if self.is_required():
|
||||||
self.prepare()
|
self.prepare()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def teardown(self):
|
def teardown(self):
|
||||||
if self.is_prepared():
|
if self.is_prepared():
|
||||||
|
|||||||
@@ -48,20 +48,24 @@ class SamplingBatchInfo:
|
|||||||
disable_penalizer: bool,
|
disable_penalizer: bool,
|
||||||
):
|
):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
with batch.input_ids.device:
|
device = batch.input_ids.device
|
||||||
temperatures = torch.tensor(
|
temperatures = (
|
||||||
|
torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
).view(-1, 1)
|
|
||||||
top_ps = torch.tensor(
|
|
||||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
|
||||||
)
|
|
||||||
top_ks = torch.tensor(
|
|
||||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
|
||||||
)
|
|
||||||
min_ps = torch.tensor(
|
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
|
||||||
)
|
)
|
||||||
|
.view(-1, 1)
|
||||||
|
.to(device, non_blocking=True)
|
||||||
|
)
|
||||||
|
top_ps = torch.tensor(
|
||||||
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
top_ks = torch.tensor(
|
||||||
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
min_ps = torch.tensor(
|
||||||
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
temperatures=temperatures,
|
temperatures=temperatures,
|
||||||
@@ -80,7 +84,7 @@ class SamplingBatchInfo:
|
|||||||
#
|
#
|
||||||
# While we choose not to even create the class instances if they are not required, this
|
# While we choose not to even create the class instances if they are not required, this
|
||||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||||
# handle {filter_batch()} and {merge()} cases as well.
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||||
if disable_penalizer:
|
if disable_penalizer:
|
||||||
ret.penalizer_orchestrator = None
|
ret.penalizer_orchestrator = None
|
||||||
else:
|
else:
|
||||||
@@ -112,19 +116,20 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = None
|
self.linear_penalties = None
|
||||||
|
|
||||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||||
|
if not penalizer.is_prepared():
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||||
if penalizer.is_prepared():
|
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
|
||||||
else:
|
else:
|
||||||
if penalizer.is_prepared():
|
if self.linear_penalties is None:
|
||||||
if self.linear_penalties is None:
|
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
self.linear_penalties = torch.zeros(
|
||||||
self.linear_penalties = torch.zeros(
|
(bs, self.vocab_size),
|
||||||
(bs, self.vocab_size),
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
device=self.device,
|
||||||
device=self.device,
|
)
|
||||||
)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||||
|
|||||||
@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||||
)
|
)
|
||||||
|
|
||||||
actual = orchestrator.apply(
|
original = torch.ones(
|
||||||
torch.ones(
|
size=(len(case.test_subjects), self.vocab_size),
|
||||||
size=(len(case.test_subjects), self.vocab_size),
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
device=self.device,
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
actual = orchestrator.apply(original.clone())
|
||||||
expected = torch.cat(
|
expected = torch.cat(
|
||||||
tensors=[
|
tensors=[
|
||||||
subject.steps[0].expected_logits
|
subject.steps[0].expected_logits
|
||||||
for subject in case.test_subjects
|
for subject in case.test_subjects
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if actual is None:
|
||||||
|
actual = original
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
actual=actual,
|
actual=actual,
|
||||||
expected=expected,
|
expected=expected,
|
||||||
@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if actual_logits is None:
|
||||||
|
continue
|
||||||
filtered_expected_logits = torch.cat(
|
filtered_expected_logits = torch.cat(
|
||||||
tensors=[
|
tensors=[
|
||||||
subject.steps[0].expected_logits
|
subject.steps[0].expected_logits
|
||||||
@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_logits = orchestrator.apply(
|
original = torch.ones(
|
||||||
torch.ones(
|
size=(len(filtered_subjects), self.vocab_size),
|
||||||
size=(len(filtered_subjects), self.vocab_size),
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
device=self.device,
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
actual_logits = orchestrator.apply(original.clone())
|
||||||
filtered_expected_logits = torch.cat(
|
filtered_expected_logits = torch.cat(
|
||||||
tensors=[
|
tensors=[
|
||||||
subject.steps[i].expected_logits
|
subject.steps[i].expected_logits
|
||||||
for subject in filtered_subjects
|
for subject in filtered_subjects
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if actual_logits is None:
|
||||||
|
actual_logits = original
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
actual=actual_logits,
|
actual=actual_logits,
|
||||||
expected=filtered_expected_logits,
|
expected=filtered_expected_logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user