Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)
Co-authored-by: Sehoon Kim <kssteven418@gmail.com> Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -95,7 +95,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100]
|
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-48, 48-100]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
|
|||||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
@@ -37,7 +35,7 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.decode import PosEncodingMode
|
from flashinfer.decode import _get_range_buf, get_seq_lens
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||||
@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.workspace_buffer = global_workspace_buffer
|
self.workspace_buffer = global_workspace_buffer
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
if kv_indptr_buf is None:
|
if kv_indptr_buf is None:
|
||||||
self.kv_indptr = [
|
self.kv_indptr = [
|
||||||
@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.prefill_wrappers_verify.append(
|
self.prefill_wrappers_verify.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decode_wrappers.append(
|
self.decode_wrappers.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if not skip_prefill:
|
if not skip_prefill:
|
||||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||||
model_runner, self
|
model_runner, self
|
||||||
)
|
) # for verify
|
||||||
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
||||||
|
|
||||||
# Other metadata
|
# Other metadata
|
||||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.prefill_cuda_graph_metadata = {}
|
self.prefill_cuda_graph_metadata = {} # For verify
|
||||||
|
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||||
|
for i in range(self.num_wrappers):
|
||||||
|
decode_wrappers[i].begin_forward = partial(
|
||||||
|
fast_decode_plan, decode_wrappers[i]
|
||||||
|
)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
prefill_wrappers = []
|
prefill_wrappers = []
|
||||||
for i in range(self.num_wrappers):
|
for i in range(self.num_wrappers):
|
||||||
@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
causal=False,
|
causal=False,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
if wrapper.is_cuda_graph_enabled:
|
||||||
)
|
# Directly write to the cuda graph input buffer
|
||||||
|
kv_indices = wrapper._paged_kv_indices_buf
|
||||||
|
else:
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(spec_info, EagleDraftInput)
|
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tnesor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tnesor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
custom_mask = None
|
custom_mask = None
|
||||||
@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
|
||||||
def common_template(
|
def common_template(
|
||||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
kv_indices_buffer: torch.Tensor,
|
||||||
|
call_fn: Callable,
|
||||||
):
|
):
|
||||||
num_seqs = forward_batch.batch_size
|
num_seqs = forward_batch.batch_size
|
||||||
bs = self.topk * num_seqs
|
bs = self.topk * num_seqs
|
||||||
@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
|
||||||
forward_batch.batch_size
|
|
||||||
][0]
|
|
||||||
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
|
||||||
|
|
||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
|
):
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch.batch_size,
|
bs,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens_sum=-1,
|
seq_lens_sum=-1,
|
||||||
@@ -1113,6 +1121,11 @@ def should_use_tensor_core(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||||
|
# This is used to remove some host-to-device copy overhead.
|
||||||
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
|
|
||||||
def fast_decode_plan(
|
def fast_decode_plan(
|
||||||
self,
|
self,
|
||||||
indptr: torch.Tensor,
|
indptr: torch.Tensor,
|
||||||
@@ -1142,6 +1155,9 @@ def fast_decode_plan(
|
|||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
logits_soft_cap = 0.0
|
logits_soft_cap = 0.0
|
||||||
|
|
||||||
|
if self.use_tensor_cores:
|
||||||
|
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||||
|
|
||||||
if self.is_cuda_graph_enabled:
|
if self.is_cuda_graph_enabled:
|
||||||
if batch_size != self._fixed_batch_size:
|
if batch_size != self._fixed_batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1154,7 +1170,7 @@ def fast_decode_plan(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The size of indices should be less than or equal to the allocated buffer"
|
"The size of indices should be less than or equal to the allocated buffer"
|
||||||
)
|
)
|
||||||
# Skip these copies
|
# Skip these copies because we directly write to them during prepartion
|
||||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||||
@@ -1162,6 +1178,7 @@ def fast_decode_plan(
|
|||||||
self._paged_kv_indptr_buf = indptr
|
self._paged_kv_indptr_buf = indptr
|
||||||
self._paged_kv_indices_buf = indices
|
self._paged_kv_indices_buf = indices
|
||||||
self._paged_kv_last_page_len_buf = last_page_len
|
self._paged_kv_last_page_len_buf = last_page_len
|
||||||
|
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
||||||
|
|
||||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||||
if not q_data_type:
|
if not q_data_type:
|
||||||
@@ -1184,27 +1201,55 @@ def fast_decode_plan(
|
|||||||
)
|
)
|
||||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||||
|
|
||||||
empty_q_data = self.empty_q_data
|
indptr_host = (
|
||||||
empty_kv_cache = self.empty_kv_cache
|
global_override_indptr_cpu
|
||||||
stream = torch.cuda.current_stream()
|
if global_override_indptr_cpu is not None
|
||||||
self._cached_module.plan(
|
else indptr.cpu()
|
||||||
self._float_workspace_buffer,
|
|
||||||
self._int_workspace_buffer,
|
|
||||||
self._pin_memory_int_workspace_buffer,
|
|
||||||
indptr.to("cpu"),
|
|
||||||
batch_size,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
page_size,
|
|
||||||
self.is_cuda_graph_enabled,
|
|
||||||
window_left,
|
|
||||||
logits_soft_cap,
|
|
||||||
head_dim,
|
|
||||||
head_dim,
|
|
||||||
empty_q_data,
|
|
||||||
empty_kv_cache,
|
|
||||||
stream.cuda_stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_tensor_cores:
|
||||||
|
kv_lens_arr_host = get_seq_lens(
|
||||||
|
indptr_host, self.last_page_len[:batch_size], page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self._plan_info = self._cached_module.plan(
|
||||||
|
self._float_workspace_buffer,
|
||||||
|
self._int_workspace_buffer,
|
||||||
|
self._pin_memory_int_workspace_buffer,
|
||||||
|
qo_indptr_host,
|
||||||
|
indptr_host,
|
||||||
|
kv_lens_arr_host,
|
||||||
|
batch_size, # total_num_rows
|
||||||
|
batch_size,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
page_size,
|
||||||
|
self.is_cuda_graph_enabled,
|
||||||
|
head_dim,
|
||||||
|
head_dim,
|
||||||
|
False, # causal
|
||||||
|
torch.cuda.current_stream().cuda_stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._plan_info = self._cached_module.plan(
|
||||||
|
self._float_workspace_buffer,
|
||||||
|
self._int_workspace_buffer,
|
||||||
|
self._pin_memory_int_workspace_buffer,
|
||||||
|
indptr_host,
|
||||||
|
batch_size,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
page_size,
|
||||||
|
self.is_cuda_graph_enabled,
|
||||||
|
window_left,
|
||||||
|
logits_soft_cap,
|
||||||
|
head_dim,
|
||||||
|
head_dim,
|
||||||
|
self.empty_q_data,
|
||||||
|
self.empty_kv_cache,
|
||||||
|
torch.cuda.current_stream().cuda_stream,
|
||||||
|
)
|
||||||
|
|
||||||
self._pos_encoding_mode = pos_encoding_mode
|
self._pos_encoding_mode = pos_encoding_mode
|
||||||
self._window_left = window_left
|
self._window_left = window_left
|
||||||
self._logits_soft_cap = logits_soft_cap
|
self._logits_soft_cap = logits_soft_cap
|
||||||
|
|||||||
@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
|
|||||||
|
|
||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
|
):
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch.batch_size,
|
bs,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens_sum=-1,
|
seq_lens_sum=-1,
|
||||||
|
|||||||
@@ -396,16 +396,10 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
run_once()
|
run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
global global_graph_memory_pool
|
global global_graph_memory_pool
|
||||||
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
global_graph_memory_pool = graph.pool()
|
global_graph_memory_pool = graph.pool()
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
|
|||||||
|
|
||||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
||||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
|
||||||
|
if len(parents_list) > 1:
|
||||||
|
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||||
|
else:
|
||||||
|
batch_size = parents_list[0].shape[0]
|
||||||
|
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
||||||
|
|
||||||
return parent_list, top_scores_index, draft_tokens
|
return parent_list, top_scores_index, draft_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
import time
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
|
|
||||||
run_once()
|
run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
with torch.cuda.graph(
|
with torch.cuda.graph(
|
||||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||||
):
|
):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
set_global_graph_memory_pool(graph.pool())
|
set_global_graph_memory_pool(graph.pool())
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch
|
forward_batch, forward_batch.batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -62,6 +62,7 @@ class EagleDraftInput:
|
|||||||
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
||||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||||
)
|
)
|
||||||
|
pt += extend_len
|
||||||
|
|
||||||
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
||||||
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||||
EAGLEDraftCudaGraphRunner,
|
EAGLEDraftCudaGraphRunner,
|
||||||
@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
|
|||||||
fast_topk,
|
fast_topk,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
||||||
from sglang.srt.utils import get_available_gpu_memory
|
from sglang.srt.utils import get_available_gpu_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
target_worker: TpModelWorker,
|
target_worker: TpModelWorker,
|
||||||
):
|
):
|
||||||
|
# Parse arguments
|
||||||
|
self.server_args = server_args
|
||||||
|
self.topk = server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
|
self.enable_nan_detection = server_args.enable_nan_detection
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.device = server_args.device
|
||||||
|
self.target_worker = target_worker
|
||||||
|
|
||||||
# Override context length with target model's context length
|
# Override context length with target model's context length
|
||||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
|
||||||
|
|
||||||
# Do not capture cuda graph in `super().__init__()`
|
# Do not capture cuda graph in `super().__init__()`
|
||||||
# We will capture it later
|
# It will be captured later.
|
||||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_cuda_graph = True
|
||||||
|
# Share the allocator with a target worker.
|
||||||
|
# Draft and target worker own their own KV cache pools.
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
target_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
|
||||||
# Lossy optimization by using hot tokens
|
# Load hot token ids
|
||||||
if server_args.speculative_token_map is not None:
|
if server_args.speculative_token_map is not None:
|
||||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||||
server_args.json_model_override_args = (
|
server_args.json_model_override_args = (
|
||||||
@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
self.hot_token_id = None
|
self.hot_token_id = None
|
||||||
|
|
||||||
# We share the allocator with a target worker. Draft/target worker
|
# Init draft worker
|
||||||
# owns its own KV cache.
|
|
||||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
|
||||||
target_worker.get_memory_pool()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init target worker
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
self.target_worker = target_worker
|
|
||||||
|
|
||||||
# Parse arguments
|
|
||||||
self.topk = server_args.speculative_eagle_topk
|
|
||||||
self.speculative_num_steps = server_args.speculative_num_steps
|
|
||||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
|
||||||
server_args.speculative_algorithm
|
|
||||||
)
|
|
||||||
self.server_args = server_args
|
|
||||||
self.use_nan_detection = self.server_args.enable_nan_detection
|
|
||||||
self.device = self.model_runner.device
|
|
||||||
self.gpu_id = self.model_runner.gpu_id
|
|
||||||
|
|
||||||
# Share the embedding and lm_head
|
# Share the embedding and lm_head
|
||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
backup_disable_cuda_graph
|
backup_disable_cuda_graph
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.init_attention_backend()
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
|
def init_attention_backend(self):
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
if server_args.attention_backend == "flashinfer":
|
if self.server_args.attention_backend == "flashinfer":
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
FlashInferMultiStepDraftBackend,
|
FlashInferMultiStepDraftBackend,
|
||||||
)
|
)
|
||||||
@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
elif server_args.attention_backend == "triton":
|
elif self.server_args.attention_backend == "triton":
|
||||||
from sglang.srt.layers.attention.triton_backend import (
|
from sglang.srt.layers.attention.triton_backend import (
|
||||||
TritonMultiStepDraftBackend,
|
TritonMultiStepDraftBackend,
|
||||||
)
|
)
|
||||||
@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
self.init_cuda_graphs()
|
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.spec_info = res.draft_input
|
batch.spec_info = res.draft_input
|
||||||
|
|
||||||
|
if batch.return_logprob:
|
||||||
|
# Compute output logprobs using the sampler.
|
||||||
|
num_tokens_per_req = [
|
||||||
|
accept + 1 for accept in res.accept_length_per_req_cpu
|
||||||
|
]
|
||||||
|
self.target_worker.model_runner.update_output_logprobs(
|
||||||
|
logits_output,
|
||||||
|
batch.sampling_info,
|
||||||
|
batch.top_logprobs_nums,
|
||||||
|
batch.token_ids_logprobs,
|
||||||
|
res.verified_id,
|
||||||
|
# +1 for bonus token.
|
||||||
|
num_tokens_per_req=num_tokens_per_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add output logprobs to the request.
|
||||||
|
pt = 0
|
||||||
|
# NOTE: tolist() of these values are skipped when output is processed
|
||||||
|
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
||||||
|
verified_ids = res.verified_id.tolist()
|
||||||
|
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||||
|
for _ in range(num_tokens):
|
||||||
|
if req.return_logprob:
|
||||||
|
token_id = verified_ids[pt]
|
||||||
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||||
|
req.output_token_logprobs_idx.append(token_id)
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
|
req.output_top_logprobs_val.append(
|
||||||
|
res.logits_output.next_token_top_logprobs_val[pt]
|
||||||
|
)
|
||||||
|
req.output_top_logprobs_idx.append(
|
||||||
|
res.logits_output.next_token_top_logprobs_idx[pt]
|
||||||
|
)
|
||||||
|
pt += 1
|
||||||
|
|
||||||
return logits_output, res, model_worker_batch
|
return logits_output, res, model_worker_batch
|
||||||
|
|
||||||
def forward_draft_extend(
|
def forward_draft_extend(
|
||||||
@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
|
forward_batch.return_logprob = False
|
||||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
# We don't need logprob for this extend.
|
# We don't need logprob for this extend.
|
||||||
|
original_return_logprob = batch.return_logprob
|
||||||
|
batch.return_logprob = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Restore backup.
|
# Restore backup.
|
||||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||||
|
batch.return_logprob = original_return_logprob
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.seq_lens = seq_lens_backup
|
batch.seq_lens = seq_lens_backup
|
||||||
|
|
||||||
@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
draft_input.hidden_states = logits_output.hidden_states
|
draft_input.hidden_states = logits_output.hidden_states
|
||||||
|
|
||||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||||
if self.use_nan_detection:
|
if self.enable_nan_detection:
|
||||||
logits = logits_output.next_token_logits
|
logits = logits_output.next_token_logits
|
||||||
if torch.any(torch.isnan(logits)):
|
if torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
||||||
f'accept_length : {res["accept_length"]:.2f} \n'
|
f'accept_length : {res["accept_length"]:.2f} \n'
|
||||||
)
|
)
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 1100)
|
self.assertLess(res["median_e2e_latency_ms"], 900)
|
||||||
self.assertGreater(res["accept_length"], 2.99)
|
self.assertGreater(res["accept_length"], 2.99)
|
||||||
|
|
||||||
def test_moe_offline_throughput_default(self):
|
def test_moe_offline_throughput_default(self):
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
|
import json
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +25,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
|
run_logprob_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
@@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.9)
|
self.assertGreater(avg_spec_accept_length, 3.5)
|
||||||
|
|
||||||
# Wait a little bit so that the memory check happens.
|
# Wait a little bit so that the memory check happens.
|
||||||
time.sleep(4)
|
time.sleep(4)
|
||||||
|
|
||||||
|
def test_logprob_start_len(self):
|
||||||
|
logprob_start_len = 4
|
||||||
|
new_tokens = 4
|
||||||
|
prompts = [
|
||||||
|
"I have a very good idea on",
|
||||||
|
"Today is a sunndy day and",
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": new_tokens,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"top_logprobs_num": 5,
|
||||||
|
"logprob_start_len": logprob_start_len,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
print(json.dumps(response_json, indent=2))
|
||||||
|
|
||||||
|
for res in response_json:
|
||||||
|
self.assertEqual(
|
||||||
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||||
|
|
||||||
|
def test_logprob_match(self):
|
||||||
|
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
|
||||||
|
|
||||||
|
def run_generate(
|
||||||
|
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
|
||||||
|
):
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
**prompt_kwargs,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
|
"logprob_start_len": logprob_start_len,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
prompt = "I have a very good idea on how to"
|
||||||
|
|
||||||
|
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
|
||||||
|
output_logprobs = np.array(
|
||||||
|
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
|
||||||
|
)
|
||||||
|
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
|
||||||
|
|
||||||
|
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
|
||||||
|
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
|
||||||
|
|
||||||
|
new_prompt = input_tokens + output_tokens
|
||||||
|
score = run_generate(
|
||||||
|
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
|
||||||
|
)
|
||||||
|
output_logprobs_score = np.array(
|
||||||
|
[
|
||||||
|
x[0]
|
||||||
|
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{output_logprobs[-10:]=}")
|
||||||
|
print(f"{output_logprobs_score[-10:]=}")
|
||||||
|
|
||||||
|
diff = np.abs(output_logprobs - output_logprobs_score)
|
||||||
|
max_diff = np.max(diff)
|
||||||
|
self.assertLess(max_diff, 0.25)
|
||||||
|
|
||||||
|
def test_logprob_mixed(self):
|
||||||
|
args = []
|
||||||
|
temperature = 0
|
||||||
|
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||||
|
# Llama 2 context length seems to be only 2k, so we can only test small length.
|
||||||
|
for input_len in [200, 500, 1000, 2000]:
|
||||||
|
for output_len in [4, 8]:
|
||||||
|
for logprob_start_len in [0, 100, 300, 800, 1998]:
|
||||||
|
for return_logprob in [True, False]:
|
||||||
|
for top_logprobs_num in [0, 5]:
|
||||||
|
|
||||||
|
if logprob_start_len >= input_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
args.append(
|
||||||
|
(
|
||||||
|
input_len,
|
||||||
|
output_len,
|
||||||
|
temperature,
|
||||||
|
logprob_start_len,
|
||||||
|
return_logprob,
|
||||||
|
top_logprobs_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
random.shuffle(args)
|
||||||
|
|
||||||
|
func = partial(run_logprob_check, self)
|
||||||
|
with ThreadPoolExecutor(8) as executor:
|
||||||
|
list(executor.map(func, args))
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLERetract(TestEAGLEServer):
|
class TestEAGLERetract(TestEAGLEServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase):
|
|||||||
|
|
||||||
print(f"result = `{result}`")
|
print(f"result = `{result}`")
|
||||||
|
|
||||||
assert "paris" in result["text"].lower()
|
self.assertIn("paris", result["text"].lower())
|
||||||
|
|
||||||
throughput = max_tokens / (tok - tic)
|
throughput = max_tokens / (tok - tic)
|
||||||
print(f"Throughput: {throughput} tokens/s")
|
print(f"Throughput: {throughput} tokens/s")
|
||||||
assert throughput >= 140
|
self.assertGreaterEqual(throughput, 140)
|
||||||
|
|
||||||
def test_gptq_module(self):
|
def test_gptq_module(self):
|
||||||
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
|
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user