Compare commits

...

1 Commits

Author SHA1 Message Date
Scott Hanselman
a1a114a319 refactor: extract WebSocketClientBase from Gateway/Node clients (#63)
Extract ~200 lines of duplicated WebSocket lifecycle code into a shared
abstract base class. Both OpenClawGatewayClient and WindowsNodeClient
now inherit from WebSocketClientBase.

Shared in base class:
- Connection lifecycle: ConnectAsync, ListenForMessagesAsync, ReconnectWithBackoffAsync
- SendRawAsync (thread-safe with TOCTOU protection)
- Dispose (defensive pattern, skip CTS dispose)
- Fields: WebSocket, URL, token, credentials, CTS, backoff array
- StatusChanged event with RaiseStatusChanged helper
- Constructor: URL normalization, credential extraction, validation

Subclass hooks (template method pattern):
- ProcessMessageAsync (abstract) - Gateway wraps sync, Node uses async
- ReceiveBufferSize (abstract) - Gateway 16KB, Node 64KB
- ClientRole (abstract) - for log messages
- OnConnectedAsync, OnDisconnected, OnError, OnDisposing (virtual)

Safety improvements (Node's safer patterns adopted everywhere):
- ObjectDisposedException catch in listen loop
- Post-delay CTS check in reconnect
- Try-catch around reconnect guard
- Constructor argument validation

20 new tests via TestWebSocketClient test double.
596 total tests pass (503 shared + 93 tray).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-17 22:22:11 -07:00
6 changed files with 573 additions and 399 deletions

View File

@ -1,27 +1,13 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace OpenClaw.Shared;
public class OpenClawGatewayClient : IDisposable
public class OpenClawGatewayClient : WebSocketClientBase
{
private ClientWebSocket? _webSocket;
private readonly string _gatewayUrl;
private readonly string _gatewayUrlForDisplay;
private readonly string _token;
private readonly string? _credentials;
private readonly IOpenClawLogger _logger;
private CancellationTokenSource _cts;
private bool _disposed;
private int _reconnectAttempts;
private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 };
// Tracked state
private readonly Dictionary<string, SessionInfo> _sessions = new();
private readonly Dictionary<string, GatewayNodeInfo> _nodes = new();
@ -45,8 +31,32 @@ public class OpenClawGatewayClient : IDisposable
_nodeListUnsupported = false;
}
protected override int ReceiveBufferSize => 16384;
protected override string ClientRole => "gateway";
protected override Task ProcessMessageAsync(string json)
{
ProcessMessage(json);
return Task.CompletedTask;
}
protected override Task OnConnectedAsync()
{
ResetUnsupportedMethodFlags();
return Task.CompletedTask;
}
protected override void OnDisconnected()
{
ClearPendingRequests();
}
protected override void OnDisposing()
{
ClearPendingRequests();
}
// Events
public event EventHandler<ConnectionStatus>? StatusChanged;
public event EventHandler<OpenClawNotification>? NotificationReceived;
public event EventHandler<AgentActivity>? ActivityChanged;
public event EventHandler<ChannelHealth[]>? ChannelHealthUpdated;
@ -59,63 +69,17 @@ public class OpenClawGatewayClient : IDisposable
public event EventHandler<SessionCommandResult>? SessionCommandCompleted;
public OpenClawGatewayClient(string gatewayUrl, string token, IOpenClawLogger? logger = null)
: base(gatewayUrl, token, logger)
{
_gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl);
_gatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl);
_token = token;
_credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl);
_logger = logger ?? NullLogger.Instance;
_cts = new CancellationTokenSource();
}
public async Task ConnectAsync()
{
try
{
StatusChanged?.Invoke(this, ConnectionStatus.Connecting);
_logger.Info($"Connecting to gateway: {_gatewayUrlForDisplay}");
_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
// Set Origin header based on gateway URL (convert ws/wss to http/https)
var uri = new Uri(_gatewayUrl);
var originScheme = uri.Scheme == "wss" ? "https" : "http";
var origin = $"{originScheme}://{uri.Host}:{uri.Port}";
_webSocket.Options.SetRequestHeader("Origin", origin);
if (!string.IsNullOrEmpty(_credentials))
{
var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials);
_webSocket.Options.SetRequestHeader(
"Authorization",
$"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}");
}
await _webSocket.ConnectAsync(uri, _cts.Token);
ResetUnsupportedMethodFlags();
_reconnectAttempts = 0;
_logger.Info("Gateway connected, waiting for challenge...");
// Don't send connect yet - wait for challenge event in ListenForMessagesAsync
_ = Task.Run(() => ListenForMessagesAsync(), _cts.Token);
}
catch (Exception ex)
{
_logger.Error("Connection failed", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
}
}
public async Task DisconnectAsync()
{
if (_webSocket?.State == WebSocketState.Open)
if (IsConnected)
{
try
{
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", CancellationToken.None);
await CloseWebSocketAsync();
}
catch (Exception ex)
{
@ -123,13 +87,13 @@ public class OpenClawGatewayClient : IDisposable
}
}
ClearPendingRequests();
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
RaiseStatusChanged(ConnectionStatus.Disconnected);
_logger.Info("Disconnected");
}
public async Task CheckHealthAsync()
{
if (_webSocket?.State != WebSocketState.Open)
if (!IsConnected)
{
await ReconnectWithBackoffAsync();
return;
@ -149,14 +113,14 @@ public class OpenClawGatewayClient : IDisposable
catch (Exception ex)
{
_logger.Error("Health check failed", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
RaiseStatusChanged(ConnectionStatus.Error);
await ReconnectWithBackoffAsync();
}
}
public async Task SendChatMessageAsync(string message)
{
if (_webSocket?.State != WebSocketState.Open)
if (!IsConnected)
throw new InvalidOperationException("Gateway connection is not open");
var req = new
@ -179,7 +143,7 @@ public class OpenClawGatewayClient : IDisposable
/// <summary>Request usage/context info from gateway (may not be supported on all gateways).</summary>
public async Task RequestUsageAsync()
{
if (_webSocket?.State != WebSocketState.Open) return;
if (!IsConnected) return;
try
{
if (_usageStatusUnsupported)
@ -270,7 +234,7 @@ public class OpenClawGatewayClient : IDisposable
/// <summary>Start a channel (telegram, whatsapp, etc).</summary>
public async Task<bool> StartChannelAsync(string channelName)
{
if (_webSocket?.State != WebSocketState.Open) return false;
if (!IsConnected) return false;
try
{
var req = new
@ -294,7 +258,7 @@ public class OpenClawGatewayClient : IDisposable
/// <summary>Stop a channel (telegram, whatsapp, etc).</summary>
public async Task<bool> StopChannelAsync(string channelName)
{
if (_webSocket?.State != WebSocketState.Open) return false;
if (!IsConnected) return false;
try
{
var req = new
@ -315,31 +279,6 @@ public class OpenClawGatewayClient : IDisposable
}
}
// --- Connection management ---
private async Task ReconnectWithBackoffAsync()
{
var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)];
_reconnectAttempts++;
_logger.Warn($"Reconnecting in {delay}ms (attempt {_reconnectAttempts})");
StatusChanged?.Invoke(this, ConnectionStatus.Connecting);
try
{
await Task.Delay(delay, _cts.Token);
_webSocket?.Dispose();
_webSocket = null;
await ConnectAsync();
}
catch (OperationCanceledException) { }
catch (Exception ex)
{
_logger.Error("Reconnect failed", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
// Don't recurse — the listen loop will trigger reconnect again
}
}
private async Task SendConnectMessageAsync(string? nonce = null)
{
// Use "cli" client ID for native apps - no browser security checks
@ -373,31 +312,9 @@ public class OpenClawGatewayClient : IDisposable
await SendRawAsync(JsonSerializer.Serialize(msg));
}
private async Task SendRawAsync(string message)
{
// Capture local reference to avoid TOCTOU race with reconnect/dispose
var ws = _webSocket;
if (ws?.State != WebSocketState.Open) return;
try
{
var bytes = Encoding.UTF8.GetBytes(message);
await ws.SendAsync(new ArraySegment<byte>(bytes),
WebSocketMessageType.Text, true, _cts.Token);
}
catch (ObjectDisposedException)
{
// WebSocket was disposed between state check and send
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState)
{
_logger.Warn($"WebSocket send failed (state changed): {ex.Message}");
}
}
private async Task SendTrackedRequestAsync(string method, object? parameters = null)
{
if (_webSocket?.State != WebSocketState.Open) return;
if (!IsConnected) return;
var requestId = Guid.NewGuid().ToString();
TrackPendingRequest(requestId, method);
@ -482,60 +399,6 @@ public class OpenClawGatewayClient : IDisposable
}
}
// --- Message loop ---
private async Task ListenForMessagesAsync()
{
var buffer = new byte[16384]; // Larger buffer for big events
var sb = new StringBuilder();
try
{
while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(
new ArraySegment<byte>(buffer), _cts.Token);
if (result.MessageType == WebSocketMessageType.Text)
{
sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count));
if (result.EndOfMessage)
{
ProcessMessage(sb.ToString());
sb.Clear();
}
}
else if (result.MessageType == WebSocketMessageType.Close)
{
var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown";
var closeDesc = _webSocket.CloseStatusDescription ?? "no description";
_logger.Info($"Server closed connection: {closeStatus} - {closeDesc}");
ClearPendingRequests();
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.Warn("Connection closed prematurely");
ClearPendingRequests();
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
}
catch (OperationCanceledException) { }
catch (Exception ex)
{
_logger.Error("Listen error", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
}
// Auto-reconnect if not intentionally disposed
if (!_disposed && !_cts.Token.IsCancellationRequested)
{
await ReconnectWithBackoffAsync();
}
}
// --- Message processing ---
private void ProcessMessage(string json)
@ -594,7 +457,7 @@ public class OpenClawGatewayClient : IDisposable
if (payload.TryGetProperty("type", out var t) && t.GetString() == "hello-ok")
{
_logger.Info("Handshake complete (hello-ok)");
StatusChanged?.Invoke(this, ConnectionStatus.Connected);
RaiseStatusChanged(ConnectionStatus.Connected);
// Request initial state after handshake
_ = Task.Run(async () =>
@ -1738,21 +1601,4 @@ public class OpenClawGatewayClient : IDisposable
if (string.IsNullOrEmpty(text) || text.Length <= maxLen) return text;
return text[..(maxLen - 1)] + "…";
}
public void Dispose()
{
if (_disposed) return;
_disposed = true;
try { _cts.Cancel(); } catch { }
ClearPendingRequests();
var ws = _webSocket;
_webSocket = null;
try { ws?.Dispose(); } catch { }
// Don't dispose _cts immediately — listen loop or reconnect may still reference it.
// It will be GC'd after all pending tasks complete.
}
}

View File

@ -0,0 +1,269 @@
using System;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace OpenClaw.Shared;
/// <summary>
/// Abstract base class for WebSocket-based gateway clients.
/// Extracts shared connection lifecycle: connect, listen, reconnect, send, dispose.
/// Subclasses implement message processing and provide configuration via abstract members.
/// </summary>
public abstract class WebSocketClientBase : IDisposable
{
private ClientWebSocket? _webSocket;
private readonly string _gatewayUrl;
private readonly string? _credentials;
private CancellationTokenSource _cts;
private bool _disposed;
private int _reconnectAttempts;
private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 };
protected readonly string _token;
protected readonly IOpenClawLogger _logger;
/// <summary>Gateway URL with credentials stripped, safe for logging/display.</summary>
protected string GatewayUrlForDisplay { get; }
/// <summary>Whether Dispose has been called.</summary>
protected bool IsDisposed => _disposed;
/// <summary>Whether the WebSocket is currently open and connected.</summary>
protected bool IsConnected => _webSocket?.State == WebSocketState.Open;
/// <summary>Cancellation token tied to this client's lifetime.</summary>
protected CancellationToken CancellationToken => _cts.Token;
// Events
public event EventHandler<ConnectionStatus>? StatusChanged;
// --- Abstract members (subclass MUST implement) ---
/// <summary>
/// Process a received WebSocket text message. Called from the listen loop.
/// Gateway wraps its sync ProcessMessage with Task.CompletedTask;
/// Node directly uses its async implementation.
/// </summary>
protected abstract Task ProcessMessageAsync(string json);
/// <summary>Receive buffer size in bytes. Gateway: 16384, Node: 65536.</summary>
protected abstract int ReceiveBufferSize { get; }
/// <summary>Client role for log messages, e.g. "gateway" or "node".</summary>
protected abstract string ClientRole { get; }
// --- Virtual hooks (subclass MAY override) ---
/// <summary>Called after WebSocket connects, before the listen loop starts.</summary>
protected virtual Task OnConnectedAsync() => Task.CompletedTask;
/// <summary>Called when the server closes the connection or it drops.</summary>
protected virtual void OnDisconnected() { }
/// <summary>Called on unrecoverable listen-loop errors.</summary>
protected virtual void OnError(Exception ex) { }
/// <summary>Called at the start of Dispose, before CTS cancellation.</summary>
protected virtual void OnDisposing() { }
protected WebSocketClientBase(string gatewayUrl, string token, IOpenClawLogger? logger = null)
{
if (string.IsNullOrEmpty(gatewayUrl))
throw new ArgumentException("Gateway URL is required.", nameof(gatewayUrl));
if (string.IsNullOrEmpty(token))
throw new ArgumentException("Token is required.", nameof(token));
_gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl);
GatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl);
_token = token;
_credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl);
_logger = logger ?? NullLogger.Instance;
_cts = new CancellationTokenSource();
}
public async Task ConnectAsync()
{
try
{
RaiseStatusChanged(ConnectionStatus.Connecting);
_logger.Info($"Connecting to {ClientRole}: {GatewayUrlForDisplay}");
_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
// Set Origin header (convert ws/wss to http/https)
var uri = new Uri(_gatewayUrl);
var originScheme = uri.Scheme == "wss" ? "https" : "http";
var origin = $"{originScheme}://{uri.Host}:{uri.Port}";
_webSocket.Options.SetRequestHeader("Origin", origin);
if (!string.IsNullOrEmpty(_credentials))
{
var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials);
_webSocket.Options.SetRequestHeader(
"Authorization",
$"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}");
}
await _webSocket.ConnectAsync(uri, _cts.Token);
_reconnectAttempts = 0;
_logger.Info($"{ClientRole} connected, waiting for challenge...");
await OnConnectedAsync();
_ = Task.Run(() => ListenForMessagesAsync(), _cts.Token);
}
catch (Exception ex)
{
_logger.Error($"{ClientRole} connection failed", ex);
RaiseStatusChanged(ConnectionStatus.Error);
}
}
private async Task ListenForMessagesAsync()
{
var buffer = new byte[ReceiveBufferSize];
var sb = new StringBuilder();
try
{
while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(
new ArraySegment<byte>(buffer), _cts.Token);
if (result.MessageType == WebSocketMessageType.Text)
{
sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count));
if (result.EndOfMessage)
{
await ProcessMessageAsync(sb.ToString());
sb.Clear();
}
}
else if (result.MessageType == WebSocketMessageType.Close)
{
var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown";
var closeDesc = _webSocket.CloseStatusDescription ?? "no description";
_logger.Info($"Server closed connection: {closeStatus} - {closeDesc}");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.Warn("Connection closed prematurely");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
}
catch (OperationCanceledException) { }
catch (ObjectDisposedException) { /* CTS or WebSocket disposed during shutdown */ }
catch (Exception ex)
{
_logger.Error($"{ClientRole} listen error", ex);
OnError(ex);
RaiseStatusChanged(ConnectionStatus.Error);
}
// Auto-reconnect if not intentionally disposed
if (!_disposed)
{
try
{
if (!_cts.Token.IsCancellationRequested)
{
await ReconnectWithBackoffAsync();
}
}
catch (ObjectDisposedException) { /* CTS disposed during check */ }
}
}
protected async Task ReconnectWithBackoffAsync()
{
var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)];
_reconnectAttempts++;
_logger.Warn($"{ClientRole} reconnecting in {delay}ms (attempt {_reconnectAttempts})");
RaiseStatusChanged(ConnectionStatus.Connecting);
try
{
await Task.Delay(delay, _cts.Token);
// Check cancellation after delay
if (_cts.Token.IsCancellationRequested) return;
// Safely dispose old socket
var oldSocket = _webSocket;
_webSocket = null;
try { oldSocket?.Dispose(); } catch { /* ignore dispose errors */ }
await ConnectAsync();
}
catch (OperationCanceledException) { }
catch (Exception ex)
{
_logger.Error($"{ClientRole} reconnect failed", ex);
RaiseStatusChanged(ConnectionStatus.Error);
}
}
/// <summary>Send a text message over the WebSocket. Thread-safe.</summary>
protected async Task SendRawAsync(string message)
{
// Capture local reference to avoid TOCTOU race with reconnect/dispose
var ws = _webSocket;
if (ws?.State != WebSocketState.Open) return;
try
{
var bytes = Encoding.UTF8.GetBytes(message);
await ws.SendAsync(new ArraySegment<byte>(bytes),
WebSocketMessageType.Text, true, _cts.Token);
}
catch (ObjectDisposedException)
{
// WebSocket was disposed between state check and send
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState)
{
_logger.Warn($"WebSocket send failed (state changed): {ex.Message}");
}
}
/// <summary>Gracefully close the WebSocket connection.</summary>
protected async Task CloseWebSocketAsync()
{
var ws = _webSocket;
if (ws?.State == WebSocketState.Open)
{
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", System.Threading.CancellationToken.None);
}
}
/// <summary>Fire the StatusChanged event. Use this instead of directly invoking the event.</summary>
protected void RaiseStatusChanged(ConnectionStatus status)
=> StatusChanged?.Invoke(this, status);
public void Dispose()
{
if (_disposed) return;
_disposed = true;
OnDisposing();
try { _cts.Cancel(); } catch { }
var ws = _webSocket;
_webSocket = null;
try { ws?.Dispose(); } catch { }
// Don't dispose _cts immediately — listen loop or reconnect may still reference it.
// It will be GC'd after all pending tasks complete.
}
}

