Fuse more ops & Simplify token mapping (#1758)
This commit is contained in:
@@ -46,4 +46,8 @@ pip install nvtx
|
|||||||
import nvtx
|
import nvtx
|
||||||
with nvtx.annotate("description", color="color"):
|
with nvtx.annotate("description", color="color"):
|
||||||
# some critical code
|
# some critical code
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Other tips
|
||||||
|
|
||||||
|
1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
def update(
|
def update(
|
||||||
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking, will be initialized during runtime
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.update = self.update_single_wrapper
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
||||||
# Keep the signature for type checking, will be initialized during runtime
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.max_context_len,
|
self.max_context_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
|
||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
|
|||||||
@@ -33,56 +33,61 @@ class Sampler(nn.Module):
|
|||||||
if isinstance(logits, LogitsProcessorOutput):
|
if isinstance(logits, LogitsProcessorOutput):
|
||||||
logits = logits.next_token_logits
|
logits = logits.next_token_logits
|
||||||
|
|
||||||
# Post process logits
|
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
logits.div_(sampling_info.temperatures)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
logits = None
|
|
||||||
del logits
|
|
||||||
|
|
||||||
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the probability.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
probs = torch.where(
|
logits = torch.where(
|
||||||
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampling_info.is_all_greedy:
|
if sampling_info.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
batch_next_token_ids = torch.argmax(probs, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
else:
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
# Post process logits
|
||||||
uniform_samples = torch.rand(
|
logits.div_(sampling_info.temperatures)
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
probs = torch.softmax(logits, dim=-1)
|
||||||
)
|
logits = None
|
||||||
if sampling_info.need_min_p_sampling:
|
del logits
|
||||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
|
||||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
probs, uniform_samples, sampling_info.min_ps
|
uniform_samples = torch.rand(
|
||||||
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
)
|
)
|
||||||
else:
|
if sampling_info.need_min_p_sampling:
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||||
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||||
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, sampling_info.min_ps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
sampling_info.top_ks,
|
||||||
|
sampling_info.top_ps,
|
||||||
|
filter_apply_order="joint",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not torch.all(success):
|
||||||
|
logger.warning("Detected errors during sampling!")
|
||||||
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||||
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
|
# A slower fallback implementation with torch native operations.
|
||||||
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
|
||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
filter_apply_order="joint",
|
sampling_info.min_ps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.all(success):
|
return batch_next_token_ids.to(torch.int32)
|
||||||
logger.warning("Detected errors during sampling!")
|
|
||||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
|
||||||
# Here we provide a slower fallback implementation.
|
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
|
||||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch_next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
|
|||||||
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||||
|
input_ids[:] = torch.where(
|
||||||
|
input_ids < 0,
|
||||||
|
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TpModelWorkerClient:
|
class TpModelWorkerClient:
|
||||||
"""A tensor parallel model worker."""
|
"""A tensor parallel model worker."""
|
||||||
|
|
||||||
@@ -99,33 +108,25 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
input_ids = model_worker_batch.input_ids
|
input_ids = model_worker_batch.input_ids
|
||||||
input_ids[:] = torch.where(
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
input_ids < 0,
|
|
||||||
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
|
||||||
input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
self.launch_event.set()
|
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
future_next_token_ids = torch.arange(
|
self.future_token_ids_map[
|
||||||
-(future_token_ids_ct + bs),
|
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
||||||
-(future_token_ids_ct),
|
] = next_token_ids
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
|
||||||
torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Copy results to the CPU
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_event = torch.cuda.Event(blocking=True)
|
copy_event = torch.cuda.Event(blocking=True)
|
||||||
copy_event.record()
|
copy_event.record()
|
||||||
|
|
||||||
|
self.launch_event.set()
|
||||||
self.copy_queue.put((copy_event, next_token_ids))
|
self.copy_queue.put((copy_event, next_token_ids))
|
||||||
|
|
||||||
def copy_thread_func(self):
|
def copy_thread_func(self):
|
||||||
@@ -149,8 +150,9 @@ class TpModelWorkerClient:
|
|||||||
# Allocate output future objects
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
future_next_token_ids = torch.arange(
|
future_next_token_ids = torch.arange(
|
||||||
-(self.future_token_ids_ct + bs),
|
-(self.future_token_ids_ct + 1),
|
||||||
-(self.future_token_ids_ct),
|
-(self.future_token_ids_ct + 1 + bs),
|
||||||
|
-1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ReqToTokenPool:
|
|||||||
self.write = self.write_without_records
|
self.write = self.write_without_records
|
||||||
|
|
||||||
def write(self, indices, values):
|
def write(self, indices, values):
|
||||||
# Keep the signature for type checking, will be initialized during runtime
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
@@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
):
|
):
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
copy_two_array(
|
||||||
cache_k = cache_k.to(self.dtype)
|
loc,
|
||||||
if cache_v.dtype != self.dtype:
|
self.k_buffer[layer_id],
|
||||||
cache_v = cache_v.to(self.dtype)
|
cache_k,
|
||||||
if self.store_dtype != self.dtype:
|
self.v_buffer[layer_id],
|
||||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
cache_v,
|
||||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
self.dtype,
|
||||||
else:
|
self.store_dtype,
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
)
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||||
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||||
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||||
|
|||||||
@@ -92,6 +92,11 @@ def set_torch_compile_config():
|
|||||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def clamp_position(seq_lens):
|
||||||
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
@@ -112,7 +117,6 @@ class CudaGraphRunner:
|
|||||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||||
else:
|
else:
|
||||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
|
|
||||||
self.capture_bs = [
|
self.capture_bs = [
|
||||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
||||||
]
|
]
|
||||||
@@ -253,7 +257,7 @@ class CudaGraphRunner:
|
|||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
positions=clamp_position(seq_lens),
|
||||||
)
|
)
|
||||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ def run_eval(args):
|
|||||||
model=args.model,
|
model=args.model,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
|
temperature=getattr(args, "temperature", 0.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run eval
|
# Run eval
|
||||||
@@ -119,6 +120,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||||
parser.add_argument("--num-examples", type=int)
|
parser.add_argument("--num-examples", type=int)
|
||||||
parser.add_argument("--num-threads", type=int, default=512)
|
parser.add_argument("--num-threads", type=int, default=512)
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
run_eval(args)
|
run_eval(args)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
|
|||||||
eval_name="mmlu",
|
eval_name="mmlu",
|
||||||
num_examples=64,
|
num_examples=64,
|
||||||
num_threads=32,
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--sampling-backend", "pytorch"],
|
other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -37,6 +37,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
|
|||||||
eval_name="mmlu",
|
eval_name="mmlu",
|
||||||
num_examples=64,
|
num_examples=64,
|
||||||
num_threads=32,
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user