Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class PytorchEngineConfig:
dllm_denoising_steps (int): Dllm denoising steps.
dllm_confidence_threshold (float): dllm unmasking threshold for
dynamic unmasking.
enforce_fp32_head (bool): Enforce lm_head to use fp32 in forward.
"""
dtype: str = 'auto'
tp: int = 1
Expand Down Expand Up @@ -387,6 +388,7 @@ class PytorchEngineConfig:
hf_overrides: Optional[Dict[str, Any]] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None
enforce_fp32_head: bool = False
# router replay
enable_return_routed_experts: bool = False
enable_transfer_obj_ref: bool = False
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ class MiscConfig:
logprobs_mode: str = None
dllm_config: DLLMConfig = None
enable_return_routed_experts: bool = False
enforce_fp32_head: bool = False

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
Expand All @@ -454,6 +455,7 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
logprobs_mode=engine_config.logprobs_mode,
dllm_config=dllm_config,
enable_return_routed_experts=engine_config.enable_return_routed_experts,
enforce_fp32_head=engine_config.enforce_fp32_head,
)
return misc_config

Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ def _build_model(self):
dllm_config=self.misc_config.dllm_config,
strategy_factory=self.strategy_factory,
enable_return_routed_experts=enable_return_routed_experts,
enforce_fp32_head=self.misc_config.enforce_fp32_head,
)
patched_model = build_patched_model(self.model_config,
device=device,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ class BuildModelContext:
dllm_config: DLLMConfig = None
strategy_factory: 'StrategyFactoryBase' = None
enable_return_routed_experts: bool = False
enforce_fp32_head: bool = False


class StepContextManager:
Expand Down
31 changes: 14 additions & 17 deletions lmdeploy/pytorch/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul,
build_rotary_embedding_from_config)
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class InternLM2Attention(nn.Module):
Expand Down Expand Up @@ -208,11 +209,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)
self.tok_embeddings = Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)

# build all decode layers
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -269,7 +270,7 @@ def get_input_embeddings(self):
return self.tok_embeddings


class InternLM2ForCausalLM(nn.Module, CudaGraphMixin):
class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""Rewrote model of InternLM2ForCausalLM."""

packed_modules_mapping = {
Expand All @@ -290,11 +291,7 @@ def __init__(self,
# build Model
self.model = InternLM2Model(config, dtype=dtype, device=device)
# build lm_head
self.output = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -315,9 +312,9 @@ def forward(
)
return hidden_states

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.output(hidden_states)
def get_lm_head(self):
"""Get lm_head."""
return self.output

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
34 changes: 11 additions & 23 deletions lmdeploy/pytorch/models/internlm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul,
build_rotary_embedding_from_config)
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class InternLM3Attention(nn.Module):
Expand Down Expand Up @@ -210,11 +211,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)
self.embed_tokens = Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)

# build all decode layers
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -271,7 +272,7 @@ def get_input_embeddings(self):
return self.embed_tokens


class InternLM3ForCausalLM(nn.Module, CudaGraphMixin):
class InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""Rewrote model of InternLM3ForCausalLM."""

packed_modules_mapping = {
Expand All @@ -297,11 +298,7 @@ def __init__(self,
# build InternLM3Model
self.model = InternLM3Model(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -322,15 +319,6 @@ def forward(
)
return hidden_states

def update_weights(self):
"""Update weights."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.lm_head(hidden_states)

def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import DeployModelMixinV1, vlm_model


class Gating(nn.Module):
Expand Down Expand Up @@ -444,7 +444,7 @@ def forward(
return last_hidden_state


class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin):
class InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin):

def __init__(self,
config: PretrainedConfig,
Expand Down Expand Up @@ -801,9 +801,9 @@ def forward(
position_ids=position_ids,
attn_metadata=attn_metadata)

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
def get_lm_head(self):
"""Get lm_head."""
return self.language_model.get_lm_head()

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/pytorch/models/internvl3_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import DeployModelMixinV1, vlm_model


@torch.compile(dynamic=True)
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward(self, image_features):
return hidden_states


class InternVLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):
class InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):

def __init__(self,
config: PretrainedConfig,
Expand Down Expand Up @@ -485,9 +485,9 @@ def _mark_dynamic_once(self, pixel_values, dims):
torch._dynamo.mark_dynamic(pixel_values, dims)
self.has_compiled_vit = True

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
def get_lm_head(self):
"""Get lm_head."""
return self.language_model.get_lm_head()

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
33 changes: 10 additions & 23 deletions lmdeploy/pytorch/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, Embedding, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class Phi3Attention(nn.Module):
Expand Down Expand Up @@ -211,11 +211,11 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)
self.embed_tokens = Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)

# build all decode layers
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -272,7 +272,7 @@ def get_input_embeddings(self):
return self.embed_tokens


class Phi3ForCausalLM(nn.Module, CudaGraphMixin):
class Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""ModelForCausalLM."""

packed_modules_mapping = {
Expand All @@ -293,11 +293,7 @@ def __init__(self,
# build model
self.model = Phi3Model(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -318,15 +314,6 @@ def forward(
)
return hidden_states

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.lm_head(hidden_states)

def update_weights(self):
"""Update weights."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()
Expand Down
11 changes: 3 additions & 8 deletions lmdeploy/pytorch/models/phi3_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .phi3 import Phi3ForCausalLM, Phi3Model
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import vlm_model

CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,
dropout=0.0,
Expand Down Expand Up @@ -264,7 +263,7 @@ def forward(
)


class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin):
class Phi3VForCausalLM(Phi3ForCausalLM):

def __init__(self,
config: PretrainedConfig,
Expand All @@ -277,11 +276,7 @@ def __init__(self,
# build model
self.model = Phi3VModel(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

self.input_processor = Phi3VInputProcessor(config, dtype)

Expand Down
Loading
Loading