[bug] fix errors related to context length in SD (#9388)
This commit is contained in:
@@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
# Parse args
|
||||
self.eagle_worker = eagle_worker
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
self.model_runner: EAGLEWorker
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
|
||||
@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Override context length with target model's context length
|
||||
# Override the context length of the draft model to be the same as the target model.
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
|
||||
# Do not capture cuda graph in `super().__init__()`
|
||||
|
||||
Reference in New Issue
Block a user