fix: preserve valid cached OAuth tokens

This commit is contained in:
Peter Steinberger 2026-05-20 17:21:05 +01:00
parent 3ca4b5bae8
commit 31bbaa804f
No known key found for this signature in database
3 changed files with 99 additions and 0 deletions

View File

@ -6,6 +6,7 @@
- Make `generate-cli --bundle` artifacts deterministic by removing bundle-only paths/timestamps from embedded metadata and sorting generated tool/schema output. (Issue #180, thanks @imroc)
- Let daemon-managed OAuth servers reuse cached credentials for tool calls and tool listing after token expiry. (PR #182 / issue #181, thanks @bradhallett)
- Avoid restarting browser OAuth when an already-connected server has a still-valid cached access token. (Issue #179, thanks @jaigew and @StanAngeloff)
## [0.11.1] - 2026-05-14

View File

@ -1,5 +1,6 @@
import { auth as sdkAuth } from '@modelcontextprotocol/sdk/client/auth.js';
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
import type { OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { Logger } from '../logging.js';
import type { OAuthSession } from '../oauth.js';
@ -9,6 +10,7 @@ export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 300_000;
const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error');
const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error');
const MAX_OAUTH_ERROR_DETAIL_LENGTH = 1_200;
const PROACTIVE_TOKEN_SKEW_SECONDS = 60;
export interface OAuthCapableTransport extends Transport {
close(): Promise<void>;
@ -109,6 +111,15 @@ function hasErrorMarker(error: unknown, marker: symbol): boolean {
);
}
function hasUsableCachedAccessToken(tokens: OAuthTokens | undefined): boolean {
if (!tokens || typeof tokens.access_token !== 'string' || tokens.access_token.trim().length === 0) {
return false;
}
const stored = tokens as OAuthTokens & { expires_at?: number; expiresAt?: number };
const expiresAt = typeof stored.expires_at === 'number' ? stored.expires_at : stored.expiresAt;
return typeof expiresAt === 'number' && expiresAt > Math.floor(Date.now() / 1000) + PROACTIVE_TOKEN_SKEW_SECONDS;
}
export async function connectWithAuth(
client: Client,
transport: OAuthCapableTransport,
@ -239,6 +250,10 @@ async function completeProactiveAuthorization(
return;
}
try {
const cachedTokens = await session.provider.tokens?.();
if (hasUsableCachedAccessToken(cachedTokens)) {
return;
}
const result = await sdkAuth(session.provider, {
serverUrl: options.serverUrl,
fetchFn: options.fetchFn,

View File

@ -241,6 +241,67 @@ describe('connectWithAuth', () => {
expect(connectedTransport).toBe(transport);
});
it('skips proactive OAuth when cached tokens are still usable', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
const tokens = vi.fn(async () => ({
access_token: 'cached-token',
token_type: 'Bearer',
expires_at: Math.floor(Date.now() / 1000) + 3600,
}));
session.provider.tokens = tokens;
mocks.sdkAuth.mockResolvedValueOnce('REDIRECT');
const transport = new MockTransport();
const logger = createLogger();
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
serverName: 'courtlistener',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://courtlistener.example/mcp',
});
expect(tokens).toHaveBeenCalledTimes(1);
expect(mocks.sdkAuth).not.toHaveBeenCalled();
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(transport.calls).toEqual([]);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('runs proactive OAuth when cached tokens are expired', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
session.provider.tokens = vi.fn(async () => ({
access_token: 'expired-token',
token_type: 'Bearer',
expires_at: Math.floor(Date.now() / 1000) - 60,
}));
mocks.sdkAuth.mockResolvedValueOnce('AUTHORIZED');
const transport = new MockTransport();
const logger = createLogger();
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
serverName: 'courtlistener',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://courtlistener.example/mcp',
});
expect(mocks.sdkAuth).toHaveBeenCalledWith(session.provider, {
serverUrl: 'https://courtlistener.example/mcp',
fetchFn: undefined,
});
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(transport.calls).toEqual([]);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('marks proactive OAuth failures as OAuth flow errors', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
@ -261,6 +322,28 @@ describe('connectWithAuth', () => {
).rejects.toSatisfy((error: unknown) => error === authError && isOAuthFlowError(error));
});
it('marks cached token read failures during proactive OAuth as OAuth flow errors', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session } = createPendingAuthorizationSession();
const tokenError = new Error('malformed token cache');
session.provider.tokens = vi.fn(async () => {
throw tokenError;
});
const transport = new MockTransport();
const logger = createLogger();
await expect(
connectWithAuth(client, transport, session, logger, {
serverName: 'calendar',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://calendar.example/mcp',
})
).rejects.toSatisfy((error: unknown) => error === tokenError && isOAuthFlowError(error));
});
it('marks finishAuth failures as oauth flow errors', async () => {
const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed'));
const client = { connect } as unknown as Client;