diff --git a/dreadnode/scorers/judge.py b/dreadnode/scorers/judge.py index a323752f..e5824c16 100644 --- a/dreadnode/scorers/judge.py +++ b/dreadnode/scorers/judge.py @@ -1,6 +1,7 @@ import typing as t import rigging as rg +from loguru import logger from dreadnode.common_types import AnyDict from dreadnode.meta import Config @@ -42,6 +43,7 @@ def llm_judge( input: t.Any | None = None, expected_output: t.Any | None = None, model_params: rg.GenerateParams | AnyDict | None = None, + fallback_model: str | rg.Generator | None = None, passing: t.Callable[[float], bool] | None = None, min_score: float | None = None, max_score: float | None = None, @@ -57,6 +59,7 @@ def llm_judge( input: The input which produced the output for context, if applicable. expected_output: The expected output to compare against, if applicable. model_params: Optional parameters for the model. + fallback_model: Optional fallback model to use if the primary model fails. passing: Optional callback to determine if the score is passing based on the score value - overrides any model-specified value. min_score: Optional minimum score for the judgement - if provided, the score will be clamped to this value. max_score: Optional maximum score for the judgement - if provided, the score will be clamped to this value. @@ -74,25 +77,31 @@ async def evaluate( input: t.Any | None = input, expected_output: t.Any | None = expected_output, model_params: rg.GenerateParams | AnyDict | None = model_params, + fallback_model: str | rg.Generator | None = fallback_model, min_score: float | None = min_score, max_score: float | None = max_score, system_prompt: str | None = system_prompt, ) -> list[Metric]: - generator: rg.Generator - if isinstance(model, str): - generator = rg.get_generator( - model, - params=model_params - if isinstance(model_params, rg.GenerateParams) - else rg.GenerateParams.model_validate(model_params) - if model_params - else None, - ) - elif isinstance(model, rg.Generator): - generator = model - else: + def _create_generator( + model: str | rg.Generator, + params: rg.GenerateParams | AnyDict | None, + ) -> rg.Generator: + """Create a Generator from a model identifier or return the Generator instance.""" + if isinstance(model, str): + return rg.get_generator( + model, + params=params + if isinstance(params, rg.GenerateParams) + else rg.GenerateParams.model_validate(params) + if params + else None, + ) + if isinstance(model, rg.Generator): + return model raise TypeError("Model must be a string identifier or a Generator instance.") + generator = _create_generator(model, model_params) + input_data = JudgeInput( input=str(input) if input is not None else None, expected_output=str(expected_output) if expected_output is not None else None, @@ -100,10 +109,36 @@ async def evaluate( rubric=rubric, ) - pipeline = generator.chat([]) - if system_prompt: - pipeline.chat.inject_system_content(system_prompt) - judgement = await judge.bind(pipeline)(input_data) + # Track fallback usage for observability + used_fallback = False + primary_error: str | None = None + + # Try primary model, fallback if needed + try: + pipeline = generator.chat([]) + if system_prompt: + pipeline.chat.inject_system_content(system_prompt) + judgement = await judge.bind(pipeline)(input_data) + except Exception as e: + if fallback_model is None: + raise + # Log primary model failure and fallback usage + used_fallback = True + primary_error = f"{type(e).__name__}: {e}" + primary_model_name = model if isinstance(model, str) else type(model).__name__ + fallback_model_name = ( + fallback_model if isinstance(fallback_model, str) else type(fallback_model).__name__ + ) + logger.warning( + f"Primary model '{primary_model_name}' failed with {primary_error}. " + f"Using fallback model '{fallback_model_name}'." + ) + # Use fallback model + generator = _create_generator(fallback_model, model_params) + pipeline = generator.chat([]) + if system_prompt: + pipeline.chat.inject_system_content(system_prompt) + judgement = await judge.bind(pipeline)(input_data) if min_score is not None: judgement.score = max(min_score, judgement.score) @@ -117,6 +152,15 @@ async def evaluate( value=judgement.score, attributes={ "reason": judgement.reason, + "used_fallback": used_fallback, + "fallback_model": ( + str(fallback_model) + if isinstance(fallback_model, str) + else type(fallback_model).__name__ + ) + if used_fallback + else None, + "primary_error": primary_error, }, ) pass_metric = Metric(value=float(judgement.passing))