diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index baacfea56e..64bbcb5f7b 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -350,6 +350,7 @@ class PytorchEngineConfig: dllm_denoising_steps (int): Dllm denoising steps. dllm_confidence_threshold (float): dllm unmasking threshold for dynamic unmasking. + enforce_fp32_head (bool): Enforce lm_head to use fp32 in forward. """ dtype: str = 'auto' tp: int = 1 @@ -387,6 +388,7 @@ class PytorchEngineConfig: hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False logprobs_mode: str = None + enforce_fp32_head: bool = False # router replay enable_return_routed_experts: bool = False enable_transfer_obj_ref: bool = False diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 1fecc29c92..ed3f021f58 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -435,6 +435,7 @@ class MiscConfig: logprobs_mode: str = None dllm_config: DLLMConfig = None enable_return_routed_experts: bool = False + enforce_fp32_head: bool = False @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): @@ -454,6 +455,7 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): logprobs_mode=engine_config.logprobs_mode, dllm_config=dllm_config, enable_return_routed_experts=engine_config.enable_return_routed_experts, + enforce_fp32_head=engine_config.enforce_fp32_head, ) return misc_config diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index c6c9345539..cf7b594785 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1019,6 +1019,7 @@ def _build_model(self): dllm_config=self.misc_config.dllm_config, strategy_factory=self.strategy_factory, enable_return_routed_experts=enable_return_routed_experts, + enforce_fp32_head=self.misc_config.enforce_fp32_head, ) patched_model = build_patched_model(self.model_config, device=device, diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 298ac8e18c..cec6c85456 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -492,6 +492,7 @@ class BuildModelContext: dllm_config: DLLMConfig = None strategy_factory: 'StrategyFactoryBase' = None enable_return_routed_experts: bool = False + enforce_fp32_head: bool = False class StepContextManager: diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index ffe693f9e7..9bf8726030 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -7,12 +7,13 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class InternLM2Attention(nn.Module): @@ -208,11 +209,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.tok_embeddings = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.tok_embeddings = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -269,7 +270,7 @@ def get_input_embeddings(self): return self.tok_embeddings -class InternLM2ForCausalLM(nn.Module, CudaGraphMixin): +class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """Rewrote model of InternLM2ForCausalLM.""" packed_modules_mapping = { @@ -290,11 +291,7 @@ def __init__(self, # build Model self.model = InternLM2Model(config, dtype=dtype, device=device) # build lm_head - self.output = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -315,9 +312,9 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.output(hidden_states) + def get_lm_head(self): + """Get lm_head.""" + return self.output def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/internlm3.py b/lmdeploy/pytorch/models/internlm3.py index d3bbc6830b..d565906e1c 100644 --- a/lmdeploy/pytorch/models/internlm3.py +++ b/lmdeploy/pytorch/models/internlm3.py @@ -7,12 +7,13 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class InternLM3Attention(nn.Module): @@ -210,11 +211,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -271,7 +272,7 @@ def get_input_embeddings(self): return self.embed_tokens -class InternLM3ForCausalLM(nn.Module, CudaGraphMixin): +class InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """Rewrote model of InternLM3ForCausalLM.""" packed_modules_mapping = { @@ -297,11 +298,7 @@ def __init__(self, # build InternLM3Model self.model = InternLM3Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -322,15 +319,6 @@ def forward( ) return hidden_states - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 2dbd9f9f3e..233ebabb99 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -19,7 +19,7 @@ from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Gating(nn.Module): @@ -444,7 +444,7 @@ def forward( return last_hidden_state -class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin): +class InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -801,9 +801,9 @@ def forward( position_ids=position_ids, attn_metadata=attn_metadata) - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.language_model.get_logits(hidden_states) + def get_lm_head(self): + """Get lm_head.""" + return self.language_model.get_lm_head() def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 6e760dbeac..0a151d8ed3 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -20,7 +20,7 @@ from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model @torch.compile(dynamic=True) @@ -439,7 +439,7 @@ def forward(self, image_features): return hidden_states -class InternVLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -485,9 +485,9 @@ def _mark_dynamic_once(self, pixel_values, dims): torch._dynamo.mark_dynamic(pixel_values, dims) self.has_compiled_vit = True - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.language_model.get_logits(hidden_states) + def get_lm_head(self): + """Get lm_head.""" + return self.language_model.get_lm_head() def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index 24dd504522..600b1aa37e 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -7,13 +7,13 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Phi3Attention(nn.Module): @@ -211,11 +211,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -272,7 +272,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Phi3ForCausalLM(nn.Module, CudaGraphMixin): +class Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -293,11 +293,7 @@ def __init__(self, # build model self.model = Phi3Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -318,15 +314,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py index a70d0bfecc..c6804d5586 100644 --- a/lmdeploy/pytorch/models/phi3_v.py +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -9,11 +9,10 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor -from lmdeploy.pytorch.nn.linear import build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .phi3 import Phi3ForCausalLM, Phi3Model -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import vlm_model CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0, dropout=0.0, @@ -264,7 +263,7 @@ def forward( ) -class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin): +class Phi3VForCausalLM(Phi3ForCausalLM): def __init__(self, config: PretrainedConfig, @@ -277,11 +276,7 @@ def __init__(self, # build model self.model = Phi3VModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) self.input_processor = Phi3VInputProcessor(config, dtype) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index c13c805f88..0fc2311eb1 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -7,12 +7,13 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Qwen2Attention(nn.Module): @@ -208,11 +209,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -269,7 +270,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2ForCausalLM(nn.Module, CudaGraphMixin): +class Qwen2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -295,11 +296,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -320,15 +317,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 11d2948b7a..b934d4142f 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Qwen2_5_PatchEmbed(nn.Module): @@ -366,7 +366,7 @@ def forward(self, return hidden_states -class Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -402,11 +402,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -447,15 +443,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index b62e5df4bf..c1acd0c972 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -10,12 +10,14 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Qwen2MoeAttention(nn.Module): @@ -309,11 +311,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -370,7 +372,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2MoeForCausalLM(nn.Module, CudaGraphMixin): +class Qwen2MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -396,11 +398,7 @@ def __init__(self, # build model self.model = Qwen2MoeModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -421,10 +419,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 77e025c638..84f87454f9 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -9,14 +9,14 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, LayerNorm, RMSNorm, SiluAndMul, +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, FlashAttention, LayerNorm, RMSNorm, SiluAndMul, build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int], @@ -235,11 +235,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.vocab_size = config.vocab_size self.mrope_section = config.rope_scaling['mrope_section'] - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -592,7 +592,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, return self.merger(hidden_states) -class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -628,11 +628,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -668,15 +664,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3.py b/lmdeploy/pytorch/models/qwen3.py index 381bfb72cb..4617b31c05 100644 --- a/lmdeploy/pytorch/models/qwen3.py +++ b/lmdeploy/pytorch/models/qwen3.py @@ -7,12 +7,13 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Qwen3Attention(nn.Module): @@ -217,11 +218,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ @@ -278,7 +279,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3ForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -304,11 +305,7 @@ def __init__(self, # build model self.model = Qwen3model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -329,15 +326,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index f170191d72..c76b62f83e 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -8,7 +8,8 @@ from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.eplb import EPLBManager from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe @@ -16,6 +17,7 @@ from .patch import get_build_model_context from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Qwen3MoeAttention(nn.Module): @@ -317,11 +319,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) if get_dist_manager().current_context().dist_config.enable_eplb: ep_size, _ = get_ep_world_rank() @@ -392,7 +394,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3MoeForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -419,11 +421,7 @@ def __init__(self, # build model self.model = Qwen3MoeModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) # for router replay bm_ctx = get_build_model_context() self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts @@ -461,10 +459,6 @@ def forward( return hidden_states return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts) - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py index 49420bcb4e..f03bffdcea 100644 --- a/lmdeploy/pytorch/models/qwen3_next.py +++ b/lmdeploy/pytorch/models/qwen3_next.py @@ -10,13 +10,15 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul, + build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.model import DeployModelMixinV1 class GatedDeltaMeta: @@ -812,11 +814,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers # TODO: use full config.num_hidden_layers @@ -879,7 +881,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3NextForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3NextForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -905,11 +907,7 @@ def __init__(self, # build model self.model = Qwen3NextModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -932,10 +930,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 6844c6b8d0..931fc68241 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -18,7 +18,7 @@ from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention from .qwen3 import Qwen3model from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Qwen3VLTextRotaryEmbedding(nn.Module): @@ -443,7 +443,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_ return hidden_states, deepstack_feature_lists -class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -481,11 +481,11 @@ def __init__(self, self.language_model = Qwen3VLTextModel(config.text_config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.text_config.hidden_size, - config.text_config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.text_config.hidden_size, + config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -545,15 +545,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.language_model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.language_model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py index 0b7938db49..03f3b6799b 100644 --- a/lmdeploy/pytorch/models/utils/model.py +++ b/lmdeploy/pytorch/models/utils/model.py @@ -6,6 +6,9 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.nn.linear import build_rowwise_linear + +from ..patch import get_build_model_context class DeployModelMixin: @@ -51,6 +54,52 @@ def get_input_processor(self) -> BaseModelInputProcessor: return None +class DeployModelMixinV1(DeployModelMixin): + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + if hidden_states.dtype != self.get_lm_head().weight.dtype: + hidden_states = hidden_states.to(dtype=self.get_lm_head().weight.dtype) + hidden_states = self.get_lm_head()(hidden_states) + return hidden_states + + def get_lm_head(self): + """Get lm_head.""" + return self.lm_head + + def get_input_embeddings(self): + """Get embeds.""" + raise NotImplementedError('Not Implemented') + + def update_weights(self): + """Update weights.""" + if getattr(self.config, 'tie_word_embeddings', False): + if getattr(self.config, 'enforce_fp32_head', False) is True: + self.get_input_embeddings().update_weight_dtype(torch.float32) + self.get_lm_head().weight = self.get_input_embeddings().weight + + def build_lm_head(self, + hidden_size: int, + vocab_size: int, + bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + **kwargs): + """Build LM Head.""" + bm_ctx = get_build_model_context() + self.config.enforce_fp32_head = bm_ctx.enforce_fp32_head + dtype = torch.float32 if self.config.enforce_fp32_head else dtype + lm_head = build_rowwise_linear( + hidden_size, + vocab_size, + bias, + dtype=dtype, + device=device, + **kwargs, + ) + return lm_head + + def vlm_model(vlm_cls): if not issubclass(vlm_cls, torch.nn.Module): raise ValueError('Only subclasses of nn.Module can be decorated with @vlm_model.') diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 96cbee873a..8b04074c07 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -3,6 +3,7 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import GeluAndMul, SiluAndMul # noqa: F401 from .attention import Attention, FlashAttention # noqa: F401 +from .emb import Embedding # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import RopeType # noqa: F401 diff --git a/lmdeploy/pytorch/nn/emb.py b/lmdeploy/pytorch/nn/emb.py new file mode 100644 index 0000000000..9ef72ba9c3 --- /dev/null +++ b/lmdeploy/pytorch/nn/emb.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional + +import torch +from torch import nn + + +class Embedding(nn.Embedding): + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + **kwargs): + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, dtype=dtype, device=device, **kwargs) + self.orig_dtype = self.weight.dtype + + def update_weight_dtype(self, dtype): + """Update weight dtype.""" + if self.weight.dtype != dtype: + self.weight.data = self.weight.data.to(dtype=dtype) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward.""" + out = super().forward(input) + if out.dtype != self.orig_dtype: + out = out.to(dtype=self.orig_dtype) + return out