clawbench/tests/test_client.py
scoootscooob abf3500f69
Some checks failed
CI / Python ${{ matrix.python-version }} test suite (3.11) (push) Has been cancelled
CI / Python ${{ matrix.python-version }} test suite (3.12) (push) Has been cancelled
fix(harness): keep gateway RPC sockets alive
2026-05-02 14:51:52 -07:00

299 lines
10 KiB
Python

from __future__ import annotations
import asyncio
import pytest
from websockets.datastructures import Headers
from websockets.exceptions import InvalidMessage, InvalidStatus
from websockets.http11 import Response
from clawbench.client import GatewayClient, GatewayConfig, _correlate_transcript, _parse_single_message
from clawbench.schemas import Transcript
def test_gateway_config_defaults():
cfg = GatewayConfig()
# Defaults raised from 15s/60s -- see GatewayConfig docstring for
# the rationale; 15s used to race gateway cold-start and produce
# spurious empty_response failures.
assert cfg.connect_timeout == 30.0
assert cfg.request_timeout == 60.0
def test_gateway_config_env_overrides(monkeypatch):
monkeypatch.setenv("CLAWBENCH_CONNECT_TIMEOUT", "45")
monkeypatch.setenv("CLAWBENCH_REQUEST_TIMEOUT", "120")
cfg = GatewayConfig()
assert cfg.connect_timeout == 45.0
assert cfg.request_timeout == 120.0
@pytest.mark.parametrize("raw", ["not-a-number", "nan", "inf", "0", "-1"])
def test_gateway_config_invalid_env_falls_back_to_default(monkeypatch, caplog, raw):
monkeypatch.setenv("CLAWBENCH_CONNECT_TIMEOUT", raw)
with caplog.at_level("WARNING"):
cfg = GatewayConfig()
assert cfg.connect_timeout == 30.0
assert any("CLAWBENCH_CONNECT_TIMEOUT" in r.getMessage() for r in caplog.records)
@pytest.mark.asyncio
async def test_gateway_client_disables_websocket_keepalive_for_long_rpc(
monkeypatch: pytest.MonkeyPatch,
):
connect_kwargs: dict[str, object] = {}
class FakeWebSocket:
async def close(self) -> None:
return None
async def fake_connect(*args, **kwargs):
connect_kwargs.update(kwargs)
return FakeWebSocket()
async def fake_wait_event(self, event_name: str, *, timeout: float):
return {"payload": {"nonce": ""}}
async def fake_rpc(self, method: str, params=None, **kwargs):
return {"payload": {"type": "hello-ok", "protocol": 3}}
async def fake_listener(self):
await asyncio.sleep(60)
monkeypatch.setattr("clawbench.client.websockets.connect", fake_connect)
monkeypatch.setattr(GatewayClient, "_wait_event", fake_wait_event)
monkeypatch.setattr(GatewayClient, "_rpc", fake_rpc)
monkeypatch.setattr(GatewayClient, "_listener", fake_listener)
client = GatewayClient(GatewayConfig(connect_timeout=2))
await client.connect()
await client.close()
assert connect_kwargs["ping_interval"] is None
assert connect_kwargs["ping_timeout"] is None
def test_tool_results_are_correlated_back_to_tool_calls():
tool_message = _parse_single_message(
{
"role": "assistant",
"content": [
{"type": "toolCall", "id": "call-1", "name": "exec", "arguments": {"command": "pytest -q"}},
],
}
)
result_message = _parse_single_message(
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call-1", "content": "ERROR failed test"},
],
}
)
transcript = _correlate_transcript(Transcript(messages=[tool_message, result_message])) # type: ignore[arg-type]
call = transcript.tool_call_sequence[0]
assert call.output == "ERROR failed test"
assert call.success is False
assert call.error == "ERROR failed test"
def test_message_usage_is_parsed_into_transcript_usage():
message = _parse_single_message(
{
"role": "assistant",
"content": [{"type": "text", "text": "Done."}],
"usage": {
"input": 10,
"output": 20,
"reasoning": 5,
"cacheRead": 3,
"cacheWrite": 2,
"totalTokens": 40,
"cost": {"total": 0.0125},
},
}
)
assert message is not None
assert message.usage.input_tokens == 10
assert message.usage.output_tokens == 20
assert message.usage.reasoning_tokens == 5
assert message.usage.total_tokens == 40
assert message.usage.total_cost_usd == 0.0125
@pytest.mark.asyncio
async def test_gateway_client_retries_transient_drain_errors(monkeypatch: pytest.MonkeyPatch):
attempts = 0
class FakeWebSocket:
async def close(self) -> None:
return None
async def fake_connect(*args, **kwargs):
nonlocal attempts
attempts += 1
if attempts == 1:
raise InvalidStatus(Response(503, "Service Unavailable", Headers()))
return FakeWebSocket()
async def fake_wait_event(self, event_name: str, *, timeout: float):
return {"payload": {"nonce": ""}}
async def fake_rpc(self, method: str, params=None, **kwargs):
return {"payload": {"type": "hello-ok", "protocol": 3}}
async def fake_listener(self):
await asyncio.sleep(60)
monkeypatch.setattr("clawbench.client.websockets.connect", fake_connect)
monkeypatch.setattr(GatewayClient, "_wait_event", fake_wait_event)
monkeypatch.setattr(GatewayClient, "_rpc", fake_rpc)
monkeypatch.setattr(GatewayClient, "_listener", fake_listener)
client = GatewayClient(GatewayConfig(connect_timeout=2))
await client.connect()
assert attempts == 2
await client.close()
@pytest.mark.asyncio
async def test_gateway_client_retries_half_closed_handshake_errors(
monkeypatch: pytest.MonkeyPatch,
):
attempts = 0
class FakeWebSocket:
async def close(self) -> None:
return None
async def fake_connect(*args, **kwargs):
nonlocal attempts
attempts += 1
if attempts == 1:
raise InvalidMessage("did not receive a valid HTTP response")
return FakeWebSocket()
async def fake_wait_event(self, event_name: str, *, timeout: float):
return {"payload": {"nonce": ""}}
async def fake_rpc(self, method: str, params=None, **kwargs):
return {"payload": {"type": "hello-ok", "protocol": 3}}
async def fake_listener(self):
await asyncio.sleep(60)
monkeypatch.setattr("clawbench.client.websockets.connect", fake_connect)
monkeypatch.setattr(GatewayClient, "_wait_event", fake_wait_event)
monkeypatch.setattr(GatewayClient, "_rpc", fake_rpc)
monkeypatch.setattr(GatewayClient, "_listener", fake_listener)
client = GatewayClient(GatewayConfig(connect_timeout=2))
await client.connect()
assert attempts == 2
await client.close()
@pytest.mark.asyncio
async def test_send_and_wait_collects_messages_that_arrive_after_final_state():
client = GatewayClient(GatewayConfig(request_timeout=1))
session_key = "session-1"
async def fake_rpc(method: str, params=None):
assert method == "sessions.send"
async def emit() -> None:
await asyncio.sleep(0.01)
await client._event_queues[f"chat:{session_key}"].put({"payload": {"state": "final"}})
await asyncio.sleep(0.2)
await client._event_queues[f"session.message:{session_key}"].put(
{
"payload": {
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Late but valid."}],
"usage": {"input": 1, "output": 2, "totalTokens": 3},
}
}
}
)
asyncio.create_task(emit())
return {"ok": True, "payload": {}}
client._rpc = fake_rpc # type: ignore[method-assign]
transcript = await client.send_and_wait(session_key, "hello", timeout=1.0)
assert [message.text for message in transcript.assistant_messages] == ["Late but valid."]
@pytest.mark.asyncio
async def test_send_and_wait_passes_gateway_timeout_and_waits_for_run():
client = GatewayClient(GatewayConfig(request_timeout=1))
session_key = "session-1"
calls: list[tuple[str, dict | None, dict]] = []
async def fake_rpc(method: str, params=None, **kwargs):
calls.append((method, params, kwargs))
if method == "sessions.send":
return {"ok": True, "payload": {"runId": "run-1"}}
if method == "agent.wait":
return {"ok": True, "payload": {"runId": "run-1", "status": "completed"}}
if method == "sessions.get":
return {
"ok": True,
"payload": {
"messages": [
{
"role": "assistant",
"content": [{"type": "text", "text": "Done."}],
}
]
},
}
return {"ok": True, "payload": {}}
client._rpc = fake_rpc # type: ignore[method-assign]
transcript = await client.send_and_wait(session_key, "hello", timeout=1.5)
send_call = next(call for call in calls if call[0] == "sessions.send")
assert send_call[1] == {
"key": session_key,
"message": "hello",
"idempotencyKey": send_call[1]["idempotencyKey"],
"timeoutMs": 1500,
}
wait_call = next(call for call in calls if call[0] == "agent.wait")
assert wait_call[1] == {"runId": "run-1", "timeoutMs": 1500}
assert wait_call[2]["timeout"] == 11.5
assert [message.text for message in transcript.assistant_messages] == ["Done."]
@pytest.mark.asyncio
async def test_send_and_wait_aborts_run_when_no_terminal_state_arrives():
client = GatewayClient(GatewayConfig(request_timeout=1))
session_key = "session-1"
calls: list[tuple[str, dict | None, dict]] = []
async def fake_rpc(method: str, params=None, **kwargs):
calls.append((method, params, kwargs))
if method == "sessions.send":
return {"ok": True, "payload": {"runId": "run-timeout"}}
if method == "agent.wait":
await asyncio.sleep(60)
if method == "sessions.abort":
return {"ok": True, "payload": {"status": "aborted"}}
if method == "sessions.get":
return {"ok": True, "payload": {"messages": []}}
return {"ok": True, "payload": {}}
client._rpc = fake_rpc # type: ignore[method-assign]
await client.send_and_wait(session_key, "hello", timeout=0.01)
assert ("sessions.abort", {"key": session_key, "runId": "run-timeout"}, {"timeout": 1}) in calls