[Fix] Fix logprob and normalized_logprob (#1428)
This commit is contained in:
25
.github/workflows/pr-test.yml
vendored
25
.github/workflows/pr-test.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 8
|
python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
|
||||||
|
|
||||||
unit-test-backend-part-2:
|
unit-test-backend-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -73,7 +73,26 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 8
|
python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
|
||||||
|
|
||||||
|
unit-test-backend-part-3:
|
||||||
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
|
runs-on: 1-gpu-runner
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -e "python[dev]"
|
||||||
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
|
- name: Run test
|
||||||
|
timeout-minutes: 20
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 run_suite.py --suite minimal --range-begin 14
|
||||||
|
|
||||||
performance-test-1-gpu-part-1:
|
performance-test-1-gpu-part-1:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -217,7 +236,7 @@ jobs:
|
|||||||
|
|
||||||
finish:
|
finish:
|
||||||
needs: [
|
needs: [
|
||||||
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2,
|
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3,
|
||||||
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
||||||
accuracy-test-1-gpu, accuracy-test-2-gpu
|
accuracy-test-1-gpu, accuracy-test-2-gpu
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
# Node 1
|
# Node 1
|
||||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
|
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
**Generative Models**
|
**Generative Models**
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.sampling_params = sampling_params
|
req.sampling_params = sampling_params
|
||||||
req.fill_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return input_ids, reqs
|
return input_ids, reqs
|
||||||
@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|||||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||||
i, : bench_args.cut_len
|
i, : bench_args.cut_len
|
||||||
]
|
]
|
||||||
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.sampling_params = sampling_params
|
req.sampling_params = sampling_params
|
||||||
req.fill_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return reqs
|
return reqs
|
||||||
|
|||||||
@@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
# Compute logprob
|
# Compute logprob
|
||||||
data = {
|
data = {
|
||||||
"text": [s.text_ + c for c in choices],
|
"text": [s.text_ + c for c in choices],
|
||||||
"sampling_params": {"max_new_tokens": 0},
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 0,
|
||||||
|
"temperature": 0,
|
||||||
|
},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"logprob_start_len": max(prompt_len - 2, 0),
|
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
|
||||||
}
|
}
|
||||||
obj = self._generate_http_request(s, data)
|
obj = self._generate_http_request(s, data)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class AttentionBackend(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
@@ -66,9 +67,11 @@ class AttentionBackend(ABC):
|
|||||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
"""Run a forward for decode."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
"""Run a forward for extend."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if total_num_tokens >= global_config.layer_sync_threshold:
|
if total_num_tokens >= global_config.layer_sync_threshold:
|
||||||
|
# TODO: Revisit this. Why is this synchronize needed?
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
|
|||||||
|
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
normalized_prompt_logprobs: torch.Tensor
|
normalized_prompt_logprobs: torch.Tensor
|
||||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||||
input_token_logprobs: torch.Tensor
|
input_token_logprobs: torch.Tensor
|
||||||
|
|
||||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsMetadata:
|
class LogitsMetadata:
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
return_top_logprob: bool = False
|
||||||
|
|
||||||
extend_seq_lens: Optional[torch.Tensor] = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_start_loc: Optional[torch.Tensor] = None
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
|
||||||
|
|
||||||
extend_seq_lens_cpu: List[int] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
logprob_start_lens_cpu: List[int] = None
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||||
|
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||||
|
if input_metadata.forward_mode.is_extend():
|
||||||
|
extend_logprob_pruned_lens_cpu = [
|
||||||
|
extend_len - start_len
|
||||||
|
for extend_len, start_len in zip(
|
||||||
|
input_metadata.extend_seq_lens,
|
||||||
|
input_metadata.extend_logprob_start_lens_cpu,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
extend_logprob_pruned_lens_cpu = None
|
||||||
return cls(
|
return cls(
|
||||||
forward_mode=input_metadata.forward_mode,
|
forward_mode=input_metadata.forward_mode,
|
||||||
extend_seq_lens=input_metadata.extend_seq_lens,
|
|
||||||
extend_start_loc=input_metadata.extend_start_loc,
|
|
||||||
return_logprob=input_metadata.return_logprob,
|
|
||||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||||
|
return_logprob=input_metadata.return_logprob,
|
||||||
|
return_top_logprob=return_top_logprob,
|
||||||
|
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||||
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
||||||
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
|
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
|
||||||
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
|
|||||||
def _get_normalized_prompt_logprobs(
|
def _get_normalized_prompt_logprobs(
|
||||||
self,
|
self,
|
||||||
input_token_logprobs: torch.Tensor,
|
input_token_logprobs: torch.Tensor,
|
||||||
cum_start_len0: torch.Tensor,
|
|
||||||
cum_start_len1: torch.Tensor,
|
|
||||||
logits_metadata: LogitsMetadata,
|
logits_metadata: LogitsMetadata,
|
||||||
):
|
):
|
||||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||||
|
pruned_lens = torch.tensor(
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
|
start = torch.zeros_like(pruned_lens)
|
||||||
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
|
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
||||||
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
end = torch.clamp(
|
||||||
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
||||||
|
)
|
||||||
sum_logp = (
|
sum_logp = (
|
||||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||||
)
|
)
|
||||||
normalized_prompt_logprobs = sum_logp / (
|
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
||||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return normalized_prompt_logprobs
|
return normalized_prompt_logprobs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||||
|
max_k = max(logits_metadata.top_logprobs_nums)
|
||||||
|
ret = all_logprobs.topk(max_k, dim=1)
|
||||||
|
values = ret.values.tolist()
|
||||||
|
indices = ret.indices.tolist()
|
||||||
|
|
||||||
if logits_metadata.forward_mode.is_decode():
|
if logits_metadata.forward_mode.is_decode():
|
||||||
output_top_logprobs = []
|
output_top_logprobs = []
|
||||||
max_k = max(logits_metadata.top_logprobs_nums)
|
|
||||||
ret = all_logprobs.topk(max_k, dim=1)
|
|
||||||
values = ret.values.tolist()
|
|
||||||
indices = ret.indices.tolist()
|
|
||||||
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
||||||
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
||||||
return None, output_top_logprobs
|
return None, output_top_logprobs
|
||||||
else:
|
else:
|
||||||
# TODO: vectorize the code below
|
|
||||||
input_top_logprobs, output_top_logprobs = [], []
|
input_top_logprobs, output_top_logprobs = [], []
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
for k, pruned_len in zip(
|
||||||
|
logits_metadata.top_logprobs_nums,
|
||||||
max_k = max(logits_metadata.top_logprobs_nums)
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||||
ret = all_logprobs.topk(max_k, dim=1)
|
):
|
||||||
values = ret.values.tolist()
|
if pruned_len <= 0:
|
||||||
indices = ret.indices.tolist()
|
|
||||||
|
|
||||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
|
||||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
|
||||||
pruned_len = extend_seq_len - start_len
|
|
||||||
|
|
||||||
if extend_seq_len == 0:
|
|
||||||
input_top_logprobs.append([])
|
input_top_logprobs.append([])
|
||||||
output_top_logprobs.append([])
|
output_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
k = logits_metadata.top_logprobs_nums[i]
|
|
||||||
input_top_logprobs.append(
|
input_top_logprobs.append(
|
||||||
[
|
[
|
||||||
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
||||||
@@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_index = None
|
last_index = None
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
last_index = (
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
output_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
if logits_metadata.forward_mode.is_decode():
|
|
||||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
if logits_metadata.forward_mode.is_decode():
|
||||||
return_top_logprob = any(
|
if logits_metadata.return_top_logprob:
|
||||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
|
||||||
)
|
|
||||||
if return_top_logprob:
|
|
||||||
output_top_logprobs = self.get_top_logprobs(
|
output_top_logprobs = self.get_top_logprobs(
|
||||||
last_logprobs, logits_metadata
|
last_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
else:
|
else:
|
||||||
output_top_logprobs = None
|
output_top_logprobs = None
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
|
|||||||
output_top_logprobs=output_top_logprobs,
|
output_top_logprobs=output_top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Slice the requested tokens to compute logprob
|
||||||
pt, states, pruned_input_ids = 0, [], []
|
pt, states, pruned_input_ids = 0, [], []
|
||||||
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
|
for start_len, extend_len in zip(
|
||||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
logits_metadata.extend_logprob_start_lens_cpu,
|
||||||
|
logits_metadata.extend_seq_lens_cpu,
|
||||||
|
):
|
||||||
states.append(hidden_states[pt + start_len : pt + extend_len])
|
states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||||
pt += extend_len
|
pt += extend_len
|
||||||
|
|
||||||
|
# Compute the logits and logprobs for all required tokens
|
||||||
states = torch.cat(states, dim=0)
|
states = torch.cat(states, dim=0)
|
||||||
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
|
||||||
|
|
||||||
cum_start_len1 = torch.tensor(
|
|
||||||
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
|
||||||
).cumsum(0)
|
|
||||||
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
|
||||||
cum_start_len0[1:] = cum_start_len1[:-1]
|
|
||||||
|
|
||||||
all_logits = torch.matmul(states, weight.T)
|
all_logits = torch.matmul(states, weight.T)
|
||||||
if self.do_tensor_parallel_all_gather:
|
if self.do_tensor_parallel_all_gather:
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
# Get the logprob of top-k tokens
|
||||||
return_top_logprob = any(
|
if logits_metadata.return_top_logprob:
|
||||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
|
||||||
)
|
|
||||||
if return_top_logprob:
|
|
||||||
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||||
all_logprobs, logits_metadata
|
all_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_top_logprobs = output_top_logprobs = None
|
input_top_logprobs = output_top_logprobs = None
|
||||||
|
|
||||||
last_logprobs = all_logprobs[last_index - cum_start_len1]
|
# Compute the normalized logprobs for the requested tokens.
|
||||||
|
# Note that we pad a zero at the end for easy batching.
|
||||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
|
||||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
|
||||||
input_token_logprobs = all_logprobs[
|
input_token_logprobs = all_logprobs[
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat(pruned_input_ids)[1:],
|
||||||
|
torch.tensor([0], device="cuda"),
|
||||||
|
]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
input_token_logprobs,
|
input_token_logprobs,
|
||||||
cum_start_len0,
|
|
||||||
cum_start_len1,
|
|
||||||
logits_metadata,
|
logits_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the last token logprob for the prefill tokens.
|
|
||||||
input_token_logprobs = input_token_logprobs[:-1]
|
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
@@ -43,6 +43,7 @@ class GenerateReqInput:
|
|||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# If return logprobs, the number of top logprobs to return at each position.
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class BaseFinishReason:
|
|||||||
self.is_error = is_error
|
self.is_error = is_error
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
raise NotImplementedError("Subclasses must implement this method")
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||||
@@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
class Req:
|
class Req:
|
||||||
"""Store all inforamtion of a request."""
|
"""Store all inforamtion of a request."""
|
||||||
|
|
||||||
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
rid: str,
|
||||||
|
origin_input_text: str,
|
||||||
|
origin_input_ids: Tuple[int],
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.origin_input_text = origin_input_text
|
self.origin_input_text = origin_input_text
|
||||||
@@ -118,6 +124,10 @@ class Req:
|
|||||||
# Memory info
|
# Memory info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
|
|
||||||
|
# Check finish
|
||||||
|
self.tokenizer = None
|
||||||
|
self.finished_reason = None
|
||||||
|
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
# ----- | --------- read_ids -------|
|
# ----- | --------- read_ids -------|
|
||||||
# ----- | surr_ids |
|
# ----- | surr_ids |
|
||||||
@@ -136,7 +146,7 @@ class Req:
|
|||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
self.completion_tokens_wo_jump_forward = 0
|
self.completion_tokens_wo_jump_forward = 0
|
||||||
|
|
||||||
# For vision input
|
# For vision inputs
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.image_sizes = None
|
self.image_sizes = None
|
||||||
self.image_offsets = None
|
self.image_offsets = None
|
||||||
@@ -144,31 +154,35 @@ class Req:
|
|||||||
self.modalities = None
|
self.modalities = None
|
||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
self.extend_input_len = 0
|
|
||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
|
self.extend_input_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
# Sampling parameters
|
# Sampling parameters
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
# Check finish
|
# Logprobs (arguments)
|
||||||
self.tokenizer = None
|
|
||||||
self.finished_reason = None
|
|
||||||
|
|
||||||
# Logprobs
|
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
self.embedding = None
|
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
|
|
||||||
|
# Logprobs (return value)
|
||||||
self.normalized_prompt_logprob = None
|
self.normalized_prompt_logprob = None
|
||||||
self.input_token_logprobs = None
|
self.input_token_logprobs = None
|
||||||
self.input_top_logprobs = None
|
self.input_top_logprobs = None
|
||||||
self.output_token_logprobs = []
|
self.output_token_logprobs = []
|
||||||
self.output_top_logprobs = []
|
self.output_top_logprobs = []
|
||||||
|
|
||||||
|
# Logprobs (internal values)
|
||||||
# The tokens is prefilled but need to be considered as decode tokens
|
# The tokens is prefilled but need to be considered as decode tokens
|
||||||
# and should be updated for the decode logprobs
|
# and should be updated for the decode logprobs
|
||||||
self.last_update_decode_tokens = 0
|
self.last_update_decode_tokens = 0
|
||||||
|
# The relative logprob_start_len in an extend batch
|
||||||
|
self.extend_logprob_start_len = 0
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
self.embedding = None
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.regex_fsm: RegexGuide = None
|
self.regex_fsm: RegexGuide = None
|
||||||
@@ -363,9 +377,13 @@ class ScheduleBatch:
|
|||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
|
# Stream
|
||||||
|
has_stream: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||||
return_logprob = any(req.return_logprob for req in reqs)
|
return_logprob = any(req.return_logprob for req in reqs)
|
||||||
|
has_stream = any(req.stream for req in reqs)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -373,18 +391,15 @@ class ScheduleBatch:
|
|||||||
token_to_kv_pool=token_to_kv_pool,
|
token_to_kv_pool=token_to_kv_pool,
|
||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
|
has_stream=has_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return len(self.reqs) if self.reqs else 0
|
return len(self.reqs)
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
def has_stream(self) -> bool:
|
|
||||||
# Return whether batch has at least 1 streaming request
|
|
||||||
return any(r.stream for r in self.reqs)
|
|
||||||
|
|
||||||
def alloc_req_slots(self, num_reqs):
|
def alloc_req_slots(self, num_reqs):
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||||
if req_pool_indices is None:
|
if req_pool_indices is None:
|
||||||
@@ -427,8 +442,8 @@ class ScheduleBatch:
|
|||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
ext_len = seq_len - pre_len
|
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
if pre_len > 0:
|
if pre_len > 0:
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
@@ -436,9 +451,19 @@ class ScheduleBatch:
|
|||||||
] = req.prefix_indices
|
] = req.prefix_indices
|
||||||
|
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||||
out_cache_loc[pt : pt + ext_len]
|
out_cache_loc[pt : pt + req.extend_input_len]
|
||||||
)
|
)
|
||||||
pt += ext_len
|
|
||||||
|
# Compute the relative logprob_start_len in an extend batch
|
||||||
|
if req.logprob_start_len >= pre_len:
|
||||||
|
extend_logprob_start_len = min(
|
||||||
|
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extend_logprob_start_len = req.extend_input_len - 1
|
||||||
|
|
||||||
|
req.extend_logprob_start_len = extend_logprob_start_len
|
||||||
|
pt += req.extend_input_len
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
@@ -451,21 +476,13 @@ class ScheduleBatch:
|
|||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
||||||
|
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
||||||
|
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
self.forward_mode = ForwardMode.MIXED
|
self.forward_mode = ForwardMode.MIXED
|
||||||
self.running_bs = running_batch.batch_size()
|
running_bs = running_batch.batch_size()
|
||||||
|
|
||||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
|
||||||
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
|
|
||||||
prefix_lens_cpu.extend(
|
|
||||||
[
|
|
||||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
|
||||||
for r in running_batch.reqs
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for req in running_batch.reqs:
|
for req in running_batch.reqs:
|
||||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||||
@@ -473,12 +490,22 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
||||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||||
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
|
extend_num_tokens = self.extend_num_tokens + running_bs
|
||||||
|
|
||||||
self.merge(running_batch)
|
self.merge(running_batch)
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.prefix_lens_cpu = prefix_lens_cpu
|
|
||||||
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||||
|
self.prefix_lens_cpu.extend(
|
||||||
|
[
|
||||||
|
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||||
|
for r in running_batch.reqs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.extend_lens_cpu.extend([1] * running_bs)
|
||||||
|
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
bs = self.batch_size()
|
bs = self.batch_size()
|
||||||
@@ -685,6 +712,7 @@ class ScheduleBatch:
|
|||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter(unfinished_indices, new_indices)
|
self.sampling_info.filter(unfinished_indices, new_indices)
|
||||||
|
|
||||||
@@ -695,7 +723,6 @@ class ScheduleBatch:
|
|||||||
self.sampling_info.merge(other.sampling_info)
|
self.sampling_info.merge(other.sampling_info)
|
||||||
|
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
@@ -706,3 +733,4 @@ class ScheduleBatch:
|
|||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
|||||||
@@ -197,8 +197,6 @@ class TokenizerManager:
|
|||||||
if not_use_index
|
if not_use_index
|
||||||
else obj.logprob_start_len[index]
|
else obj.logprob_start_len[index]
|
||||||
)
|
)
|
||||||
if return_logprob and logprob_start_len == -1:
|
|
||||||
logprob_start_len = len(input_ids) - 1
|
|
||||||
top_logprobs_num = (
|
top_logprobs_num = (
|
||||||
obj.top_logprobs_num
|
obj.top_logprobs_num
|
||||||
if not_use_index
|
if not_use_index
|
||||||
@@ -251,8 +249,6 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# Send to the controller
|
# Send to the controller
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if return_logprob and logprob_start_len == -1:
|
|
||||||
logprob_start_len = len(input_ids) - 1
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid,
|
rid,
|
||||||
input_text,
|
input_text,
|
||||||
@@ -349,8 +345,6 @@ class TokenizerManager:
|
|||||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
|
||||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
|
||||||
pixel_values, image_hashes, image_sizes = (
|
pixel_values, image_hashes, image_sizes = (
|
||||||
await self._get_pixel_values(obj.image_data[index])
|
await self._get_pixel_values(obj.image_data[index])
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ class ModelTpServer:
|
|||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.out_pyobjs and self.running_batch.has_stream():
|
if self.out_pyobjs and self.running_batch.has_stream:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
@@ -360,9 +360,13 @@ class ModelTpServer:
|
|||||||
# Only when pixel values is not None we have modalities
|
# Only when pixel values is not None we have modalities
|
||||||
req.modalities = recv_req.modalites
|
req.modalities = recv_req.modalites
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
|
||||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
|
|
||||||
|
if req.logprob_start_len == -1:
|
||||||
|
# By default, only return the logprobs for output tokens
|
||||||
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||||
|
|
||||||
# Init regex FSM
|
# Init regex FSM
|
||||||
if (
|
if (
|
||||||
@@ -384,7 +388,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
"Request length is longer than the KV cache pool size or "
|
"Request length is longer than the KV cache pool size or "
|
||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
@@ -583,7 +587,7 @@ class ModelTpServer:
|
|||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req is not self.current_inflight_req:
|
if req is not self.current_inflight_req:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
@@ -607,10 +611,9 @@ class ModelTpServer:
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
pt += req.extend_input_len
|
|
||||||
else:
|
else:
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
logits_output = self.model_runner.forward(batch)
|
logits_output = self.model_runner.forward(batch)
|
||||||
@@ -638,48 +641,63 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
self,
|
self,
|
||||||
i,
|
i: int,
|
||||||
req: Req,
|
req: Req,
|
||||||
pt: int,
|
pt: int,
|
||||||
next_token_ids: List[int],
|
next_token_ids: List[int],
|
||||||
output: LogitsProcessorOutput,
|
output: LogitsProcessorOutput,
|
||||||
):
|
):
|
||||||
|
"""Attach logprobs to the return values."""
|
||||||
|
req.output_token_logprobs.append(
|
||||||
|
(output.next_token_logprobs[i], next_token_ids[i])
|
||||||
|
)
|
||||||
|
|
||||||
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||||
|
|
||||||
if req.normalized_prompt_logprob is None:
|
if req.normalized_prompt_logprob is None:
|
||||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
if req.input_token_logprobs is None:
|
if req.input_token_logprobs is None:
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
input_token_logprobs = output.input_token_logprobs[
|
||||||
req.input_token_logprobs = list(
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||||
zip(
|
]
|
||||||
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
input_token_ids = req.fill_ids[
|
||||||
req.fill_ids[-req.extend_input_len + 1 :],
|
len(req.fill_ids)
|
||||||
)
|
- num_input_logprobs
|
||||||
)
|
+ 1 : len(req.fill_ids)
|
||||||
if req.logprob_start_len == 0:
|
- req.last_update_decode_tokens
|
||||||
|
]
|
||||||
|
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
||||||
|
|
||||||
|
if (
|
||||||
|
req.logprob_start_len == 0
|
||||||
|
): # The first token does not have logprob, pad it.
|
||||||
req.input_token_logprobs = [
|
req.input_token_logprobs = [
|
||||||
(None, req.fill_ids[0])
|
(None, req.fill_ids[0])
|
||||||
] + req.input_token_logprobs
|
] + req.input_token_logprobs
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
if req.last_update_decode_tokens != 0:
|
||||||
|
# Some decode tokens are re-computed in an extend batch
|
||||||
req.output_token_logprobs.extend(
|
req.output_token_logprobs.extend(
|
||||||
list(
|
list(
|
||||||
zip(
|
zip(
|
||||||
output.input_token_logprobs[
|
output.input_token_logprobs[
|
||||||
pt
|
pt
|
||||||
+ req.extend_input_len
|
+ num_input_logprobs
|
||||||
|
- 1
|
||||||
- req.last_update_decode_tokens : pt
|
- req.last_update_decode_tokens : pt
|
||||||
+ req.extend_input_len
|
+ num_input_logprobs
|
||||||
- 1
|
- 1
|
||||||
],
|
],
|
||||||
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
req.fill_ids[
|
||||||
|
len(req.fill_ids)
|
||||||
|
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
req.output_token_logprobs.append(
|
|
||||||
(output.next_token_logprobs[i], next_token_ids[i])
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
if req.input_top_logprobs is None:
|
if req.input_top_logprobs is None:
|
||||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||||
@@ -688,10 +706,12 @@ class ModelTpServer:
|
|||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
if req.last_update_decode_tokens != 0:
|
||||||
req.output_top_logprobs.extend(
|
req.output_top_logprobs.extend(
|
||||||
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
||||||
)
|
)
|
||||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||||
|
|
||||||
|
return num_input_logprobs
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ class CudaGraphRunner:
|
|||||||
attn_backend=self.model_runner.attn_backend,
|
attn_backend=self.model_runner.attn_backend,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=0,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||||
)
|
)
|
||||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class InputMetadata:
|
|||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
extend_seq_lens_cpu: List[int] = None
|
extend_seq_lens_cpu: List[int] = None
|
||||||
logprob_start_lens_cpu: List[int] = None
|
extend_logprob_start_lens_cpu: List[int] = None
|
||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
pixel_values: List[torch.Tensor] = None
|
pixel_values: List[torch.Tensor] = None
|
||||||
@@ -138,27 +138,13 @@ class InputMetadata:
|
|||||||
self.positions = self.positions.to(torch.int64)
|
self.positions = self.positions.to(torch.int64)
|
||||||
|
|
||||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||||
extend_lens_cpu = [
|
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||||
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
|
|
||||||
]
|
|
||||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
|
||||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
||||||
|
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||||
self.logprob_start_lens_cpu = [
|
|
||||||
(
|
|
||||||
min(
|
|
||||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
|
||||||
extend_lens_cpu[i] - 1,
|
|
||||||
)
|
|
||||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
|
||||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
|
||||||
)
|
|
||||||
for i, req in enumerate(batch.reqs)
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, UploadFile
|
from fastapi import HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -472,7 +472,7 @@ def v1_generate_request(
|
|||||||
first_prompt_type = type(all_requests[0].prompt)
|
first_prompt_type = type(all_requests[0].prompt)
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
assert (
|
assert (
|
||||||
type(request.prompt) == first_prompt_type
|
type(request.prompt) is first_prompt_type
|
||||||
), "All prompts must be of the same type in file input settings"
|
), "All prompts must be of the same type in file input settings"
|
||||||
if len(all_requests) > 1 and request.n > 1:
|
if len(all_requests) > 1 and request.n > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -887,7 +887,7 @@ def v1_chat_generate_request(
|
|||||||
input_ids.append(prompt_ids)
|
input_ids.append(prompt_ids)
|
||||||
return_logprobs.append(request.logprobs)
|
return_logprobs.append(request.logprobs)
|
||||||
logprob_start_lens.append(-1)
|
logprob_start_lens.append(-1)
|
||||||
top_logprobs_nums.append(request.top_logprobs)
|
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||||
|
|
||||||
sampling_params = {
|
sampling_params = {
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
|
|||||||
@@ -86,24 +86,24 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
device = "cuda"
|
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
ret = cls(vocab_size=vocab_size)
|
ret = cls(vocab_size=vocab_size)
|
||||||
|
|
||||||
ret.temperatures = torch.tensor(
|
with torch.device("cuda"):
|
||||||
[r.sampling_params.temperature for r in reqs],
|
ret.temperatures = torch.tensor(
|
||||||
dtype=torch.float,
|
[r.sampling_params.temperature for r in reqs],
|
||||||
device=device,
|
dtype=torch.float,
|
||||||
).view(-1, 1)
|
).view(-1, 1)
|
||||||
ret.top_ps = torch.tensor(
|
ret.top_ps = torch.tensor(
|
||||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
ret.top_ks = torch.tensor(
|
ret.top_ks = torch.tensor(
|
||||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||||
)
|
)
|
||||||
ret.min_ps = torch.tensor(
|
ret.min_ps = torch.tensor(
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
@@ -116,7 +116,7 @@ class SamplingBatchInfo:
|
|||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
device=device,
|
device="cuda",
|
||||||
Penalizers={
|
Penalizers={
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
|
|||||||
@@ -11,16 +11,18 @@ suites = {
|
|||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
|
"test_json_constrained.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
"test_json_constrained.py",
|
|
||||||
"test_skip_tokenizer_init.py",
|
|
||||||
"test_torch_compile.py",
|
|
||||||
"test_triton_attn_backend.py",
|
|
||||||
"test_pytorch_sampling_backend.py",
|
"test_pytorch_sampling_backend.py",
|
||||||
|
"test_server_args.py",
|
||||||
|
"test_skip_tokenizer_init.py",
|
||||||
|
"test_srt_endpoint.py",
|
||||||
|
"test_torch_compile.py",
|
||||||
|
"test_torchao.py",
|
||||||
|
"test_triton_attn_backend.py",
|
||||||
"test_update_weights.py",
|
"test_update_weights.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_server_args.py",
|
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
"sampling/penaltylib": glob.glob(
|
||||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||||
|
|||||||
@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase):
|
|||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
model=model,
|
model=model,
|
||||||
eval_name="mmlu",
|
eval_name="mmlu",
|
||||||
num_examples=32,
|
num_examples=64,
|
||||||
num_threads=32,
|
num_threads=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.6
|
assert metrics["score"] >= 0.65
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.api_key = "sk-123456"
|
|
||||||
cls.json_schema = json.dumps(
|
cls.json_schema = json.dumps(
|
||||||
{
|
{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
"required": ["name", "population"],
|
"required": ["name", "population"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||||
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
"top_logprobs_num": top_logprobs_num,
|
"top_logprobs_num": top_logprobs_num,
|
||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
headers=headers,
|
|
||||||
)
|
)
|
||||||
print(json.dumps(response.json()))
|
print(json.dumps(response.json()))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
self.run_decode()
|
self.run_decode()
|
||||||
|
|
||||||
def test_json_openai(self):
|
def test_json_openai(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|||||||
@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||||
|
|
||||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
|
||||||
assert ret_num_top_logprobs > 0
|
assert ret_num_top_logprobs > 0
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
|
||||||
|
assert response.choices[0].logprobs.token_logprobs[0]
|
||||||
|
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.top_logprobs[0]
|
response.choices[0].logprobs.top_logprobs[0]
|
||||||
)
|
)
|
||||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
assert ret_num_top_logprobs > 0
|
assert ret_num_top_logprobs > 0
|
||||||
|
|
||||||
@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert isinstance(js_obj["name"], str)
|
assert isinstance(js_obj["name"], str)
|
||||||
assert isinstance(js_obj["population"], int)
|
assert isinstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
def test_penalty(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "Introduce the capital of France."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
frequency_penalty=1.0,
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
"text": "The capital of France is",
|
"text": "The capital of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0 if n == 1 else 0.5,
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 16,
|
||||||
"n": n,
|
"n": n,
|
||||||
},
|
},
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||||
response_json.append(json.loads(line[6:]))
|
response_json.append(json.loads(line[6:]))
|
||||||
print(json.dumps(response_json))
|
|
||||||
|
print(json.dumps(response_json, indent=2))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
def test_simple_decode(self):
|
def test_simple_decode(self):
|
||||||
@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
self.run_decode(n=3, stream=True)
|
self.run_decode(n=3, stream=True)
|
||||||
|
|
||||||
def test_logprob(self):
|
def test_logprob(self):
|
||||||
for top_logprobs_num in [0, 3]:
|
self.run_decode(
|
||||||
for return_text in [True, False]:
|
return_logprob=True,
|
||||||
self.run_decode(
|
top_logprobs_num=5,
|
||||||
return_logprob=True,
|
return_text=True,
|
||||||
top_logprobs_num=top_logprobs_num,
|
)
|
||||||
return_text=return_text,
|
|
||||||
)
|
def test_logprob_start_len(self):
|
||||||
|
logprob_start_len = 4
|
||||||
|
new_tokens = 4
|
||||||
|
prompts = [
|
||||||
|
"I have a very good idea on",
|
||||||
|
"Today is a sunndy day and",
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": new_tokens,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"top_logprobs_num": 5,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
|
"logprob_start_len": logprob_start_len,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
print(json.dumps(response_json, indent=2))
|
||||||
|
|
||||||
|
for i, res in enumerate(response_json):
|
||||||
|
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
|
||||||
|
res["meta_info"]["input_token_logprobs"]
|
||||||
|
)
|
||||||
|
assert prompts[i].endswith(
|
||||||
|
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res["meta_info"]["completion_tokens"] == new_tokens
|
||||||
|
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
|
||||||
|
res["text"] == "".join(
|
||||||
|
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.65
|
assert metrics["score"] >= 0.60
|
||||||
|
|
||||||
def run_decode(self, max_new_tokens):
|
def run_decode(self, max_new_tokens):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|||||||
@@ -127,7 +127,6 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_context_attention_once(self, head_dim):
|
def _test_context_attention_once(self, head_dim):
|
||||||
# Set up a simple test case
|
# Set up a simple test case
|
||||||
batch_size = 2
|
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
seq_lens = [8, 12]
|
seq_lens = [8, 12]
|
||||||
max_seq_len = max(seq_lens)
|
max_seq_len = max(seq_lens)
|
||||||
|
|||||||
Reference in New Issue
Block a user