diff --git a/CHANGELOG.md b/CHANGELOG.md index e3c7ce1..35102ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Make `generate-cli --bundle` artifacts deterministic by removing bundle-only paths/timestamps from embedded metadata and sorting generated tool/schema output. (Issue #180, thanks @imroc) - Let daemon-managed OAuth servers reuse cached credentials for tool calls and tool listing after token expiry. (PR #182 / issue #181, thanks @bradhallett) +- Avoid restarting browser OAuth when an already-connected server has a still-valid cached access token. (Issue #179, thanks @jaigew and @StanAngeloff) ## [0.11.1] - 2026-05-14 diff --git a/src/runtime/oauth.ts b/src/runtime/oauth.ts index 6038d21..f7210d7 100644 --- a/src/runtime/oauth.ts +++ b/src/runtime/oauth.ts @@ -1,5 +1,6 @@ import { auth as sdkAuth } from '@modelcontextprotocol/sdk/client/auth.js'; import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import type { OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Logger } from '../logging.js'; import type { OAuthSession } from '../oauth.js'; @@ -9,6 +10,7 @@ export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 300_000; const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error'); const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error'); const MAX_OAUTH_ERROR_DETAIL_LENGTH = 1_200; +const PROACTIVE_TOKEN_SKEW_SECONDS = 60; export interface OAuthCapableTransport extends Transport { close(): Promise; @@ -109,6 +111,15 @@ function hasErrorMarker(error: unknown, marker: symbol): boolean { ); } +function hasUsableCachedAccessToken(tokens: OAuthTokens | undefined): boolean { + if (!tokens || typeof tokens.access_token !== 'string' || tokens.access_token.trim().length === 0) { + return false; + } + const stored = tokens as OAuthTokens & { expires_at?: number; expiresAt?: number }; + const expiresAt = typeof stored.expires_at === 'number' ? stored.expires_at : stored.expiresAt; + return typeof expiresAt === 'number' && expiresAt > Math.floor(Date.now() / 1000) + PROACTIVE_TOKEN_SKEW_SECONDS; +} + export async function connectWithAuth( client: Client, transport: OAuthCapableTransport, @@ -239,6 +250,10 @@ async function completeProactiveAuthorization( return; } try { + const cachedTokens = await session.provider.tokens?.(); + if (hasUsableCachedAccessToken(cachedTokens)) { + return; + } const result = await sdkAuth(session.provider, { serverUrl: options.serverUrl, fetchFn: options.fetchFn, diff --git a/tests/runtime-oauth-connect.test.ts b/tests/runtime-oauth-connect.test.ts index 11245c9..cc9583f 100644 --- a/tests/runtime-oauth-connect.test.ts +++ b/tests/runtime-oauth-connect.test.ts @@ -241,6 +241,67 @@ describe('connectWithAuth', () => { expect(connectedTransport).toBe(transport); }); + it('skips proactive OAuth when cached tokens are still usable', async () => { + const connect = vi.fn().mockResolvedValueOnce(undefined); + const client = { connect } as unknown as Client; + const { session, waitForAuthorizationCode } = createPendingAuthorizationSession(); + const tokens = vi.fn(async () => ({ + access_token: 'cached-token', + token_type: 'Bearer', + expires_at: Math.floor(Date.now() / 1000) + 3600, + })); + session.provider.tokens = tokens; + mocks.sdkAuth.mockResolvedValueOnce('REDIRECT'); + + const transport = new MockTransport(); + const logger = createLogger(); + + const connectedTransport = await connectWithAuth(client, transport, session, logger, { + serverName: 'courtlistener', + maxAttempts: 1, + oauthTimeoutMs: 5000, + serverUrl: 'https://courtlistener.example/mcp', + }); + + expect(tokens).toHaveBeenCalledTimes(1); + expect(mocks.sdkAuth).not.toHaveBeenCalled(); + expect(waitForAuthorizationCode).not.toHaveBeenCalled(); + expect(transport.calls).toEqual([]); + expect(session.close).toHaveBeenCalled(); + expect(connectedTransport).toBe(transport); + }); + + it('runs proactive OAuth when cached tokens are expired', async () => { + const connect = vi.fn().mockResolvedValueOnce(undefined); + const client = { connect } as unknown as Client; + const { session, waitForAuthorizationCode } = createPendingAuthorizationSession(); + session.provider.tokens = vi.fn(async () => ({ + access_token: 'expired-token', + token_type: 'Bearer', + expires_at: Math.floor(Date.now() / 1000) - 60, + })); + mocks.sdkAuth.mockResolvedValueOnce('AUTHORIZED'); + + const transport = new MockTransport(); + const logger = createLogger(); + + const connectedTransport = await connectWithAuth(client, transport, session, logger, { + serverName: 'courtlistener', + maxAttempts: 1, + oauthTimeoutMs: 5000, + serverUrl: 'https://courtlistener.example/mcp', + }); + + expect(mocks.sdkAuth).toHaveBeenCalledWith(session.provider, { + serverUrl: 'https://courtlistener.example/mcp', + fetchFn: undefined, + }); + expect(waitForAuthorizationCode).not.toHaveBeenCalled(); + expect(transport.calls).toEqual([]); + expect(session.close).toHaveBeenCalled(); + expect(connectedTransport).toBe(transport); + }); + it('marks proactive OAuth failures as OAuth flow errors', async () => { const connect = vi.fn().mockResolvedValueOnce(undefined); const client = { connect } as unknown as Client; @@ -261,6 +322,28 @@ describe('connectWithAuth', () => { ).rejects.toSatisfy((error: unknown) => error === authError && isOAuthFlowError(error)); }); + it('marks cached token read failures during proactive OAuth as OAuth flow errors', async () => { + const connect = vi.fn().mockResolvedValueOnce(undefined); + const client = { connect } as unknown as Client; + const { session } = createPendingAuthorizationSession(); + const tokenError = new Error('malformed token cache'); + session.provider.tokens = vi.fn(async () => { + throw tokenError; + }); + + const transport = new MockTransport(); + const logger = createLogger(); + + await expect( + connectWithAuth(client, transport, session, logger, { + serverName: 'calendar', + maxAttempts: 1, + oauthTimeoutMs: 5000, + serverUrl: 'https://calendar.example/mcp', + }) + ).rejects.toSatisfy((error: unknown) => error === tokenError && isOAuthFlowError(error)); + }); + it('marks finishAuth failures as oauth flow errors', async () => { const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed')); const client = { connect } as unknown as Client;