clawbench/tests/test_client.py
Vincent Koc e7bd37c1b5
Some checks failed
CI / Python ${{ matrix.python-version }} test suite (3.11) (push) Has been cancelled
CI / Python ${{ matrix.python-version }} test suite (3.12) (push) Has been cancelled
fix(client): clean pending rpc on send failure
2026-04-29 00:07:40 -07:00

229 lines
7.6 KiB
Python

from __future__ import annotations
import asyncio
import json
import pytest
from websockets.datastructures import Headers
from websockets.exceptions import InvalidMessage, InvalidStatus
from websockets.http11 import Response
from clawbench.client import GatewayClient, GatewayConfig, _correlate_transcript, _parse_single_message
from clawbench.schemas import Transcript
def test_gateway_config_defaults():
cfg = GatewayConfig()
# Defaults raised from 15s/60s -- see GatewayConfig docstring for
# the rationale; 15s used to race gateway cold-start and produce
# spurious empty_response failures.
assert cfg.connect_timeout == 30.0
assert cfg.request_timeout == 60.0
def test_gateway_config_env_overrides(monkeypatch):
monkeypatch.setenv("CLAWBENCH_CONNECT_TIMEOUT", "45")
monkeypatch.setenv("CLAWBENCH_REQUEST_TIMEOUT", "120")
cfg = GatewayConfig()
assert cfg.connect_timeout == 45.0
assert cfg.request_timeout == 120.0
@pytest.mark.parametrize("raw", ["not-a-number", "nan", "inf", "0", "-1"])
def test_gateway_config_invalid_env_falls_back_to_default(monkeypatch, caplog, raw):
monkeypatch.setenv("CLAWBENCH_CONNECT_TIMEOUT", raw)
with caplog.at_level("WARNING"):
cfg = GatewayConfig()
assert cfg.connect_timeout == 30.0
assert any("CLAWBENCH_CONNECT_TIMEOUT" in r.getMessage() for r in caplog.records)
def test_tool_results_are_correlated_back_to_tool_calls():
tool_message = _parse_single_message(
{
"role": "assistant",
"content": [
{"type": "toolCall", "id": "call-1", "name": "exec", "arguments": {"command": "pytest -q"}},
],
}
)
result_message = _parse_single_message(
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call-1", "content": "ERROR failed test"},
],
}
)
transcript = _correlate_transcript(Transcript(messages=[tool_message, result_message])) # type: ignore[arg-type]
call = transcript.tool_call_sequence[0]
assert call.output == "ERROR failed test"
assert call.success is False
assert call.error == "ERROR failed test"
def test_message_usage_is_parsed_into_transcript_usage():
message = _parse_single_message(
{
"role": "assistant",
"content": [{"type": "text", "text": "Done."}],
"usage": {
"input": 10,
"output": 20,
"reasoning": 5,
"cacheRead": 3,
"cacheWrite": 2,
"totalTokens": 40,
"cost": {"total": 0.0125},
},
}
)
assert message is not None
assert message.usage.input_tokens == 10
assert message.usage.output_tokens == 20
assert message.usage.reasoning_tokens == 5
assert message.usage.total_tokens == 40
assert message.usage.total_cost_usd == 0.0125
@pytest.mark.asyncio
async def test_gateway_client_retries_transient_drain_errors(monkeypatch: pytest.MonkeyPatch):
attempts = 0
class FakeWebSocket:
async def close(self) -> None:
return None
async def fake_connect(*args, **kwargs):
nonlocal attempts
attempts += 1
if attempts == 1:
raise InvalidStatus(Response(503, "Service Unavailable", Headers()))
return FakeWebSocket()
async def fake_wait_event(self, event_name: str, *, timeout: float):
return {"payload": {"nonce": ""}}
async def fake_rpc(self, method: str, params=None):
return {"payload": {"type": "hello-ok", "protocol": 3}}
async def fake_listener(self):
await asyncio.sleep(60)
monkeypatch.setattr("clawbench.client.websockets.connect", fake_connect)
monkeypatch.setattr(GatewayClient, "_wait_event", fake_wait_event)
monkeypatch.setattr(GatewayClient, "_rpc", fake_rpc)
monkeypatch.setattr(GatewayClient, "_listener", fake_listener)
client = GatewayClient(GatewayConfig(connect_timeout=2))
await client.connect()
assert attempts == 2
await client.close()
@pytest.mark.asyncio
async def test_gateway_client_retries_half_closed_handshake_errors(
monkeypatch: pytest.MonkeyPatch,
):
attempts = 0
class FakeWebSocket:
async def close(self) -> None:
return None
async def fake_connect(*args, **kwargs):
nonlocal attempts
attempts += 1
if attempts == 1:
raise InvalidMessage("did not receive a valid HTTP response")
return FakeWebSocket()
async def fake_wait_event(self, event_name: str, *, timeout: float):
return {"payload": {"nonce": ""}}
async def fake_rpc(self, method: str, params=None):
return {"payload": {"type": "hello-ok", "protocol": 3}}
async def fake_listener(self):
await asyncio.sleep(60)
monkeypatch.setattr("clawbench.client.websockets.connect", fake_connect)
monkeypatch.setattr(GatewayClient, "_wait_event", fake_wait_event)
monkeypatch.setattr(GatewayClient, "_rpc", fake_rpc)
monkeypatch.setattr(GatewayClient, "_listener", fake_listener)
client = GatewayClient(GatewayConfig(connect_timeout=2))
await client.connect()
assert attempts == 2
await client.close()
@pytest.mark.asyncio
async def test_send_and_wait_collects_messages_that_arrive_after_final_state():
client = GatewayClient(GatewayConfig(request_timeout=1))
session_key = "session-1"
async def fake_rpc(method: str, params=None):
assert method == "sessions.send"
async def emit() -> None:
await asyncio.sleep(0.01)
await client._event_queues[f"chat:{session_key}"].put({"payload": {"state": "final"}})
await asyncio.sleep(0.2)
await client._event_queues[f"session.message:{session_key}"].put(
{
"payload": {
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Late but valid."}],
"usage": {"input": 1, "output": 2, "totalTokens": 3},
}
}
}
)
asyncio.create_task(emit())
return {"ok": True, "payload": {}}
client._rpc = fake_rpc # type: ignore[method-assign]
transcript = await client.send_and_wait(session_key, "hello", timeout=1.0)
assert [message.text for message in transcript.assistant_messages] == ["Late but valid."]
@pytest.mark.asyncio
async def test_rpc_send_failure_cleans_pending_request():
class FailingWebSocket:
async def send(self, payload: str) -> None: # noqa: ARG002
raise ConnectionError("socket closed")
client = GatewayClient(GatewayConfig(request_timeout=0.01))
client._ws = FailingWebSocket() # type: ignore[assignment]
with pytest.raises(ConnectionError, match="socket closed"):
await client._rpc("sessions.create", {"model": "test-model"})
assert client._pending == {}
@pytest.mark.asyncio
async def test_rpc_timeout_cleans_pending_request():
sent_frames: list[dict[str, object]] = []
class SilentWebSocket:
async def send(self, payload: str) -> None:
sent_frames.append(json.loads(payload))
client = GatewayClient(GatewayConfig(request_timeout=0.01))
client._ws = SilentWebSocket() # type: ignore[assignment]
with pytest.raises(TimeoutError, match="RPC sessions.create timed out"):
await client._rpc("sessions.create", {"model": "test-model"})
assert sent_frames[0]["method"] == "sessions.create"
assert client._pending == {}