Sync cuda graph runners (#6976)
This commit is contained in:
@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
seq_lens_sum=seq_lens.sum(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.positions = self.positions[:num_tokens]
|
forward_batch.positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||||
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
|
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(1)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
seq_lens_sum=seq_lens.sum(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -12,6 +14,7 @@ import triton.language as tl
|
|||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
Req,
|
Req,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||||
|
|
||||||
@@ -34,15 +36,15 @@ if is_cuda():
|
|||||||
elif is_hip():
|
elif is_hip():
|
||||||
from sgl_kernel import verify_tree_greedy
|
from sgl_kernel import verify_tree_greedy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Simulate acceptance length for benchmarking purposes
|
||||||
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
||||||
|
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
||||||
|
|
||||||
|
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -84,9 +86,9 @@ class EagleDraftInput:
|
|||||||
self,
|
self,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
speculative_num_steps: int,
|
speculative_num_steps: int,
|
||||||
|
context_length: int,
|
||||||
pad_input: bool = False,
|
pad_input: bool = False,
|
||||||
):
|
):
|
||||||
assert len(self.verified_id) == len(batch.out_cache_loc)
|
|
||||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||||
@@ -112,49 +114,49 @@ class EagleDraftInput:
|
|||||||
batch.input_ids = self.verified_id
|
batch.input_ids = self.verified_id
|
||||||
self.verified_id = new_verified_id
|
self.verified_id = new_verified_id
|
||||||
|
|
||||||
if pad_input:
|
if not pad_input:
|
||||||
batch_size = sum(not req.finished() for req in batch.reqs)
|
return
|
||||||
# Total constant input length after padding
|
|
||||||
static_len = speculative_num_steps + 1
|
|
||||||
# Total size after padding
|
|
||||||
padded_input_size = batch_size * static_len
|
|
||||||
|
|
||||||
padded_len = padded_input_size - batch.input_ids.shape[0]
|
batch_size = sum(not req.finished() for req in batch.reqs)
|
||||||
if padded_len > 0:
|
# Total constant input length after padding
|
||||||
new_input_ids = torch.nn.functional.pad(
|
static_len = speculative_num_steps + 1
|
||||||
batch.input_ids, (0, padded_len), value=0
|
# Total size after padding
|
||||||
)
|
padded_input_size = batch_size * static_len
|
||||||
position_padding = torch.arange(
|
|
||||||
padded_len, device=self.positions.device
|
|
||||||
)
|
|
||||||
new_positions = torch.cat([self.positions, position_padding])
|
|
||||||
|
|
||||||
# need dummy hidden states for the padded positions
|
padded_len = padded_input_size - batch.input_ids.shape[0]
|
||||||
hidden_states_dim = self.hidden_states.shape[-1]
|
if padded_len > 0:
|
||||||
new_hidden_states = torch.cat(
|
new_input_ids = torch.nn.functional.pad(
|
||||||
[
|
batch.input_ids, (0, padded_len), value=0
|
||||||
self.hidden_states,
|
)
|
||||||
torch.zeros(
|
position_padding = torch.arange(padded_len, device=self.positions.device)
|
||||||
(padded_len, hidden_states_dim),
|
new_positions = torch.cat([self.positions, position_padding])
|
||||||
dtype=self.hidden_states.dtype,
|
|
||||||
device=self.hidden_states.device,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# allocate KV cache location for the padded tokens
|
# need dummy hidden states for the padded positions
|
||||||
padded_cache_loc = torch.zeros(
|
hidden_states_dim = self.hidden_states.shape[-1]
|
||||||
padded_len,
|
new_hidden_states = torch.cat(
|
||||||
dtype=batch.out_cache_loc.dtype,
|
[
|
||||||
device=batch.out_cache_loc.device,
|
self.hidden_states,
|
||||||
)
|
torch.zeros(
|
||||||
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
|
(padded_len, hidden_states_dim),
|
||||||
|
dtype=self.hidden_states.dtype,
|
||||||
|
device=self.hidden_states.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
batch.input_ids = new_input_ids
|
# allocate KV cache location for the padded tokens
|
||||||
self.hidden_states = new_hidden_states
|
padded_cache_loc = torch.zeros(
|
||||||
self.positions = new_positions
|
padded_len,
|
||||||
batch.out_cache_loc = new_out_cache_loc
|
dtype=batch.out_cache_loc.dtype,
|
||||||
|
device=batch.out_cache_loc.device,
|
||||||
|
)
|
||||||
|
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
|
||||||
|
|
||||||
|
batch.input_ids = new_input_ids
|
||||||
|
self.hidden_states = new_hidden_states
|
||||||
|
self.positions = new_positions
|
||||||
|
batch.out_cache_loc = new_out_cache_loc
|
||||||
|
|
||||||
def generate_attn_arg_prefill(
|
def generate_attn_arg_prefill(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.spec_info.prepare_extend_after_decode(
|
batch.spec_info.prepare_extend_after_decode(
|
||||||
batch,
|
batch,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
|
self.server_args.context_length,
|
||||||
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
||||||
)
|
)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
run_logprob_check,
|
run_logprob_check,
|
||||||
)
|
)
|
||||||
@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
class TestEAGLEDraftExtend(CustomTestCase):
|
class TestEAGLEDraftExtend(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
|
|||||||
cls.accept_len_threshold = 1.50
|
cls.accept_len_threshold = 1.50
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
|
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
|
|||||||
cls.accept_len_threshold = 1.50
|
cls.accept_len_threshold = 1.50
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
|
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user