[Fix] Add speculative_draft_model_revision to server_args (#5255)
Signed-off-by: Devashish Lal <devashish@rivosinc.com>
This commit is contained in:
@@ -302,11 +302,16 @@ class ModelConfig:
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
|
||||
@@ -78,6 +78,11 @@ class TpModelWorker:
|
||||
if not is_draft_worker
|
||||
else server_args.speculative_draft_model_path
|
||||
),
|
||||
model_revision=(
|
||||
server_args.revision
|
||||
if not is_draft_worker
|
||||
else server_args.speculative_draft_model_revision
|
||||
),
|
||||
is_draft_model=is_draft_worker,
|
||||
)
|
||||
|
||||
|
||||
@@ -249,6 +249,7 @@ class ServerArgs:
|
||||
# Speculative decoding
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_draft_model_revision: Optional[str] = None
|
||||
speculative_num_steps: Optional[int] = None
|
||||
speculative_eagle_topk: Optional[int] = None
|
||||
speculative_num_draft_tokens: Optional[int] = None
|
||||
@@ -1498,6 +1499,14 @@ class ServerArgs:
|
||||
type=str,
|
||||
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-draft-model-revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The specific draft model version to use. It can be a branch "
|
||||
"name, a tag name, or a commit id. If unspecified, will use "
|
||||
"the default version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-num-steps",
|
||||
type=int,
|
||||
|
||||
@@ -505,6 +505,7 @@ class SRTRunner:
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
speculative_draft_model_path: Optional[str] = None,
|
||||
speculative_draft_model_revision: Optional[str] = None,
|
||||
speculative_algorithm: Optional[str] = None,
|
||||
speculative_num_steps: Optional[int] = None,
|
||||
speculative_eagle_topk: Optional[int] = None,
|
||||
@@ -526,6 +527,9 @@ class SRTRunner:
|
||||
spec_kwargs = {}
|
||||
if speculative_draft_model_path:
|
||||
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
||||
spec_kwargs["speculative_draft_model_revision"] = (
|
||||
speculative_draft_model_revision
|
||||
)
|
||||
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
||||
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
||||
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
||||
|
||||
Reference in New Issue
Block a user