Simplify the nan detection and greedy check in sampler (#1709)

This commit is contained in:
Lianmin Zheng
2024-10-18 20:21:24 -07:00
committed by GitHub
parent 2bcfba1b08
commit f0f8a7699b
6 changed files with 24 additions and 7 deletions

View File

@@ -245,10 +245,10 @@ class CudaGraphRunner:
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_bs] = forward_batch.input_ids
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
self.seq_lens[:raw_bs] = forward_batch.seq_lens
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(

View File

@@ -137,6 +137,7 @@ class ModelRunner:
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection,
}
)