Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_stateful_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,77 @@ def test_stateful_tool_env_add_tool_skips_args(self, mock_stateful_tool_env):
assert mock_stateful_tool_env.skipped_args["secret_tool"] == ["secret"]
assert "secret_tool" in mock_stateful_tool_env.tool_map

def test_add_tool_skips_dict_type_args(self, mock_stateful_tool_env):
def tool_with_dict(command: str, state: dict | None = None) -> str:
return command

mock_stateful_tool_env.add_tool(tool_with_dict, args_to_skip=["state"])

schema = next(
t
for t in mock_stateful_tool_env.oai_tools
if t["function"]["name"] == "tool_with_dict"
)
assert "state" not in schema["function"]["parameters"]["properties"]

def test_add_tool_does_not_mutate_original_signature(self, mock_stateful_tool_env):
"""Verify that add_tool with args_to_skip doesn't mutate the original function."""
import inspect

def my_tool(command: str, hidden: int, visible: bool = True) -> str:
"""A tool with multiple parameters."""
return command

original_params = list(inspect.signature(my_tool).parameters.keys())
original_annotations = dict(my_tool.__annotations__)

mock_stateful_tool_env.add_tool(my_tool, args_to_skip=["hidden"])

# Original function signature should be unchanged
assert list(inspect.signature(my_tool).parameters.keys()) == original_params
assert my_tool.__annotations__ == original_annotations
assert "hidden" in inspect.signature(my_tool).parameters

# But schema should have hidden removed
schema = next(
t
for t in mock_stateful_tool_env.oai_tools
if t["function"]["name"] == "my_tool"
)
assert "hidden" not in schema["function"]["parameters"]["properties"]
assert "command" in schema["function"]["parameters"]["properties"]

def test_add_tool_does_not_mutate_bound_method_signature(
self, mock_stateful_tool_env
):
"""Verify that add_tool with args_to_skip doesn't mutate bound method signatures."""
import inspect

class ToolProvider:
def my_tool(self, command: str, hidden: int, visible: bool = True) -> str:
"""A tool with multiple parameters."""
return command

bound_method = ToolProvider().my_tool
original_params = list(inspect.signature(bound_method).parameters.keys())

mock_stateful_tool_env.add_tool(bound_method, args_to_skip=["hidden"])

# Original bound method signature should be unchanged
assert (
list(inspect.signature(bound_method).parameters.keys()) == original_params
)
assert "hidden" in inspect.signature(bound_method).parameters

# But schema should have hidden removed
schema = next(
t
for t in mock_stateful_tool_env.oai_tools
if t["function"]["name"] == "my_tool"
)
assert "hidden" not in schema["function"]["parameters"]["properties"]
assert "command" in schema["function"]["parameters"]["properties"]

@pytest.mark.asyncio
async def test_tool_env_tool_invalid_json_arguments(
self, mock_openai_client, sample_chat_dataset
Expand Down
34 changes: 33 additions & 1 deletion verifiers/envs/stateful_tool_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
from abc import abstractmethod
from typing import Callable, cast
Expand All @@ -11,6 +12,37 @@
from verifiers.utils.tool_utils import convert_func_to_oai_tool


def filter_signature(func, args_to_skip):
"""Return a wrapper with filtered signature for schema generation.

Does not mutate the original function.
"""
if not args_to_skip:
return func
sig = inspect.signature(func)
filtered_sig = sig.replace(
parameters=[
p
for n, p in sig.parameters.items()
if n not in args_to_skip and n != "self"
]
)
filtered_annotations = {
k: v
for k, v in getattr(func, "__annotations__", {}).items()
if k not in args_to_skip
}

def wrapper(*args, **kwargs):
return func(*args, **kwargs)

wrapper.__name__ = getattr(func, "__name__", "unknown")
wrapper.__doc__ = getattr(func, "__doc__", None)
wrapper.__signature__ = filtered_sig
wrapper.__annotations__ = filtered_annotations
return wrapper


class StatefulToolEnv(vf.ToolEnv):
def __init__(
self,
Expand Down Expand Up @@ -48,7 +80,7 @@ def add_tool(self, tool: Callable, args_to_skip: list[str] = []):
Assumes all non-skipped args use standard JSON types (no remaining $ref/$defs).
"""
self.tools.append(tool)
oai_tool = convert_func_to_oai_tool(tool)
oai_tool = convert_func_to_oai_tool(filter_signature(tool, args_to_skip))
assert "function" in oai_tool
assert "parameters" in oai_tool["function"]
params = oai_tool["function"]["parameters"]
Expand Down
Loading