Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions environments/math_python/math_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def load_environment(

parser = vf.Parser(extract_fn=extract_boxed_answer)
math_rubric = vf.MathRubric(parser=parser)
vf_env = vf.PythonEnv(
return vf.PythonEnv(
dataset=dataset,
system_prompt=system_prompt,
parser=parser,
Expand All @@ -50,7 +50,3 @@ def load_environment(
sandbox_client_max_workers=sandbox_client_max_workers,
**kwargs,
)
assert vf_env.tools is not None
tool_rubric = vf.ToolRubric(tools=vf_env.tools)
vf_env.rubric = vf.RubricGroup(rubrics=[tool_rubric, vf_env.rubric])
return vf_env
2 changes: 1 addition & 1 deletion tests/test_env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def func3(completion, **kwargs):

assert rubric.env_map == env_map
# Should have all unique reward function names
assert set(rubric.all_reward_names) == {"func1", "func2", "func3"}
assert set(rubric.all_reward_names) == {"num_turns", "func1", "func2", "func3"}

@pytest.mark.asyncio
async def test_env_group_rubric_score_rollout(self, mock_openai_client):
Expand Down
2 changes: 0 additions & 2 deletions verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .parsers.xml_parser import XMLParser
from .rubrics.judge_rubric import JudgeRubric
from .rubrics.rubric_group import RubricGroup
from .rubrics.tool_rubric import ToolRubric
from .utils.data_utils import (
extract_boxed_answer,
extract_hash_answer,
Expand Down Expand Up @@ -84,7 +83,6 @@ def setup_logging(
"Rubric",
"JudgeRubric",
"RubricGroup",
"ToolRubric",
"MathRubric",
"TextArenaEnv",
"ReasoningGymEnv",
Expand Down
8 changes: 8 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,14 @@ def set_kwargs(self, **kwargs) -> None:
else:
setattr(self, key, value)

def add_rubric(self, rubric: Rubric) -> None:
if self.rubric is None:
self.rubric = rubric
elif isinstance(self.rubric, vf.RubricGroup):
self.rubric.rubrics.append(rubric)
else:
self.rubric = vf.RubricGroup(rubrics=[self.rubric, rubric])

def set_max_seq_len(self, max_seq_len: int | None) -> None:
"""Set the maximum sequence length for this environment."""
self.max_seq_len = max_seq_len
Expand Down
11 changes: 11 additions & 0 deletions verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,22 @@
logger = logging.getLogger(__name__)


class MultiTurnMonitorRubric(vf.Rubric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_metric(self.num_turns)

async def num_turns(self, state: State) -> int:
return len(state["trajectory"])


class MultiTurnEnv(vf.Environment):
def __init__(self, max_turns: int = -1, **kwargs):
super().__init__(**kwargs)
self.max_turns = max_turns

self.add_rubric(MultiTurnMonitorRubric())

@abstractmethod
async def env_response(
self, messages: Messages, state: State, **kwargs
Expand Down
27 changes: 22 additions & 5 deletions verifiers/envs/python_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class PythonWorkerState(TypedDict):
ready: bool
execution_count: int
ready_wait_time: float


class PythonWorkerNotReadyError(vf.SandboxError): ...
Expand All @@ -28,6 +29,15 @@ class PythonWorkerRequestError(vf.SandboxError): ...
class PythonWorkerDeadError(vf.SandboxError): ...


class PythonMonitorRubric(vf.Rubric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_metric(self.python_ready_wait_time)

async def python_ready_wait_time(self, state: vf.State) -> float:
return state["python_state"]["ready_wait_time"]


class PythonEnv(SandboxEnv):
"""Sandbox-backed environment exposing a persistent Python REPL."""

Expand Down Expand Up @@ -189,6 +199,7 @@ def __init__(
start_command=start_command,
**kwargs,
)
self.add_rubric(PythonMonitorRubric())
self.add_tool(
self.python, args_to_skip=["sandbox_id", "sandbox_state", "python_state"]
)
Expand All @@ -199,6 +210,7 @@ async def setup_state(self, state: vf.State, **kwargs: Any) -> vf.State:
state["python_state"] = {
"ready": False,
"execution_count": 0,
"ready_wait_time": -1.0,
}
return state

Expand Down Expand Up @@ -229,7 +241,7 @@ async def python(
) -> str:
"""Execute `code` inside persistent Python REPL."""
if not python_state["ready"]:
await self._wait_for_worker_ready(sandbox_state, sandbox_id)
await self._wait_for_worker_ready(sandbox_id, sandbox_state, python_state)
python_state["ready"] = True
self.logger.debug(f"Executing code\n{code}")
sandbox_response = await self._send_worker_request(
Expand All @@ -242,7 +254,10 @@ async def cleanup_python_state(self, state: vf.State):
state.pop("python_state", None)

async def _wait_for_worker_ready(
self, sandbox_state: SandboxState, sandbox_id: str
self,
sandbox_id: str,
sandbox_state: SandboxState,
python_state: PythonWorkerState,
) -> None:
s = time.time()
try:
Expand All @@ -260,11 +275,13 @@ async def _wait_for_worker_ready(
)
if result.exit_code != 0:
raise RuntimeError(result.stderr)
self.logger.debug(
f"Waited {time.time() - s:.1f}s for Python worker to be ready"
)
except Exception as e:
raise PythonWorkerNotReadyError from e
ready_wait_time = time.time() - s
python_state["ready_wait_time"] = ready_wait_time
self.logger.debug(
f"Waited {ready_wait_time:.1f}s for Python worker to be ready"
)

async def _send_worker_request(
self,
Expand Down
42 changes: 37 additions & 5 deletions verifiers/envs/sandbox_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def teardown(self, wait: bool = True) -> None:

class SandboxState(TypedDict):
ready: bool
ready_wait_time: float
command_execution_times: list[float]


class SandboxCreationError(vf.SandboxError): ...
Expand All @@ -97,6 +99,24 @@ class SandboxCreationError(vf.SandboxError): ...
class SandboxNotReadyError(vf.SandboxError): ...


class SandboxMonitorRubric(vf.Rubric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_metric(self.sandbox_ready_wait_time)
self.add_metric(self.sandbox_command_execution_time)

async def sandbox_ready_wait_time(self, state: vf.State) -> float:
return state["sandbox_state"]["ready_wait_time"]

async def sandbox_command_execution_time(self, state: vf.State) -> float:
command_execution_times = state["sandbox_state"]["command_execution_times"]
return (
sum(command_execution_times) / len(command_execution_times)
if len(command_execution_times) > 0
else 0.0
)


class SandboxEnv(vf.StatefulToolEnv):
def __init__(
self,
Expand Down Expand Up @@ -127,6 +147,7 @@ def __init__(
stop_errors=stop_errors if stop_errors is not None else [vf.SandboxError],
**kwargs,
)
self.add_rubric(SandboxMonitorRubric())
self.timeout_per_command_seconds = timeout_per_command_seconds
self.sandbox_client = ThreadedAsyncSandboxClient(
max_workers=sandbox_client_max_workers,
Expand Down Expand Up @@ -173,7 +194,9 @@ async def _wait_for_sandbox_ready(
sandbox_state["ready"] = True
except Exception as e:
raise SandboxNotReadyError(e)
self.logger.debug(f"Waited {time.time() - s:.1f}s for sandbox to be ready")
ready_wait_time = time.time() - s
sandbox_state["ready_wait_time"] = ready_wait_time
self.logger.debug(f"Waited {ready_wait_time:.1f}s for sandbox to be ready")

async def bash(
self,
Expand All @@ -197,13 +220,16 @@ async def bash(
timeout=self.timeout_per_command_seconds,
)
except CommandTimeoutError:
e = time.time()
timeout_msg = f"Command timed out after {self.timeout_per_command_seconds}s"
self.logger.warning(f"{timeout_msg} in sandbox {sandbox_id}")
sandbox_state["command_execution_times"].append(
self.timeout_per_command_seconds
)
return f"Error: {timeout_msg}"
except Exception as e:
raise vf.SandboxError from e
e = time.time()
command_execution_time = time.time() - s
sandbox_state["command_execution_times"].append(command_execution_time)
stdout = results.stdout.strip()
stderr = (results.stderr or "").strip()
combined = stdout
Expand All @@ -213,7 +239,9 @@ async def bash(
else:
combined = f"stderr:\n{stderr}"
output = combined or "(no output)"
self.logger.debug(f"Executed command in {e - s:.1f}s. Got output: {output}")
self.logger.debug(
f"Executed command in {command_execution_time:.1f}s. Got output: {output}"
)
return output

async def post_rollout(self, state: vf.State):
Expand Down Expand Up @@ -252,7 +280,11 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
self.active_sandboxes.add(sandbox.id)
self.logger.debug(f"Created sandbox {sandbox.id}")
state["sandbox_id"] = sandbox.id
state["sandbox_state"] = {"ready": False}
state["sandbox_state"] = {
"ready": False,
"ready_wait_time": -1.0,
"command_execution_times": [],
}
return await super().setup_state(state, **kwargs)

def update_tool_args(
Expand Down
49 changes: 49 additions & 0 deletions verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,57 @@
from openai.types.chat import ChatCompletionAssistantMessageParam

import verifiers as vf
from verifiers.types import Messages
from verifiers.utils.async_utils import maybe_await
from verifiers.utils.tool_utils import convert_func_to_oai_tool


class ToolMonitorRubric(vf.Rubric):
def __init__(self, tools: list[Callable] | None = None, **kwargs):
super().__init__(**kwargs)

self.tools = tools or []
self.tool_names = [tool.__name__ for tool in self.tools] # type: ignore[union-attr]

# add tool metrics
self.add_metric(self.total_tool_calls)
for tool_name in self.tool_names:
self.add_metric(self.get_tool_call_count_func(tool_name))

async def total_tool_calls(self, completion: Messages) -> float:
"""Count the total number of tool calls."""
total = 0
assert isinstance(completion, list)
for msg in completion:
if msg["role"] == "assistant" and "tool_calls" in msg:
assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast]
tool_calls = assistant_msg.get("tool_calls", [])
if isinstance(tool_calls, list):
total += len(tool_calls)
return float(total)

def get_tool_call_count_func(self, tool_name: str) -> Callable:
"""Create a metric that counts calls to a specific tool."""

async def tool_call_count_func(completion: Messages) -> int:
"""Count calls to {tool_name} tool."""
count = 0
# Find tool calls in assistant messages
assert isinstance(completion, list)
for msg in completion:
if msg["role"] == "assistant" and "tool_calls" in msg:
assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast]
tool_calls = assistant_msg.get("tool_calls", [])
for tool_call in tool_calls:
if tool_call.get("function", {}).get("name") == tool_name:
count += 1

return count

tool_call_count_func.__name__ = f"{tool_name}_calls"
return tool_call_count_func


class ToolEnv(vf.MultiTurnEnv):
def __init__(
self,
Expand All @@ -28,6 +75,8 @@ def __init__(
}
super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs)

self.add_rubric(ToolMonitorRubric(tools=self.tools))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-tool metrics not tracked for dynamically added tools

ToolMonitorRubric computes tool_names at initialization time from the tools list passed to it. However, ToolEnv.__init__ creates the rubric before subclasses like SandboxEnv and PythonEnv call add_tool() to register their tools (bash, python). Since tool_names is computed once at init as an empty list, per-tool metrics like python_calls and bash_calls are never registered, defeating the purpose of the {tool_name}_calls feature described in the PR.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yea this is fair but i think this wasnt supported previously with the ToolRubric? adding would be easy to add, removing would be a bit harder bc we would have to remove a reward func from a rubric which we currently don't have support for

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah we should probably fix this (can be separate PR), shouldn't be too hard

previously we weren't using tool_rubric that heavily, and it already forced you to pass the list of tools, so it wasn't too big of a problem to just have the added tool in that list as well


def _should_stop_for_error(self, err: Exception) -> bool:
"""Check if error is in stop_errors."""
return any(isinstance(err, err_type) for err_type in self.stop_errors)
Expand Down
61 changes: 0 additions & 61 deletions verifiers/rubrics/tool_rubric.py

This file was deleted.

Loading