Simplify the nan detection and greedy check in sampler (#1709)
This commit is contained in:
@@ -245,10 +245,10 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs] = forward_batch.input_ids
|
||||
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = forward_batch.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
|
||||
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -137,6 +137,7 @@ class ModelRunner:
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"disable_penalizer": server_args.disable_penalizer,
|
||||
"disable_nan_detection": server_args.disable_nan_detection,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user