fix: preserve valid cached OAuth tokens
This commit is contained in:
parent
3ca4b5bae8
commit
31bbaa804f
@ -6,6 +6,7 @@
|
||||
|
||||
- Make `generate-cli --bundle` artifacts deterministic by removing bundle-only paths/timestamps from embedded metadata and sorting generated tool/schema output. (Issue #180, thanks @imroc)
|
||||
- Let daemon-managed OAuth servers reuse cached credentials for tool calls and tool listing after token expiry. (PR #182 / issue #181, thanks @bradhallett)
|
||||
- Avoid restarting browser OAuth when an already-connected server has a still-valid cached access token. (Issue #179, thanks @jaigew and @StanAngeloff)
|
||||
|
||||
## [0.11.1] - 2026-05-14
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { auth as sdkAuth } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import type { OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { Logger } from '../logging.js';
|
||||
import type { OAuthSession } from '../oauth.js';
|
||||
@ -9,6 +10,7 @@ export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 300_000;
|
||||
const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error');
|
||||
const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error');
|
||||
const MAX_OAUTH_ERROR_DETAIL_LENGTH = 1_200;
|
||||
const PROACTIVE_TOKEN_SKEW_SECONDS = 60;
|
||||
|
||||
export interface OAuthCapableTransport extends Transport {
|
||||
close(): Promise<void>;
|
||||
@ -109,6 +111,15 @@ function hasErrorMarker(error: unknown, marker: symbol): boolean {
|
||||
);
|
||||
}
|
||||
|
||||
function hasUsableCachedAccessToken(tokens: OAuthTokens | undefined): boolean {
|
||||
if (!tokens || typeof tokens.access_token !== 'string' || tokens.access_token.trim().length === 0) {
|
||||
return false;
|
||||
}
|
||||
const stored = tokens as OAuthTokens & { expires_at?: number; expiresAt?: number };
|
||||
const expiresAt = typeof stored.expires_at === 'number' ? stored.expires_at : stored.expiresAt;
|
||||
return typeof expiresAt === 'number' && expiresAt > Math.floor(Date.now() / 1000) + PROACTIVE_TOKEN_SKEW_SECONDS;
|
||||
}
|
||||
|
||||
export async function connectWithAuth(
|
||||
client: Client,
|
||||
transport: OAuthCapableTransport,
|
||||
@ -239,6 +250,10 @@ async function completeProactiveAuthorization(
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const cachedTokens = await session.provider.tokens?.();
|
||||
if (hasUsableCachedAccessToken(cachedTokens)) {
|
||||
return;
|
||||
}
|
||||
const result = await sdkAuth(session.provider, {
|
||||
serverUrl: options.serverUrl,
|
||||
fetchFn: options.fetchFn,
|
||||
|
||||
@ -241,6 +241,67 @@ describe('connectWithAuth', () => {
|
||||
expect(connectedTransport).toBe(transport);
|
||||
});
|
||||
|
||||
it('skips proactive OAuth when cached tokens are still usable', async () => {
|
||||
const connect = vi.fn().mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
|
||||
const tokens = vi.fn(async () => ({
|
||||
access_token: 'cached-token',
|
||||
token_type: 'Bearer',
|
||||
expires_at: Math.floor(Date.now() / 1000) + 3600,
|
||||
}));
|
||||
session.provider.tokens = tokens;
|
||||
mocks.sdkAuth.mockResolvedValueOnce('REDIRECT');
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger = createLogger();
|
||||
|
||||
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'courtlistener',
|
||||
maxAttempts: 1,
|
||||
oauthTimeoutMs: 5000,
|
||||
serverUrl: 'https://courtlistener.example/mcp',
|
||||
});
|
||||
|
||||
expect(tokens).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.sdkAuth).not.toHaveBeenCalled();
|
||||
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
|
||||
expect(transport.calls).toEqual([]);
|
||||
expect(session.close).toHaveBeenCalled();
|
||||
expect(connectedTransport).toBe(transport);
|
||||
});
|
||||
|
||||
it('runs proactive OAuth when cached tokens are expired', async () => {
|
||||
const connect = vi.fn().mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
|
||||
session.provider.tokens = vi.fn(async () => ({
|
||||
access_token: 'expired-token',
|
||||
token_type: 'Bearer',
|
||||
expires_at: Math.floor(Date.now() / 1000) - 60,
|
||||
}));
|
||||
mocks.sdkAuth.mockResolvedValueOnce('AUTHORIZED');
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger = createLogger();
|
||||
|
||||
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'courtlistener',
|
||||
maxAttempts: 1,
|
||||
oauthTimeoutMs: 5000,
|
||||
serverUrl: 'https://courtlistener.example/mcp',
|
||||
});
|
||||
|
||||
expect(mocks.sdkAuth).toHaveBeenCalledWith(session.provider, {
|
||||
serverUrl: 'https://courtlistener.example/mcp',
|
||||
fetchFn: undefined,
|
||||
});
|
||||
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
|
||||
expect(transport.calls).toEqual([]);
|
||||
expect(session.close).toHaveBeenCalled();
|
||||
expect(connectedTransport).toBe(transport);
|
||||
});
|
||||
|
||||
it('marks proactive OAuth failures as OAuth flow errors', async () => {
|
||||
const connect = vi.fn().mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
@ -261,6 +322,28 @@ describe('connectWithAuth', () => {
|
||||
).rejects.toSatisfy((error: unknown) => error === authError && isOAuthFlowError(error));
|
||||
});
|
||||
|
||||
it('marks cached token read failures during proactive OAuth as OAuth flow errors', async () => {
|
||||
const connect = vi.fn().mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
const { session } = createPendingAuthorizationSession();
|
||||
const tokenError = new Error('malformed token cache');
|
||||
session.provider.tokens = vi.fn(async () => {
|
||||
throw tokenError;
|
||||
});
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger = createLogger();
|
||||
|
||||
await expect(
|
||||
connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'calendar',
|
||||
maxAttempts: 1,
|
||||
oauthTimeoutMs: 5000,
|
||||
serverUrl: 'https://calendar.example/mcp',
|
||||
})
|
||||
).rejects.toSatisfy((error: unknown) => error === tokenError && isOAuthFlowError(error));
|
||||
});
|
||||
|
||||
it('marks finishAuth failures as oauth flow errors', async () => {
|
||||
const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed'));
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user