fix(client): clean pending rpc on send failure
This commit is contained in:
parent
88ab0f5564
commit
ea17c715b3
@ -507,14 +507,17 @@ class GatewayClient:
|
||||
effective_timeout = timeout if timeout is not None else self.config.request_timeout
|
||||
future: asyncio.Future[dict[str, Any]] = asyncio.get_running_loop().create_future()
|
||||
self._pending[request_id] = future
|
||||
await self._ws.send(json.dumps(frame))
|
||||
try:
|
||||
await self._ws.send(json.dumps(frame))
|
||||
response = await asyncio.wait_for(future, timeout=effective_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending.pop(request_id, None)
|
||||
raise TimeoutError(
|
||||
f"RPC {method} timed out after {effective_timeout:.1f}s"
|
||||
)
|
||||
except Exception:
|
||||
self._pending.pop(request_id, None)
|
||||
raise
|
||||
|
||||
if not response.get("ok", False):
|
||||
error = response.get("error", {})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from websockets.datastructures import Headers
|
||||
@ -192,3 +193,36 @@ async def test_send_and_wait_collects_messages_that_arrive_after_final_state():
|
||||
transcript = await client.send_and_wait(session_key, "hello", timeout=1.0)
|
||||
|
||||
assert [message.text for message in transcript.assistant_messages] == ["Late but valid."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc_send_failure_cleans_pending_request():
|
||||
class FailingWebSocket:
|
||||
async def send(self, payload: str) -> None: # noqa: ARG002
|
||||
raise ConnectionError("socket closed")
|
||||
|
||||
client = GatewayClient(GatewayConfig(request_timeout=0.01))
|
||||
client._ws = FailingWebSocket() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ConnectionError, match="socket closed"):
|
||||
await client._rpc("sessions.create", {"model": "test-model"})
|
||||
|
||||
assert client._pending == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc_timeout_cleans_pending_request():
|
||||
sent_frames: list[dict[str, object]] = []
|
||||
|
||||
class SilentWebSocket:
|
||||
async def send(self, payload: str) -> None:
|
||||
sent_frames.append(json.loads(payload))
|
||||
|
||||
client = GatewayClient(GatewayConfig(request_timeout=0.01))
|
||||
client._ws = SilentWebSocket() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(TimeoutError, match="RPC sessions.create timed out"):
|
||||
await client._rpc("sessions.create", {"model": "test-model"})
|
||||
|
||||
assert sent_frames[0]["method"] == "sessions.create"
|
||||
assert client._pending == {}
|
||||
|
||||
@ -5,7 +5,7 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from clawbench.queue import JobQueue
|
||||
from clawbench.queue import Job, JobQueue, JobStatus, SubmissionRequest
|
||||
from clawbench.worker import GATEWAY_PORT, GATEWAY_PORT_SPACING, EvalWorker, JobProgressTracker, ParallelLane
|
||||
|
||||
|
||||
@ -28,6 +28,52 @@ class DummyTask:
|
||||
return [object()] * self._phases
|
||||
|
||||
|
||||
class FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.evaluating: list[str] = []
|
||||
self.finished: list[tuple[str, str]] = []
|
||||
self.failed: list[tuple[str, str]] = []
|
||||
self.progress: list[tuple[str, dict[str, object]]] = []
|
||||
|
||||
async def mark_evaluating(self, job_id: str) -> None:
|
||||
self.evaluating.append(job_id)
|
||||
|
||||
async def mark_finished(self, job_id: str, result_id: str) -> None:
|
||||
self.finished.append((job_id, result_id))
|
||||
|
||||
async def mark_failed(self, job_id: str, error: str) -> None:
|
||||
self.failed.append((job_id, error))
|
||||
|
||||
async def update_progress(self, job_id: str, **kwargs) -> None:
|
||||
self.progress.append((job_id, kwargs))
|
||||
|
||||
|
||||
class FakeBenchmarkResult:
|
||||
submission_id = "submission-1"
|
||||
overall_score = 0.82
|
||||
overall_pass_hat_k = 1.0
|
||||
|
||||
def model_dump(self):
|
||||
return {
|
||||
"submission_id": self.submission_id,
|
||||
"overall_score": self.overall_score,
|
||||
"overall_pass_hat_k": self.overall_pass_hat_k,
|
||||
}
|
||||
|
||||
|
||||
def make_job(*, status: JobStatus = JobStatus.PENDING, lanes: int = 1) -> Job:
|
||||
return Job(
|
||||
job_id="job-1",
|
||||
status=status,
|
||||
request=SubmissionRequest(
|
||||
model="anthropic/claude-sonnet-4-6",
|
||||
provider="anthropic",
|
||||
runs_per_task=1,
|
||||
max_parallel_lanes=lanes,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_configure_browser_runtime_sets_benchmark_safe_openclaw_config(monkeypatch):
|
||||
worker = EvalWorker(JobQueue())
|
||||
state_dir = Path("/tmp/test-openclaw-config-basic")
|
||||
@ -171,6 +217,85 @@ def test_materialize_lane_runtime_spaces_ports_and_copies_auth(tmp_path: Path, m
|
||||
assert (lane1.state_dir / "agents" / "main" / "agent" / "auth-profiles.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_job_finishes_when_optional_result_upload_fails(tmp_path: Path, monkeypatch):
|
||||
queue = FakeQueue()
|
||||
worker = EvalWorker(queue) # type: ignore[arg-type]
|
||||
cleanup_calls: list[str] = []
|
||||
|
||||
async def fake_run_serial_benchmark(job, tasks, progress): # noqa: ANN001
|
||||
progress.mark_serial(tasks[0].id, 0, stage="running")
|
||||
return FakeBenchmarkResult()
|
||||
|
||||
async def fake_upload_result(result): # noqa: ANN001
|
||||
raise RuntimeError("hub upload unavailable")
|
||||
|
||||
monkeypatch.setattr("clawbench.worker.RESULTS_DIR", tmp_path)
|
||||
monkeypatch.setattr(worker, "_load_job_tasks", lambda job: [DummyTask("t1", "tier1", "coding")])
|
||||
monkeypatch.setattr(worker, "_run_serial_benchmark", fake_run_serial_benchmark)
|
||||
monkeypatch.setattr(worker, "_stop_gateway", lambda: cleanup_calls.append("serial"))
|
||||
monkeypatch.setattr(worker, "_stop_parallel_gateways", lambda: cleanup_calls.append("parallel"))
|
||||
monkeypatch.setattr("clawbench.upload.upload_result", fake_upload_result)
|
||||
|
||||
await worker._process_job(make_job())
|
||||
|
||||
assert queue.evaluating == ["job-1"]
|
||||
assert queue.finished == [("job-1", "submission-1")]
|
||||
assert queue.failed == []
|
||||
assert (tmp_path / "submission-1.json").exists()
|
||||
assert cleanup_calls[-2:] == ["serial", "parallel"]
|
||||
assert worker._active_model == ""
|
||||
assert worker._serial_last_task_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_job_marks_failure_and_cleans_up_after_benchmark_error(monkeypatch):
|
||||
queue = FakeQueue()
|
||||
worker = EvalWorker(queue) # type: ignore[arg-type]
|
||||
cleanup_calls: list[str] = []
|
||||
|
||||
async def fail_run_serial_benchmark(job, tasks, progress): # noqa: ANN001
|
||||
raise RuntimeError("gateway died")
|
||||
|
||||
monkeypatch.setattr(worker, "_load_job_tasks", lambda job: [DummyTask("t1", "tier1", "coding")])
|
||||
monkeypatch.setattr(worker, "_run_serial_benchmark", fail_run_serial_benchmark)
|
||||
monkeypatch.setattr(worker, "_stop_gateway", lambda: cleanup_calls.append("serial"))
|
||||
monkeypatch.setattr(worker, "_stop_parallel_gateways", lambda: cleanup_calls.append("parallel"))
|
||||
|
||||
await worker._process_job(make_job())
|
||||
|
||||
assert queue.evaluating == ["job-1"]
|
||||
assert queue.finished == []
|
||||
assert queue.failed == [("job-1", "gateway died")]
|
||||
assert cleanup_calls[-2:] == ["serial", "parallel"]
|
||||
assert worker._active_model == ""
|
||||
assert worker._serial_last_task_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_job_does_not_reclaim_already_claimed_evaluating_job(tmp_path: Path, monkeypatch):
|
||||
queue = FakeQueue()
|
||||
worker = EvalWorker(queue) # type: ignore[arg-type]
|
||||
|
||||
async def fake_run_serial_benchmark(job, tasks, progress): # noqa: ANN001
|
||||
return FakeBenchmarkResult()
|
||||
|
||||
async def fake_upload_result(result): # noqa: ANN001
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("clawbench.worker.RESULTS_DIR", tmp_path)
|
||||
monkeypatch.setattr(worker, "_load_job_tasks", lambda job: [DummyTask("t1", "tier1", "coding")])
|
||||
monkeypatch.setattr(worker, "_run_serial_benchmark", fake_run_serial_benchmark)
|
||||
monkeypatch.setattr(worker, "_stop_gateway", lambda: None)
|
||||
monkeypatch.setattr(worker, "_stop_parallel_gateways", lambda: None)
|
||||
monkeypatch.setattr("clawbench.upload.upload_result", fake_upload_result)
|
||||
|
||||
await worker._process_job(make_job(status=JobStatus.EVALUATING))
|
||||
|
||||
assert queue.evaluating == []
|
||||
assert queue.finished == [("job-1", "submission-1")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_serial_benchmark_forwards_judge_score_gate(monkeypatch):
|
||||
queue = JobQueue()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user