Skip to content

Conversation

@43758726
Copy link
Collaborator

@43758726 43758726 commented Dec 25, 2025

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

The goal of this PR is to introduce online bf16-to-fp8 quantization. This allows users to load bf16 checkpoints and automatically quantize them to fp8 in-memory during the loading phase, enabling fp8 inference without the need for offline conversion.

Modification

  1. Configuration (lmdeploy/lmdeploy/messages.py & lmdeploy/lmdeploy/turbomind/deploy/converter.py):
    • Modified TurboMindEngineConfig: Added use_quant and quant_config fields to support passing quantization parameters from the API.
    • Updated _process_quant_config: Added logic to validate and parse the incoming FP8 configuration.
    • Added get_weight_config: move dtype, group_size, weight_type, expert_weight_type from get_output_model_registered_name_and_config function to get_weight_config function.
    • Updated get_tm_model: Add whether use quant online judge.
  2. Weight Loading (lmdeploy/lmdeploy/turbomind/deploy/source_model/llama.py):
    • Added quant_weight function in LlamaReader: Iterate through the loaded weights, select the weights to be quantized, quantize them, and export the quantized weights and scale.

Use cases (Optional)

Python API:
Users can now initialize the pipeline with a BF16 model path but specify FP8 quantization in the config:

from lmdeploy import pipeline, TurboMindEngineConfig

# Load a BF16 model and run it with FP8 inference
backend_config = TurbomindEngineConfig(
    cache_max_entry_count=0.2,
    model_format = 'fp8'
)
pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config)
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces online BF16-to-FP8 quantization capability for TurboMind, allowing users to load BF16 model checkpoints and automatically quantize them to FP8 format during the loading phase without requiring offline conversion.

Key Changes:

  • Added use_quant and quant_config fields to TurbomindEngineConfig to enable FP8 quantization configuration via API
  • Implemented build_quant_config() function to process and normalize quantization parameters
  • Added quant_weight() method in LlamaReader to perform online weight quantization with regex-based weight selection

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 14 comments.

File Description
lmdeploy/messages.py Extended TurbomindEngineConfig with quantization parameters and added validation logic in __post_init__
lmdeploy/turbomind/deploy/converter.py Added build_quant_config() function and modified get_tm_model() to handle quantization config flow
lmdeploy/turbomind/deploy/source_model/llama.py Implemented quant_weight() method in LlamaReader to quantize weights using FP8 format and updated constructor to support quantization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

shares a single scaling factor. Value: int or list/tuple (default: [128, 128]).
scale_fmt (Optional): The quantization strategy used in DeepSeekV3.1 is based on "ue8m0".
The principle is to discard the mantissa of the original quantization scale and only
keep the exponent. It must set to "ue8m0" if used.
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammatical error in documentation. "It must set to" should be "It must be set to" for correct grammar.

Suggested change
keep the exponent. It must set to "ue8m0" if used.
keep the exponent. It must be set to "ue8m0" if used.

Copilot uses AI. Check for mistakes.
Comment on lines 27 to 28
proj_pattern = r'proj'
scale_inv_prefix = r'scale_inv'
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables 'proj_pattern' and 'scale_inv_prefix' are defined with raw string prefix (r'') but they don't contain any regex special characters that need escaping. These are just literal strings used for concatenation and naming. The 'r' prefix is unnecessary and could be misleading, suggesting they are regex patterns when they're just string literals.

Suggested change
proj_pattern = r'proj'
scale_inv_prefix = r'scale_inv'
proj_pattern = 'proj'
scale_inv_prefix = 'scale_inv'

Copilot uses AI. Check for mistakes.
Comment on lines 114 to 122
"""Build quant config from config in Engine Config.
Args:
src_cfg (dict): The source quantization configuration dict,
typically derived from TurbomindEngineConfig or
PytorchEngineConfig
Return:
The processed and validated quantization configuration.
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly states that build_quant_config is "added build quant config". This is redundant and unclear. Consider revising to something like "Builds and validates the quantization configuration from engine config parameters."

Suggested change
"""Build quant config from config in Engine Config.
Args:
src_cfg (dict): The source quantization configuration dict,
typically derived from TurbomindEngineConfig or
PytorchEngineConfig
Return:
The processed and validated quantization configuration.
"""Builds and validates the quantization configuration from engine config parameters.
Args:
src_cfg (dict): Source configuration dict containing quantization-related
fields, typically derived from TurbomindEngineConfig or
PytorchEngineConfig.
Returns:
dict: The processed and validated quantization configuration.

Copilot uses AI. Check for mistakes.
Comment on lines 113 to 143
def build_quant_config(src_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""Build quant config from config in Engine Config.
Args:
src_cfg (dict): The source quantization configuration dict,
typically derived from TurbomindEngineConfig or
PytorchEngineConfig
Return:
The processed and validated quantization configuration.
"""
keys = ['quant_method', 'activation_scheme', 'fmt', 'weight_block_size', 'scale_fmt']
config = {k: src_cfg.get(k) for k in keys}

config['activation_scheme'] = config['activation_scheme'] or 'dynamic'

# special process for fp8 quant
if config['quant_method'] == 'fp8':
config['fmt'] = config['fmt'] or 'e4m3'

match config['weight_block_size']:
case None:
config['weight_block_size'] = [128, 128]
case int(val):
config['weight_block_size'] = [val, val]
case [x, y] | (x, y):
config['weight_block_size'] = [x, y]
case _:
raise ValueError('invalid quant_config["weight_block_size"] datatype')

return config
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing required field validation. According to the documentation in messages.py, quant_config is "Required if use_quant is True" and quant_method is marked as "Required" within quant_config. However, build_quant_config doesn't validate that 'quant_method' is present in src_cfg, leading to config['quant_method'] being None, which will cause the fp8-specific logic to be skipped silently.

