[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
|
|||||||
model_config=model_runner.model_config,
|
model_config=model_runner.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
|
enable_custom_logit_processor=False,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend()
|
batch.prepare_for_extend()
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|||||||
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
|
|||||||
"""Apply custom logit processors to the logits.
|
"""Apply custom logit processors to the logits.
|
||||||
This function will modify the logits in-place."""
|
This function will modify the logits in-place."""
|
||||||
|
|
||||||
|
assert logits.shape[0] == len(sampling_batch_info), (
|
||||||
|
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||||
|
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||||
|
)
|
||||||
|
|
||||||
for _, (
|
for _, (
|
||||||
processor,
|
processor,
|
||||||
batch_mask,
|
batch_mask,
|
||||||
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
|
|||||||
# Get the batch indices that need to be processed
|
# Get the batch indices that need to be processed
|
||||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||||
|
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||||
|
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||||
|
)
|
||||||
|
|
||||||
# Apply the processor to the logits
|
# Apply the processor to the logits
|
||||||
logits[batch_mask] = processor(
|
logits[batch_mask] = processor(
|
||||||
logits[batch_mask],
|
logits[batch_mask],
|
||||||
|
|||||||
@@ -595,6 +595,9 @@ class ScheduleBatch:
|
|||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[SpecInfo] = None
|
||||||
|
|
||||||
|
# Enable custom logit processor
|
||||||
|
enable_custom_logit_processor: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -605,6 +608,7 @@ class ScheduleBatch:
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
|
enable_custom_logit_processor: bool,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -618,6 +622,7 @@ class ScheduleBatch:
|
|||||||
has_grammar=any(req.grammar for req in reqs),
|
has_grammar=any(req.grammar for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
|
|||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
|
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -966,6 +966,7 @@ class Scheduler:
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
|
self.server_args.enable_custom_logit_processor,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1520,6 +1521,7 @@ class Scheduler:
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
|
self.server_args.enable_custom_logit_processor,
|
||||||
)
|
)
|
||||||
idle_batch.prepare_for_idle()
|
idle_batch.prepare_for_idle()
|
||||||
return idle_batch
|
return idle_batch
|
||||||
|
|||||||
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
|
|||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
# Check if any request has custom logit processor
|
# Check if any request has custom logit processor
|
||||||
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
|
has_custom_logit_processor = (
|
||||||
|
batch.enable_custom_logit_processor # check the flag first.
|
||||||
|
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||||
|
)
|
||||||
|
|
||||||
if has_custom_logit_processor:
|
if has_custom_logit_processor:
|
||||||
# Merge the same type of custom logit processors together
|
# Merge the same type of custom logit processors together
|
||||||
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
|
|||||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||||
):
|
):
|
||||||
"""Filter the custom logit processor and custom params"""
|
"""Filter the custom logit processor and custom params"""
|
||||||
if not self.custom_logit_processor:
|
|
||||||
return
|
|
||||||
self.custom_logit_processor = {
|
self.custom_logit_processor = {
|
||||||
k: (p, mask[new_indices])
|
k: (p, mask[new_indices])
|
||||||
for k, (p, mask) in self.custom_logit_processor.items()
|
for k, (p, mask) in self.custom_logit_processor.items()
|
||||||
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
|
|||||||
}
|
}
|
||||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||||
|
|
||||||
if len(self) == 0:
|
# If the custom logit processor is an empty dict, set the flag to False,
|
||||||
|
# and set the custom logit processor and custom params to None.
|
||||||
|
if len(self.custom_logit_processor) == 0:
|
||||||
self.custom_logit_processor = None
|
self.custom_logit_processor = None
|
||||||
self.custom_params = None
|
self.custom_params = None
|
||||||
self.has_custom_logit_processor = False
|
self.has_custom_logit_processor = False
|
||||||
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_custom_logit_processor(
|
def merge_custom_logit_processor(
|
||||||
lhs: Optional[Dict[str, torch.Tensor]],
|
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||||
rhs: Optional[Dict[str, torch.Tensor]],
|
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||||
bs1: int,
|
bs1: int,
|
||||||
bs2: int,
|
bs2: int,
|
||||||
device: str,
|
device: str,
|
||||||
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
|
|||||||
)
|
)
|
||||||
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
||||||
|
|
||||||
|
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
||||||
|
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
||||||
|
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
||||||
|
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
||||||
|
f"\n{lhs=}\n{rhs=}"
|
||||||
|
)
|
||||||
|
|
||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
for item in [
|
# Merge the logit bias tensor
|
||||||
"temperatures",
|
|
||||||
"top_ps",
|
|
||||||
"top_ks",
|
|
||||||
"min_ps",
|
|
||||||
]:
|
|
||||||
self_val = getattr(self, item, None)
|
|
||||||
other_val = getattr(other, item, None)
|
|
||||||
setattr(self, item, torch.concat([self_val, other_val]))
|
|
||||||
|
|
||||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
|
||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
|
||||||
|
|
||||||
# Merge the custom logit processors and custom params lists
|
# Merge the custom logit processors and custom params lists
|
||||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||||
# Merge the custom logit processors
|
# Merge the custom logit processors
|
||||||
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
|
|||||||
# Set the flag to True if any of the two has custom logit processor
|
# Set the flag to True if any of the two has custom logit processor
|
||||||
self.has_custom_logit_processor = True
|
self.has_custom_logit_processor = True
|
||||||
|
|
||||||
|
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
||||||
|
# please make sure any merge operation with len(self) or len(other) is done before
|
||||||
|
# the merge operation of the temperatures tensor below.
|
||||||
|
for item in [
|
||||||
|
"temperatures",
|
||||||
|
"top_ps",
|
||||||
|
"top_ks",
|
||||||
|
"min_ps",
|
||||||
|
]:
|
||||||
|
self_val = getattr(self, item, None)
|
||||||
|
other_val = getattr(other, item, None)
|
||||||
|
setattr(self, item, torch.concat([self_val, other_val]))
|
||||||
|
|
||||||
|
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||||
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor):
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
# Apply logit_bias
|
# Apply logit_bias
|
||||||
if self.logit_bias is not None:
|
if self.logit_bias is not None:
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(all(x is not None for x in logprobs))
|
self.assertTrue(all(x is not None for x in logprobs))
|
||||||
|
|
||||||
def run_custom_logit_processor(self, target_token_id: int):
|
def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
|
||||||
"""Test custom logit processor with custom params."""
|
"""Test custom logit processor with custom params.
|
||||||
|
|
||||||
|
If target_token_id is None, the custom logit processor won't be passed in.
|
||||||
|
"""
|
||||||
|
|
||||||
custom_params = {"token_id": target_token_id}
|
custom_params = {"token_id": target_token_id}
|
||||||
|
|
||||||
@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
# Custom json data with custom logit processor and params.
|
# Custom json data with custom logit processor and params.
|
||||||
custom_json = base_json.copy()
|
custom_json = base_json.copy()
|
||||||
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
|
# Only set the custom logit processor if target_token_id is not None.
|
||||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
if target_token_id is not None:
|
||||||
|
custom_json["custom_logit_processor"] = (
|
||||||
|
DeterministicLogitProcessor().to_str()
|
||||||
|
)
|
||||||
|
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||||
|
|
||||||
custom_response = requests.post(
|
custom_response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
sampled_tokens = [x[1] for x in output_token_logprobs]
|
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||||
|
|
||||||
# The logit processor should always sample the given token as the logits is deterministic.
|
# The logit processor should always sample the given token as the logits is deterministic.
|
||||||
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
|
if target_token_id is not None:
|
||||||
|
self.assertTrue(
|
||||||
|
all(x == custom_params["token_id"] for x in sampled_tokens),
|
||||||
|
# Print the detailed test case info if the test fails.
|
||||||
|
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||||
|
)
|
||||||
|
|
||||||
def test_custom_logit_processor(self):
|
def test_custom_logit_processor(self):
|
||||||
"""Test custom logit processor with a single request."""
|
"""Test custom logit processor with a single request."""
|
||||||
# Temporarily skipped due to buggy implementation
|
|
||||||
return
|
|
||||||
self.run_custom_logit_processor(target_token_id=5)
|
self.run_custom_logit_processor(target_token_id=5)
|
||||||
|
|
||||||
def test_custom_logit_processor_batch(self):
|
def test_custom_logit_processor_batch(self):
|
||||||
"""Test custom logit processor with a batch of requests."""
|
"""Test custom logit processor with a batch of requests."""
|
||||||
# Temporarily skipped due to buggy implementation
|
|
||||||
return
|
|
||||||
target_token_ids = list(range(32))
|
target_token_ids = list(range(32))
|
||||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
|
def test_custom_logit_processor_batch_mixed(self):
|
||||||
|
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||||
|
target_token_ids = list(range(32)) + [None] * 16
|
||||||
|
random.shuffle(target_token_ids)
|
||||||
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
def test_get_server_info(self):
|
def test_get_server_info(self):
|
||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|||||||
Reference in New Issue
Block a user