View File

@ -1,10 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace OpenClaw.Shared;
@ -13,19 +10,9 @@ namespace OpenClaw.Shared;
/// Windows Node client - extends gateway connection to act as a node
/// Supports both operator (existing) and node (new) roles
/// </summary>
public class WindowsNodeClient : IDisposable
public class WindowsNodeClient : WebSocketClientBase
{
private ClientWebSocket? _webSocket;
private readonly string _gatewayUrl;
private readonly string _gatewayUrlForDisplay;
private readonly string _token;
private readonly string? _credentials;
private readonly IOpenClawLogger _logger;
private readonly DeviceIdentity _deviceIdentity;
private CancellationTokenSource _cts;
private bool _disposed;
private int _reconnectAttempts;
private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 };
// Node capabilities registry
private readonly List<INodeCapability> _capabilities = new();
@ -38,13 +25,12 @@ public class WindowsNodeClient : IDisposable
private bool _isPendingApproval; // True when connected but awaiting pairing approval
// Events
public event EventHandler<ConnectionStatus>? StatusChanged;
public event EventHandler<NodeInvokeRequest>? InvokeReceived;
public event EventHandler<PairingStatusEventArgs>? PairingStatusChanged;
public bool IsConnected => _isConnected;
public new bool IsConnected => _isConnected;
public string? NodeId => _nodeId;
public string GatewayUrl => _gatewayUrlForDisplay;
public string GatewayUrl => GatewayUrlForDisplay;
public IReadOnlyList<INodeCapability> Capabilities => _capabilities;
/// <summary>True if connected but waiting for pairing approval on gateway</summary>
@ -61,15 +47,12 @@ public class WindowsNodeClient : IDisposable
/// <summary>Full device ID for approval command</summary>
public string FullDeviceId => _deviceIdentity.DeviceId;
protected override int ReceiveBufferSize => 65536;
protected override string ClientRole => "node";
public WindowsNodeClient(string gatewayUrl, string token, string dataPath, IOpenClawLogger? logger = null)
: base(gatewayUrl, token, logger)
{
_gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl);
_gatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl);
_token = token;
_credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl);
_logger = logger ?? NullLogger.Instance;
_cts = new CancellationTokenSource();
// Initialize device identity
_deviceIdentity = new DeviceIdentity(dataPath, _logger);
_deviceIdentity.Initialize();
@ -77,7 +60,7 @@ public class WindowsNodeClient : IDisposable
// Initialize registration
_registration = new NodeRegistration
{
Id = _deviceIdentity.DeviceId, // Use device ID from keypair
Id = _deviceIdentity.DeviceId,
Version = "1.0.0",
Platform = "windows",
DisplayName = $"Windows Node ({Environment.MachineName})"
@ -115,132 +98,21 @@ public class WindowsNodeClient : IDisposable
_registration.Permissions[permission] = value;
}
/// <summary>
/// Connect to gateway as a node
/// </summary>
public async Task ConnectAsync()
{
try
{
StatusChanged?.Invoke(this, ConnectionStatus.Connecting);
_logger.Info($"Connecting to gateway as node: {_gatewayUrlForDisplay}");
_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
// Set Origin header
var uri = new Uri(_gatewayUrl);
var originScheme = uri.Scheme == "wss" ? "https" : "http";
var origin = $"{originScheme}://{uri.Host}:{uri.Port}";
_webSocket.Options.SetRequestHeader("Origin", origin);
if (!string.IsNullOrEmpty(_credentials))
{
var authCredentials = GatewayUrlHelper.DecodeCredentials(_credentials);
_webSocket.Options.SetRequestHeader(
"Authorization",
$"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(authCredentials))}");
}
await _webSocket.ConnectAsync(uri, _cts.Token);
_reconnectAttempts = 0;
_logger.Info("Node connected, waiting for challenge...");
// Start message loop
_ = Task.Run(() => ListenForMessagesAsync(), _cts.Token);
}
catch (Exception ex)
{
_logger.Error("Node connection failed", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
}
}
/// <summary>
/// Disconnect from gateway
/// </summary>
public async Task DisconnectAsync()
public Task DisconnectAsync()
{
_isConnected = false;
if (_webSocket?.State == WebSocketState.Open)
{
try
{
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", CancellationToken.None);
}
catch (Exception ex)
{
_logger.Warn($"Error during disconnect: {ex.Message}");
}
}
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
Dispose();
RaiseStatusChanged(ConnectionStatus.Disconnected);
_logger.Info("Node disconnected");
return Task.CompletedTask;
}
// --- Message handling ---
private async Task ListenForMessagesAsync()
{
var buffer = new byte[65536]; // Large buffer for image data
var sb = new StringBuilder();
try
{
while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(
new ArraySegment<byte>(buffer), _cts.Token);
if (result.MessageType == WebSocketMessageType.Text)
{
sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count));
if (result.EndOfMessage)
{
await ProcessMessageAsync(sb.ToString());
sb.Clear();
}
}
else if (result.MessageType == WebSocketMessageType.Close)
{
_logger.Info("Server closed connection");
_isConnected = false;
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.Warn("Connection closed prematurely");
_isConnected = false;
StatusChanged?.Invoke(this, ConnectionStatus.Disconnected);
}
catch (OperationCanceledException) { }
catch (ObjectDisposedException) { /* CTS was disposed */ }
catch (Exception ex)
{
_logger.Error("Node listen error", ex);
_isConnected = false;
StatusChanged?.Invoke(this, ConnectionStatus.Error);
}
// Auto-reconnect (with extra safety checks)
if (!_disposed)
{
try
{
if (!_cts.Token.IsCancellationRequested)
{
await ReconnectWithBackoffAsync();
}
}
catch (ObjectDisposedException) { /* CTS was disposed during check */ }
}
}
private async Task ProcessMessageAsync(string json)
protected override async Task ProcessMessageAsync(string json)
{
try
{
@ -618,7 +490,7 @@ public class WindowsNodeClient : IDisposable
_deviceIdentity.DeviceId));
}
StatusChanged?.Invoke(this, ConnectionStatus.Connected);
RaiseStatusChanged(ConnectionStatus.Connected);
}
// Handle errors
@ -638,7 +510,7 @@ public class WindowsNodeClient : IDisposable
}
}
_logger.Error($"Node registration failed: {error} (code: {errorCode})");
StatusChanged?.Invoke(this, ConnectionStatus.Error);
RaiseStatusChanged(ConnectionStatus.Error);
}
}
@ -790,70 +662,13 @@ public class WindowsNodeClient : IDisposable
await SendRawAsync(JsonSerializer.Serialize(msg));
}
private async Task SendRawAsync(string message)
protected override void OnDisconnected()
{
// Capture local reference to avoid race conditions
var ws = _webSocket;
if (ws?.State != WebSocketState.Open) return;
try
{
var bytes = Encoding.UTF8.GetBytes(message);
await ws.SendAsync(new ArraySegment<byte>(bytes),
WebSocketMessageType.Text, true, _cts.Token);
}
catch (ObjectDisposedException)
{
// WebSocket was disposed between check and send - ignore
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState)
{
// WebSocket state changed - ignore
_logger.Warn($"WebSocket send failed (state changed): {ex.Message}");
}
_isConnected = false;
}
private async Task ReconnectWithBackoffAsync()
protected override void OnError(Exception ex)
{
var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)];
_reconnectAttempts++;
_logger.Warn($"Node reconnecting in {delay}ms (attempt {_reconnectAttempts})");
StatusChanged?.Invoke(this, ConnectionStatus.Connecting);
try
{
await Task.Delay(delay, _cts.Token);
// Check cancellation after delay
if (_cts.Token.IsCancellationRequested) return;
// Safely dispose old socket
var oldSocket = _webSocket;
_webSocket = null;
try { oldSocket?.Dispose(); } catch { /* ignore dispose errors */ }
await ConnectAsync();
}
catch (OperationCanceledException) { }
catch (Exception ex)
{
_logger.Error("Node reconnect failed", ex);
StatusChanged?.Invoke(this, ConnectionStatus.Error);
}
}
public void Dispose()
{
if (_disposed) return;
_disposed = true;
try { _cts.Cancel(); } catch { /* ignore */ }
var ws = _webSocket;
_webSocket = null;
try { ws?.Dispose(); } catch { /* ignore */ }
// Don't dispose _cts immediately — reconnect loop may still reference it.
// It will be GC'd after all pending tasks complete.
_isConnected = false;
}
}

View File

@ -609,7 +609,7 @@ public class OpenClawGatewayClientTests
{
var client = new OpenClawGatewayClient(inputUrl, "test-token");
var field = typeof(OpenClawGatewayClient).GetField(
var field = typeof(OpenClawGatewayClient).BaseType?.GetField(
"_gatewayUrl",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var actualUrl = field?.GetValue(client) as string;

View File

@ -0,0 +1,244 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Xunit;
namespace OpenClaw.Shared.Tests;
/// <summary>
/// Concrete test double for WebSocketClientBase.
/// Exposes hooks and tracking for unit testing base class behavior.
/// </summary>
public class TestWebSocketClient : WebSocketClientBase
{
public List<string> ProcessedMessages { get; } = new();
public int OnConnectedCallCount { get; private set; }
public int OnDisconnectedCallCount { get; private set; }
public int OnErrorCallCount { get; private set; }
public Exception? LastError { get; private set; }
public int OnDisposingCallCount { get; private set; }
protected override int ReceiveBufferSize => 8192;
protected override string ClientRole => "test";
public TestWebSocketClient(string gatewayUrl, string token, IOpenClawLogger? logger = null)
: base(gatewayUrl, token, logger) { }
protected override Task ProcessMessageAsync(string json)
{
ProcessedMessages.Add(json);
return Task.CompletedTask;
}
protected override Task OnConnectedAsync()
{
OnConnectedCallCount++;
return Task.CompletedTask;
}
protected override void OnDisconnected()
{
OnDisconnectedCallCount++;
}
protected override void OnError(Exception ex)
{
OnErrorCallCount++;
LastError = ex;
}
protected override void OnDisposing()
{
OnDisposingCallCount++;
}
// Expose protected members for testing
public void TestRaiseStatusChanged(ConnectionStatus status)
=> RaiseStatusChanged(status);
public bool TestIsDisposed => IsDisposed;
public string TestGatewayUrlForDisplay => GatewayUrlForDisplay;
public string TestToken => _token;
public IOpenClawLogger TestLogger => _logger;
}
public class WebSocketClientBaseTests
{
private readonly TestLogger _logger = new();
[Theory]
[InlineData("http://localhost:18789", "ws://localhost:18789")]
[InlineData("https://gateway.example.com", "wss://gateway.example.com")]
[InlineData("ws://localhost:18789", "ws://localhost:18789")]
[InlineData("wss://gateway.example.com", "wss://gateway.example.com")]
public void Constructor_NormalizesUrl(string input, string _)
{
var client = new TestWebSocketClient(input, "test-token", _logger);
// GatewayUrlForDisplay is the sanitized version — just verify it's set
Assert.NotNull(client.TestGatewayUrlForDisplay);
Assert.DoesNotContain("@", client.TestGatewayUrlForDisplay); // credentials stripped
client.Dispose();
}
[Fact]
public void Constructor_StoresToken()
{
var client = new TestWebSocketClient("ws://localhost:18789", "my-token", _logger);
Assert.Equal("my-token", client.TestToken);
client.Dispose();
}
[Fact]
public void Constructor_UsesNullLoggerWhenNotProvided()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token");
Assert.NotNull(client.TestLogger);
client.Dispose();
}
[Fact]
public void Constructor_ThrowsOnNullUrl()
{
Assert.Throws<ArgumentException>(() =>
new TestWebSocketClient(null!, "token", _logger));
}
[Fact]
public void Constructor_ThrowsOnEmptyUrl()
{
Assert.Throws<ArgumentException>(() =>
new TestWebSocketClient("", "token", _logger));
}
[Fact]
public void Constructor_ThrowsOnNullToken()
{
Assert.Throws<ArgumentException>(() =>
new TestWebSocketClient("ws://localhost", null!, _logger));
}
[Fact]
public void Constructor_ThrowsOnEmptyToken()
{
Assert.Throws<ArgumentException>(() =>
new TestWebSocketClient("ws://localhost", "", _logger));
}
[Fact]
public void Constructor_WithCredentialUrl_StripsFromDisplay()
{
var client = new TestWebSocketClient("ws://user:pass@localhost:18789", "token", _logger);
Assert.DoesNotContain("pass", client.TestGatewayUrlForDisplay);
client.Dispose();
}
[Fact]
public void Dispose_SetsIsDisposed()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
Assert.False(client.TestIsDisposed);
client.Dispose();
Assert.True(client.TestIsDisposed);
}
[Fact]
public void Dispose_IsIdempotent()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
client.Dispose();
client.Dispose(); // second call should not throw
Assert.True(client.TestIsDisposed);
Assert.Equal(1, client.OnDisposingCallCount); // hook called only once
}
[Fact]
public void Dispose_CallsOnDisposingHook()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
client.Dispose();
Assert.Equal(1, client.OnDisposingCallCount);
}
[Fact]
public void RaiseStatusChanged_FiresEvent()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
ConnectionStatus? received = null;
client.StatusChanged += (_, status) => received = status;
client.TestRaiseStatusChanged(ConnectionStatus.Connecting);
Assert.Equal(ConnectionStatus.Connecting, received);
client.Dispose();
}
[Fact]
public void RaiseStatusChanged_WithNoSubscribers_DoesNotThrow()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
client.TestRaiseStatusChanged(ConnectionStatus.Connected); // no subscribers — should not throw
client.Dispose();
}
[Fact]
public void RaiseStatusChanged_MultipleSubscribers_AllNotified()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
var statuses = new List<ConnectionStatus>();
client.StatusChanged += (_, s) => statuses.Add(s);
client.StatusChanged += (_, s) => statuses.Add(s);
client.TestRaiseStatusChanged(ConnectionStatus.Error);
Assert.Equal(2, statuses.Count);
Assert.All(statuses, s => Assert.Equal(ConnectionStatus.Error, s));
client.Dispose();
}
[Fact]
public void IsConnected_FalseBeforeConnect()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
// Reflection to check IsConnected on the base
var prop = typeof(WebSocketClientBase).GetProperty("IsConnected",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var isConnected = (bool)prop!.GetValue(client)!;
Assert.False(isConnected);
client.Dispose();
}
[Fact]
public void IsConnected_FalseAfterDispose()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
client.Dispose();
var prop = typeof(WebSocketClientBase).GetProperty("IsConnected",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var isConnected = (bool)prop!.GetValue(client)!;
Assert.False(isConnected);
}
[Fact]
public async Task ConnectAsync_RaisesStatusChangedConnecting()
{
var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger);
var statuses = new List<ConnectionStatus>();
client.StatusChanged += (_, s) => statuses.Add(s);
// ConnectAsync will fail (no real server) but should still fire Connecting then Error
await client.ConnectAsync();
Assert.Contains(ConnectionStatus.Connecting, statuses);
Assert.Contains(ConnectionStatus.Error, statuses);
client.Dispose();
}
}
public class TestLogger : IOpenClawLogger
{
public List<string> Logs { get; } = new();
public void Info(string message) => Logs.Add($"INFO: {message}");
public void Debug(string message) => Logs.Add($"DEBUG: {message}");
public void Warn(string message) => Logs.Add($"WARN: {message}");
public void Error(string message, Exception? ex = null) => Logs.Add($"ERROR: {message}");
}

View File

@ -19,7 +19,7 @@ public class WindowsNodeClientTests
try
{
using var client = new WindowsNodeClient(inputUrl, "test-token", dataPath);
var field = typeof(WindowsNodeClient).GetField(
var field = typeof(WindowsNodeClient).BaseType?.GetField(
"_gatewayUrl",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var actualUrl = field?.GetValue(client) as string;