diff --git a/tests/test_stateful_tool_env.py b/tests/test_stateful_tool_env.py index 532ca6913..9ce0b1bfb 100644 --- a/tests/test_stateful_tool_env.py +++ b/tests/test_stateful_tool_env.py @@ -85,6 +85,77 @@ def test_stateful_tool_env_add_tool_skips_args(self, mock_stateful_tool_env): assert mock_stateful_tool_env.skipped_args["secret_tool"] == ["secret"] assert "secret_tool" in mock_stateful_tool_env.tool_map + def test_add_tool_skips_dict_type_args(self, mock_stateful_tool_env): + def tool_with_dict(command: str, state: dict | None = None) -> str: + return command + + mock_stateful_tool_env.add_tool(tool_with_dict, args_to_skip=["state"]) + + schema = next( + t + for t in mock_stateful_tool_env.oai_tools + if t["function"]["name"] == "tool_with_dict" + ) + assert "state" not in schema["function"]["parameters"]["properties"] + + def test_add_tool_does_not_mutate_original_signature(self, mock_stateful_tool_env): + """Verify that add_tool with args_to_skip doesn't mutate the original function.""" + import inspect + + def my_tool(command: str, hidden: int, visible: bool = True) -> str: + """A tool with multiple parameters.""" + return command + + original_params = list(inspect.signature(my_tool).parameters.keys()) + original_annotations = dict(my_tool.__annotations__) + + mock_stateful_tool_env.add_tool(my_tool, args_to_skip=["hidden"]) + + # Original function signature should be unchanged + assert list(inspect.signature(my_tool).parameters.keys()) == original_params + assert my_tool.__annotations__ == original_annotations + assert "hidden" in inspect.signature(my_tool).parameters + + # But schema should have hidden removed + schema = next( + t + for t in mock_stateful_tool_env.oai_tools + if t["function"]["name"] == "my_tool" + ) + assert "hidden" not in schema["function"]["parameters"]["properties"] + assert "command" in schema["function"]["parameters"]["properties"] + + def test_add_tool_does_not_mutate_bound_method_signature( + self, mock_stateful_tool_env + ): + """Verify that add_tool with args_to_skip doesn't mutate bound method signatures.""" + import inspect + + class ToolProvider: + def my_tool(self, command: str, hidden: int, visible: bool = True) -> str: + """A tool with multiple parameters.""" + return command + + bound_method = ToolProvider().my_tool + original_params = list(inspect.signature(bound_method).parameters.keys()) + + mock_stateful_tool_env.add_tool(bound_method, args_to_skip=["hidden"]) + + # Original bound method signature should be unchanged + assert ( + list(inspect.signature(bound_method).parameters.keys()) == original_params + ) + assert "hidden" in inspect.signature(bound_method).parameters + + # But schema should have hidden removed + schema = next( + t + for t in mock_stateful_tool_env.oai_tools + if t["function"]["name"] == "my_tool" + ) + assert "hidden" not in schema["function"]["parameters"]["properties"] + assert "command" in schema["function"]["parameters"]["properties"] + @pytest.mark.asyncio async def test_tool_env_tool_invalid_json_arguments( self, mock_openai_client, sample_chat_dataset diff --git a/verifiers/envs/stateful_tool_env.py b/verifiers/envs/stateful_tool_env.py index 8dbccbd1b..90381bc8d 100644 --- a/verifiers/envs/stateful_tool_env.py +++ b/verifiers/envs/stateful_tool_env.py @@ -1,3 +1,4 @@ +import inspect import json from abc import abstractmethod from typing import Callable, cast @@ -11,6 +12,37 @@ from verifiers.utils.tool_utils import convert_func_to_oai_tool +def filter_signature(func, args_to_skip): + """Return a wrapper with filtered signature for schema generation. + + Does not mutate the original function. + """ + if not args_to_skip: + return func + sig = inspect.signature(func) + filtered_sig = sig.replace( + parameters=[ + p + for n, p in sig.parameters.items() + if n not in args_to_skip and n != "self" + ] + ) + filtered_annotations = { + k: v + for k, v in getattr(func, "__annotations__", {}).items() + if k not in args_to_skip + } + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + setattr(wrapper, "__name__", getattr(func, "__name__", "unknown")) + setattr(wrapper, "__doc__", getattr(func, "__doc__", None)) + setattr(wrapper, "__signature__", filtered_sig) + setattr(wrapper, "__annotations__", filtered_annotations) + return wrapper + + class StatefulToolEnv(vf.ToolEnv): def __init__( self, @@ -48,7 +80,7 @@ def add_tool(self, tool: Callable, args_to_skip: list[str] = []): Assumes all non-skipped args use standard JSON types (no remaining $ref/$defs). """ self.tools.append(tool) - oai_tool = convert_func_to_oai_tool(tool) + oai_tool = convert_func_to_oai_tool(filter_signature(tool, args_to_skip)) assert "function" in oai_tool assert "parameters" in oai_tool["function"] params = oai_tool["function"]["parameters"]