clawbench/tests/test_environment.py
2026-04-09 11:15:30 -07:00

88 lines
2.5 KiB
Python

from pathlib import Path
import pytest
from clawbench.environment import verify_completion
from clawbench.schemas import CompletionSpec, MemoryState, ToolCall, Transcript, TranscriptMessage
class MemoryFallbackClient:
async def _rpc(self, method: str, params=None): # noqa: ANN001
if method == "memory.search":
raise RuntimeError("unknown method: memory.search")
raise AssertionError(f"Unexpected RPC: {method} {params}")
async def get_agent_file(self, agent_id: str, name: str): # noqa: ARG002
if name == "MEMORY.md":
return {
"file": {
"content": "beta rollout regions: us, eu; retry budget: 3\n",
}
}
return {"file": {"content": ""}}
@pytest.mark.asyncio
async def test_memory_completion_falls_back_to_agent_memory_files(tmp_path: Path):
completion = CompletionSpec(
memory=[
MemoryState(
key_pattern="beta rollout regions",
value_contains=["us", "eu", "3"],
)
]
)
result = await verify_completion(
completion,
workspace=tmp_path,
client=MemoryFallbackClient(), # type: ignore[arg-type]
session_key="session-test",
agent_id="agent-test",
runtime_values={},
)
assert result.score == 1.0
@pytest.mark.asyncio
async def test_memory_completion_falls_back_to_transcript_when_memory_rpc_is_unavailable(tmp_path: Path):
completion = CompletionSpec(
memory=[
MemoryState(
key_pattern="beta rollout regions",
value_contains=["us", "eu", "3"],
)
]
)
transcript = Transcript(
messages=[
TranscriptMessage(
role="assistant",
tool_calls=[
ToolCall(
name="write",
family="edit",
input={
"path": "memory/notes.md",
"content": "beta rollout regions: us, eu; retry budget: 3\n",
},
success=True,
)
],
)
]
)
result = await verify_completion(
completion,
workspace=tmp_path,
client=MemoryFallbackClient(), # type: ignore[arg-type]
session_key="session-test",
agent_id="agent-test",
runtime_values={},
transcript=transcript,
)
assert result.score == 1.0