Fuse more ops & Simplify token mapping (#1758)
This commit is contained in:
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
def update(
|
||||
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
||||
):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_start_idx,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.max_context_len,
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
# extend part
|
||||
if use_ragged:
|
||||
|
||||
@@ -33,56 +33,61 @@ class Sampler(nn.Module):
|
||||
if isinstance(logits, LogitsProcessorOutput):
|
||||
logits = logits.next_token_logits
|
||||
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(sampling_info.temperatures)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits = None
|
||||
del logits
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
||||
logger.warning("Detected errors during sampling! NaN in the probability.")
|
||||
probs = torch.where(
|
||||
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
)
|
||||
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
batch_next_token_ids = torch.argmax(probs, -1)
|
||||
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits = None
|
||||
del logits
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
uniform_samples,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
sampling_info.min_ps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
return batch_next_token_ids
|
||||
return batch_next_token_ids.to(torch.int32)
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
|
||||
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||
input_ids,
|
||||
)
|
||||
|
||||
|
||||
class TpModelWorkerClient:
|
||||
"""A tensor parallel model worker."""
|
||||
|
||||
@@ -99,33 +108,25 @@ class TpModelWorkerClient:
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||
input_ids,
|
||||
)
|
||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||
|
||||
# Run forward
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
self.launch_event.set()
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(future_token_ids_ct + bs),
|
||||
-(future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||
torch.int32
|
||||
)
|
||||
self.future_token_ids_map[
|
||||
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
||||
] = next_token_ids
|
||||
|
||||
# Copy results to the CPU
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event = torch.cuda.Event(blocking=True)
|
||||
copy_event.record()
|
||||
|
||||
self.launch_event.set()
|
||||
self.copy_queue.put((copy_event, next_token_ids))
|
||||
|
||||
def copy_thread_func(self):
|
||||
@@ -149,8 +150,9 @@ class TpModelWorkerClient:
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(self.future_token_ids_ct + bs),
|
||||
-(self.future_token_ids_ct),
|
||||
-(self.future_token_ids_ct + 1),
|
||||
-(self.future_token_ids_ct + 1 + bs),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ class ReqToTokenPool:
|
||||
self.write = self.write_without_records
|
||||
|
||||
def write(self, indices, values):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def available_size(self):
|
||||
@@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if cache_v.dtype != self.dtype:
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
else:
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
copy_two_array(
|
||||
loc,
|
||||
self.k_buffer[layer_id],
|
||||
cache_k,
|
||||
self.v_buffer[layer_id],
|
||||
cache_v,
|
||||
self.dtype,
|
||||
self.store_dtype,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
@@ -92,6 +92,11 @@ def set_torch_compile_config():
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
@@ -112,7 +117,6 @@ class CudaGraphRunner:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
self.capture_bs = [
|
||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
||||
]
|
||||
@@ -253,7 +257,7 @@ class CudaGraphRunner:
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
positions=clamp_position(seq_lens),
|
||||
)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ def run_eval(args):
|
||||
model=args.model,
|
||||
max_tokens=2048,
|
||||
base_url=base_url,
|
||||
temperature=getattr(args, "temperature", 0.0),
|
||||
)
|
||||
|
||||
# Run eval
|
||||
@@ -119,6 +120,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||
parser.add_argument("--num-examples", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=512)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_eval(args)
|
||||
|
||||
Reference in New Issue
Block a user