Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,16 @@ def get_tm_model(model_path,

input_model_name = get_input_model_registered_name(model_path, engine_config.model_format)
input_policy = get_input_policy(engine_config.model_format)

if engine_config.model_format == 'fp8' and not quant_config:
use_quant_online = True
else:
use_quant_online = False

input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path,
tokenizer_path=model_path,
input_policy=input_policy)
input_policy=input_policy,
use_quant_online=use_quant_online)

output_model_name, tm_cfg = \
get_output_model_registered_name_and_config(
Expand Down
35 changes: 33 additions & 2 deletions lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from lmdeploy.archs import get_model_arch
from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8

from ..config import RopeParam
from ..loader import create_loader
Expand All @@ -23,7 +24,16 @@ class LlamaReader(BaseReader):
attn_pattern = r'self_attn'
ffn_pattern = r'mlp'

def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy):
proj_pattern = 'proj'
scale_inv_prefix = 'scale_inv'

def __init__(self,
new_params: dict,
unused_params: dict,
last_bin: bool,
model_cfg: dict,
policy,
use_quant_online: bool = False):
super().__init__()
self.params = unused_params
self.params.update(new_params)
Expand All @@ -33,6 +43,22 @@ def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_
if tie_word_embeddings:
self.output_weight_key = self.tok_embeddings_key
self.processor = policy
if use_quant_online:
quant_params = self.quant_weight_fp8()
self.params.update(quant_params)

def quant_weight_fp8(self):
pattern_str = f'({self.attn_pattern}|{self.ffn_pattern}).*{self.proj_pattern}'
target_pattern = re.compile(pattern_str)

quant_params = {}
for name, weight in self.params.items():
if target_pattern.search(name):
q_weight, scale = quant_blocked_fp8(weight, torch.float8_e4m3fn, block_size=128)
quant_params[name] = q_weight
quant_params[f'{name}_{self.scale_inv_prefix}'] = scale.to(weight.dtype)

return quant_params

def filter(self, pattern: str):
params = []
Expand Down Expand Up @@ -104,14 +130,19 @@ class LlamaModel(BaseInputModel):
def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
super().__init__(model_path, tokenizer_path)
self.policy = kwargs.get('input_policy')
self.use_quant_online = kwargs.get('use_quant_online', False)
_, self.model_config = get_model_arch(model_path)
self.model_config = self.model_config.to_dict()

def readers(self):
mappings = getattr(self.Reader, 'mappings', [])
loader = create_loader(self.model_path, self.Reader.attn_layer_patten, mappings)
for i, param in loader.items():
reader = self.Reader(param, {}, False, self.model_config, policy=self.policy)
reader = self.Reader(param, {},
False,
self.model_config,
policy=self.policy,
use_quant_online=self.use_quant_online)
yield i, reader
torch.cuda.empty_cache()

Expand Down