Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions dreadnode/scorers/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def llm_judge(
input: t.Any | None = None,
expected_output: t.Any | None = None,
model_params: rg.GenerateParams | AnyDict | None = None,
fallback_model: str | rg.Generator | None = None,
passing: t.Callable[[float], bool] | None = None,
min_score: float | None = None,
max_score: float | None = None,
Expand All @@ -56,12 +57,30 @@ def llm_judge(
input: The input which produced the output for context, if applicable.
expected_output: The expected output to compare against, if applicable.
model_params: Optional parameters for the model.
fallback_model: Optional fallback model to use if the primary model fails.
passing: Optional callback to determine if the score is passing based on the score value - overrides any model-specified value.
min_score: Optional minimum score for the judgement - if provided, the score will be clamped to this value.
max_score: Optional maximum score for the judgement - if provided, the score will be clamped to this value.
name: The name of the scorer.
"""

def _get_generator(
model_input: str | rg.Generator, params: rg.GenerateParams | AnyDict | None
) -> rg.Generator:
"""Helper to create a generator from model string or return existing generator."""
if isinstance(model_input, str):
return rg.get_generator(
model_input,
params=params
if isinstance(params, rg.GenerateParams)
else rg.GenerateParams.model_validate(params)
if params
else None,
)
if isinstance(model_input, rg.Generator):
return model_input
raise TypeError("Model must be a string identifier or a Generator instance.")

async def evaluate(
data: t.Any,
*,
Expand All @@ -72,32 +91,26 @@ async def evaluate(
input: t.Any | None = input,
expected_output: t.Any | None = expected_output,
model_params: rg.GenerateParams | AnyDict | None = model_params,
fallback_model: str | rg.Generator | None = fallback_model,
min_score: float | None = min_score,
max_score: float | None = max_score,
) -> list[Metric]:
generator: rg.Generator
if isinstance(model, str):
generator = rg.get_generator(
model,
params=model_params
if isinstance(model_params, rg.GenerateParams)
else rg.GenerateParams.model_validate(model_params)
if model_params
else None,
)
elif isinstance(model, rg.Generator):
generator = model
else:
raise TypeError("Model must be a string identifier or a Generator instance.")

input_data = JudgeInput(
input=str(input) if input is not None else None,
expected_output=str(expected_output) if expected_output is not None else None,
output=str(data),
rubric=rubric,
)

judgement = await judge.bind(generator)(input_data)
# Try primary model, fallback if needed
try:
generator = _get_generator(model, model_params)
judgement = await judge.bind(generator)(input_data)
except Exception:
if fallback_model is None:
raise
generator = _get_generator(fallback_model, model_params)
judgement = await judge.bind(generator)(input_data)

if min_score is not None:
judgement.score = max(min_score, judgement.score)
Expand Down
Loading