Fix logprob in the overlapped mode (#1795)
This commit is contained in:
@@ -60,7 +60,7 @@ pip install "sglang[all]"
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
|
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
|
||||||
|
|
||||||
### Method 2: From source
|
### Method 2: From source
|
||||||
```
|
```
|
||||||
@@ -75,7 +75,7 @@ pip install -e "python[all]"
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
|
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
|
||||||
|
|
||||||
### Method 3: Using docker
|
### Method 3: Using docker
|
||||||
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ pip install "sglang[all]"
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
|
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
|
||||||
|
|
||||||
### Method 2: From source
|
### Method 2: From source
|
||||||
```
|
```
|
||||||
@@ -26,7 +26,7 @@ pip install -e "python[all]"
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
|
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
|
||||||
|
|
||||||
### Method 3: Using docker
|
### Method 3: Using docker
|
||||||
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
||||||
|
|||||||
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
|
|||||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logits: torch.Tensor
|
next_token_logits: torch.Tensor
|
||||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logprobs: torch.Tensor
|
next_token_logprobs: torch.Tensor = None
|
||||||
|
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
normalized_prompt_logprobs: torch.Tensor
|
normalized_prompt_logprobs: torch.Tensor = None
|
||||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||||
input_token_logprobs: torch.Tensor
|
input_token_logprobs: torch.Tensor = None
|
||||||
|
|
||||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
input_top_logprobs: List
|
input_top_logprobs: List = None
|
||||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
output_top_logprobs: List
|
output_top_logprobs: List = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@@ -833,6 +833,7 @@ class Scheduler:
|
|||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||||
|
next_token_logprobs = logits_output.next_token_logprobs
|
||||||
else:
|
else:
|
||||||
# Move next_token_ids and logprobs to cpu
|
# Move next_token_ids and logprobs to cpu
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ class TpModelWorkerClient:
|
|||||||
while True:
|
while True:
|
||||||
self.has_inflight_batch = False
|
self.has_inflight_batch = False
|
||||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||||
|
if not model_worker_batch:
|
||||||
|
break
|
||||||
self.has_inflight_batch = True
|
self.has_inflight_batch = True
|
||||||
self.launch_event = threading.Event()
|
self.launch_event = threading.Event()
|
||||||
|
|
||||||
@@ -122,19 +124,48 @@ class TpModelWorkerClient:
|
|||||||
] = next_token_ids
|
] = next_token_ids
|
||||||
|
|
||||||
# Copy results to the CPU
|
# Copy results to the CPU
|
||||||
|
if model_worker_batch.return_logprob:
|
||||||
|
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
|
torch.arange(len(next_token_ids), device=self.device),
|
||||||
|
next_token_ids,
|
||||||
|
].to("cpu", non_blocking=True)
|
||||||
|
if logits_output.input_token_logprobs is not None:
|
||||||
|
logits_output.input_token_logprobs = (
|
||||||
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||||
|
)
|
||||||
|
logits_output.normalized_prompt_logprobs = (
|
||||||
|
logits_output.normalized_prompt_logprobs.to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
|
)
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_event = torch.cuda.Event(blocking=True)
|
copy_event = torch.cuda.Event(blocking=True)
|
||||||
copy_event.record()
|
copy_event.record()
|
||||||
|
|
||||||
self.launch_event.set()
|
self.launch_event.set()
|
||||||
self.copy_queue.put((copy_event, next_token_ids))
|
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
||||||
|
|
||||||
def copy_thread_func(self):
|
def copy_thread_func(self):
|
||||||
while True:
|
while True:
|
||||||
copy_event, next_token_ids = self.copy_queue.get()
|
copy_event, logits_output, next_token_ids = self.copy_queue.get()
|
||||||
|
if not copy_event:
|
||||||
|
break
|
||||||
while not copy_event.query():
|
while not copy_event.query():
|
||||||
time.sleep(1e-5)
|
time.sleep(1e-5)
|
||||||
self.output_queue.put((None, next_token_ids.tolist()))
|
|
||||||
|
if logits_output.next_token_logprobs is not None:
|
||||||
|
logits_output.next_token_logprobs = (
|
||||||
|
logits_output.next_token_logprobs.tolist()
|
||||||
|
)
|
||||||
|
if logits_output.input_token_logprobs is not None:
|
||||||
|
logits_output.input_token_logprobs = (
|
||||||
|
logits_output.input_token_logprobs.tolist()
|
||||||
|
)
|
||||||
|
logits_output.normalized_prompt_logprobs = (
|
||||||
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
||||||
|
|
||||||
def resulve_batch_result(self, bid: int):
|
def resulve_batch_result(self, bid: int):
|
||||||
logits_output, next_token_ids = self.output_queue.get()
|
logits_output, next_token_ids = self.output_queue.get()
|
||||||
@@ -172,3 +203,7 @@ class TpModelWorkerClient:
|
|||||||
recv_req.model_path, recv_req.load_format
|
recv_req.model_path, recv_req.load_format
|
||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def __delete__(self):
|
||||||
|
self.input_queue.put((None, None))
|
||||||
|
self.copy_queue.put((None, None, None))
|
||||||
|
|||||||
@@ -263,7 +263,8 @@ class CudaGraphRunner:
|
|||||||
positions=clamp_position(seq_lens),
|
positions=clamp_position(seq_lens),
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
)
|
)
|
||||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
|
return logits_output.next_token_logits
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -318,23 +319,16 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
logits_output = self.output_buffers[bs]
|
next_token_logits = self.output_buffers[bs][:raw_bs]
|
||||||
|
|
||||||
# Unpad
|
|
||||||
if bs != raw_bs:
|
|
||||||
logits_output = LogitsProcessorOutput(
|
|
||||||
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
|
||||||
next_token_logprobs=None,
|
|
||||||
normalized_prompt_logprobs=None,
|
|
||||||
input_token_logprobs=None,
|
|
||||||
input_top_logprobs=None,
|
|
||||||
output_top_logprobs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
if forward_batch.return_logprob:
|
if forward_batch.return_logprob:
|
||||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
next_token_logprobs = torch.nn.functional.log_softmax(
|
||||||
logits_output.next_token_logits, dim=-1
|
next_token_logits, dim=-1
|
||||||
|
)
|
||||||
|
logits_output = LogitsProcessorOutput(
|
||||||
|
next_token_logits=next_token_logits,
|
||||||
|
next_token_logprobs=next_token_logprobs,
|
||||||
)
|
)
|
||||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
@@ -343,7 +337,11 @@ class CudaGraphRunner:
|
|||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||||
logits_output.next_token_logprobs, logits_metadata
|
next_token_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
|
else:
|
||||||
|
logits_output = LogitsProcessorOutput(
|
||||||
|
next_token_logits=next_token_logits,
|
||||||
|
)
|
||||||
|
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ suites = {
|
|||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
"test_overlap_schedule.py",
|
"test_overlap_schedule.py",
|
||||||
"test_pytorch_sampling_backend.py",
|
"test_pytorch_sampling_backend.py",
|
||||||
"test_radix_attention.py",
|
|
||||||
"test_retract_decode.py",
|
"test_retract_decode.py",
|
||||||
"test_server_args.py",
|
"test_server_args.py",
|
||||||
"test_skip_tokenizer_init.py",
|
"test_skip_tokenizer_init.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user