-
Notifications
You must be signed in to change notification settings - Fork 466
Implement monitor rubrics #653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
58c0e48
84f0bbd
34324dc
dbf7a23
1fd31f4
8510a1a
21f75a0
4d26051
edddce4
099e025
14b566f
f85bdaa
e07e029
4be2f68
6a54dfe
573d68f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,57 @@ | |
| from openai.types.chat import ChatCompletionAssistantMessageParam | ||
|
|
||
| import verifiers as vf | ||
| from verifiers.types import Messages | ||
| from verifiers.utils.async_utils import maybe_await | ||
| from verifiers.utils.tool_utils import convert_func_to_oai_tool | ||
|
|
||
|
|
||
| class ToolMonitorRubric(vf.Rubric): | ||
| def __init__(self, tools: list[Callable] | None = None, **kwargs): | ||
| super().__init__(**kwargs) | ||
|
|
||
| self.tools = tools or [] | ||
| self.tool_names = [tool.__name__ for tool in self.tools] # type: ignore[union-attr] | ||
|
|
||
| # add tool metrics | ||
| self.add_metric(self.total_tool_calls) | ||
| for tool_name in self.tool_names: | ||
| self.add_metric(self.get_tool_call_count_func(tool_name)) | ||
|
|
||
| async def total_tool_calls(self, completion: Messages) -> float: | ||
| """Count the total number of tool calls.""" | ||
| total = 0 | ||
| assert isinstance(completion, list) | ||
| for msg in completion: | ||
| if msg["role"] == "assistant" and "tool_calls" in msg: | ||
| assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] | ||
| tool_calls = assistant_msg.get("tool_calls", []) | ||
| if isinstance(tool_calls, list): | ||
| total += len(tool_calls) | ||
| return float(total) | ||
|
|
||
| def get_tool_call_count_func(self, tool_name: str) -> Callable: | ||
| """Create a metric that counts calls to a specific tool.""" | ||
|
|
||
| async def tool_call_count_func(completion: Messages) -> int: | ||
| """Count calls to {tool_name} tool.""" | ||
| count = 0 | ||
| # Find tool calls in assistant messages | ||
| assert isinstance(completion, list) | ||
| for msg in completion: | ||
| if msg["role"] == "assistant" and "tool_calls" in msg: | ||
| assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] | ||
| tool_calls = assistant_msg.get("tool_calls", []) | ||
| for tool_call in tool_calls: | ||
| if tool_call.get("function", {}).get("name") == tool_name: | ||
| count += 1 | ||
|
|
||
| return count | ||
|
|
||
| tool_call_count_func.__name__ = f"{tool_name}_calls" | ||
| return tool_call_count_func | ||
|
|
||
|
|
||
| class ToolEnv(vf.MultiTurnEnv): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -28,6 +75,8 @@ def __init__( | |
| } | ||
| super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs) | ||
|
|
||
| self.add_rubric(ToolMonitorRubric(tools=self.tools)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per-tool metrics not tracked for dynamically added tools
Additional Locations (1)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, yea this is fair but i think this wasnt supported previously with the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah we should probably fix this (can be separate PR), shouldn't be too hard previously we weren't using tool_rubric that heavily, and it already forced you to pass the list of tools, so it wasn't too big of a problem to just have the added tool in that list as well |
||
|
|
||
| def _should_stop_for_error(self, err: Exception) -> bool: | ||
| """Check if error is in stop_errors.""" | ||
| return any(isinstance(err, err_type) for err_type in self.stop_errors) | ||
|
|
||
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.