Copilot uses AI. Check for mistakes.
Comment on lines 301 to 313
method = self.quant_config and self.quant_config.get('quant_method')
if method is not None:
assert method == 'fp8', (f'Unsupported quant method: "{method}". '
f'Expected fp8')

weight_bs = self.quant_config and self.quant_config.get('weight_block_size')
assert weight_bs is None or isinstance(
weight_bs, (int, list, tuple)), (f'Unsupported weight_block_size type: {type(weight_bs).__name__}. '
f'Expected int, list, or tuple')

scale_fmt = self.quant_config and self.quant_config.get('scale_fmt')
assert scale_fmt is None or scale_fmt == 'ue8m0', (f'Unsupported scale_fmt: {scale_fmt}. '
f'Expected ue8m0')
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation logic doesn't verify that use_quant=True when quant_config is provided. If a user accidentally sets use_quant=False but provides a quant_config, the configuration would be silently ignored, which could lead to confusion. Consider adding validation to ensure these fields are consistent.

Copilot uses AI. Check for mistakes.
Comment on lines 314 to 315


Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'activation_scheme' field in quant_config is documented as optional with a default of 'dynamic', but there's no validation in post_init to ensure that if provided, it contains a valid value. Consider adding validation to check that activation_scheme is either None or a valid scheme value to catch configuration errors early.

Suggested change
activation_scheme = self.quant_config and self.quant_config.get('activation_scheme')
valid_activation_schemes = ('dynamic',)
assert activation_scheme is None or activation_scheme in valid_activation_schemes, (
f'Unsupported activation_scheme: {activation_scheme}. '
f'Expected one of: {", ".join(valid_activation_schemes)}'
)

Copilot uses AI. Check for mistakes.
raise ValueError(f'Unsupported FP8 format: {fmt}')

fp8_dtype = self.quant_dtype_map[fmt]
group_size = quant_cfg['weight_block_size'][0]
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing input validation for 'group_size'. The code accesses quant_cfg['weight_block_size'][0] without checking if weight_block_size is a list/tuple or if it has at least one element. If weight_block_size is an int (which is valid according to the config), this will fail with a TypeError. The validation should occur before this access.

Suggested change
group_size = quant_cfg['weight_block_size'][0]
weight_block_size = quant_cfg.get('weight_block_size')
if isinstance(weight_block_size, (list, tuple)):
if not weight_block_size:
raise ValueError('quantization_config.weight_block_size must not be empty')
group_size = weight_block_size[0]
elif isinstance(weight_block_size, int):
group_size = weight_block_size
else:
raise TypeError(
'quantization_config.weight_block_size must be an int or a non-empty list/tuple of ints'
)

Copilot uses AI. Check for mistakes.
"""
_, cfg = get_model_arch(model_path)
quant_config = search_nested_config(cfg.to_dict(), 'quantization_config')
if engine_config.use_quant and not quant_config:
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential NullPointerException. When engine_config.use_quant is True but engine_config.quant_config is None, calling build_quant_config(engine_config.quant_config) will pass None to src_cfg, which will then call src_cfg.get(k) causing an AttributeError. The validation in post_init doesn't enforce that quant_config must be provided when use_quant is True.

Suggested change
if engine_config.use_quant and not quant_config:
if engine_config.use_quant and not quant_config:
if engine_config.quant_config is None:
raise ValueError(
'engine_config.quant_config must be provided when '
'engine_config.use_quant is True and the model has no '
'quantization_config.'
)

Copilot uses AI. Check for mistakes.
quantization. Required if use_quant is True.
Keys in quant_config:
quant_method (Required): Specifies the quantization algorithm
to use.Must be set to "fp8" currently.
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after period in documentation. There should be a space between "fp8" and "currently" for proper formatting.

Suggested change
to use.Must be set to "fp8" currently.
to use. Must be set to "fp8" currently.

Copilot uses AI. Check for mistakes.
(inputs) are quantized during inference. Default to "dynamic",
meaning scales are calculated in real time.
fmt (Optional): Defines the specific binary format for the FP8 weights.
"e4m3" provides higher precision, while "e5m2"provides a wider
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after period in documentation. There should be a space between the closing quote and "provides" for proper formatting.

Suggested change
"e4m3" provides higher precision, while "e5m2"provides a wider
"e4m3" provides higher precision, while "e5m2" provides a wider

Copilot uses AI. Check for mistakes.

from lmdeploy.archs import get_model_arch

from ....lite.quantization.weight.quant_utils import quant_blocked_fp8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May not use relative import here. lite is a different out-of-turbomind-deploy package. Please import its full path.

input_model_name = get_input_model_registered_name(model_path, engine_config.model_format)
input_policy = get_input_policy(engine_config.model_format)

weight_config = get_weight_config(model_path=model_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_config is unnecessary

quant_params = self.quant_weight()
self.params.update(quant_params)

def quant_weight(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear function definition.
May use def quantize_weight_fp8

Comment on lines 51 to 53
weight_cfg = self.model_cfg.get('weight_config')
assert weight_cfg['weight_type'] == 'fp8', (f"Unsupported weight_type: {weight_cfg['weight_type']}. "
f'Expected fp8. ')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can reduce some redundancy here by removing the check.
The following q_weight can be updated by

q_weight, scale = quant_blocked_fp8(weight, torch.float8_e4m3fn, block_size=128)

Since turbomind engine doesn't support other values of block_size

"""Get the registered name of the turbomind model and its configuration
according to the input model path, format and user-input config. The name
will be used to access the OUTPUT_MODELS registry.
def get_weight_config(model_path: str, model_format: str, dtype: str, group_size: int) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary code. Please refer to the comments in "source_model/llama.py"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants