mcporter/tests/runtime-oauth-connect.test.ts
2026-05-20 17:21:05 +01:00

415 lines
14 KiB
TypeScript

import type { Client } from '@modelcontextprotocol/sdk/client';
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { OAuthSession } from '../src/oauth.js';
import {
connectWithAuth,
isOAuthFlowError,
isPostAuthConnectError,
OAuthAuthorizationNotStartedError,
} from '../src/runtime/oauth.js';
import {
createLogger,
createPendingAuthorizationSession,
flushAuthLoop,
MockTransport,
} from './helpers/runtime-test-helpers.js';
const mocks = vi.hoisted(() => ({
sdkAuth: vi.fn(),
}));
vi.mock('@modelcontextprotocol/sdk/client/auth.js', async () => {
const actual = await vi.importActual('@modelcontextprotocol/sdk/client/auth.js');
return {
...actual,
auth: mocks.sdkAuth,
};
});
describe('connectWithAuth', () => {
beforeEach(() => {
mocks.sdkAuth.mockReset();
mocks.sdkAuth.mockResolvedValue('AUTHORIZED');
});
it('waits for authorization code and retries connection', async () => {
const connect = vi
.fn()
.mockRejectedValueOnce(new UnauthorizedError('auth needed'))
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 1,
oauthTimeoutMs: 5000,
});
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
expect(waitForAuthorizationCode).toHaveBeenCalledTimes(1);
expect(transport.calls).toEqual(['oauth-code-123']);
expect(connect).toHaveBeenCalledTimes(2);
expect(connectedTransport).toBe(transport);
});
it('treats generic 401 transport errors as OAuth challenges', async () => {
const connect = vi
.fn()
.mockRejectedValueOnce(new Error('SSE error: Non-200 status code (401)'))
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 1,
oauthTimeoutMs: 5000,
});
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
expect(waitForAuthorizationCode).toHaveBeenCalledTimes(1);
expect(transport.calls).toEqual(['oauth-code-123']);
expect(connect).toHaveBeenCalledTimes(2);
expect(connectedTransport).toBe(transport);
});
it('recreates the transport after finishAuth when requested', async () => {
const connect = vi
.fn()
.mockRejectedValueOnce(new UnauthorizedError('auth needed'))
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const replacement = new MockTransport();
const recreateTransport = vi.fn(async () => replacement);
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 1,
oauthTimeoutMs: 5000,
recreateTransport,
});
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
expect(recreateTransport).toHaveBeenCalledWith(transport);
expect(transport.calls).toEqual(['oauth-code-123']);
expect(connect).toHaveBeenNthCalledWith(1, transport);
expect(connect).toHaveBeenNthCalledWith(2, replacement);
expect(connectedTransport).toBe(replacement);
});
it('marks reconnect failures after auth as post-auth transport errors', async () => {
const reconnectError = new Error('HTTP error 405: Method Not Allowed');
const connect = vi
.fn()
.mockRejectedValueOnce(new UnauthorizedError('auth needed'))
.mockRejectedValueOnce(reconnectError);
const client = { connect } as unknown as Client;
const { session, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 1,
oauthTimeoutMs: 5000,
});
await flushAuthLoop();
resolveNextCode('oauth-code-123');
await expect(promise).rejects.toSatisfy(
(error: unknown) => error === reconnectError && isPostAuthConnectError(error)
);
});
it('retries unauthorized reconnects after completing auth', async () => {
const connect = vi
.fn()
.mockRejectedValueOnce(new UnauthorizedError('initial auth needed'))
.mockRejectedValueOnce(new UnauthorizedError('token not active yet'))
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode, pendingResolvers, resolveNextCode } =
createPendingAuthorizationSession();
const transport = new MockTransport();
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 2,
oauthTimeoutMs: 5000,
});
await flushAuthLoop();
expect(pendingResolvers).toHaveLength(1);
resolveNextCode('oauth-code-1');
await flushAuthLoop();
expect(pendingResolvers).toHaveLength(1);
resolveNextCode('oauth-code-2');
const connectedTransport = await promise;
expect(waitForAuthorizationCode).toHaveBeenCalledTimes(2);
expect(transport.calls).toEqual(['oauth-code-1', 'oauth-code-2']);
expect(connect).toHaveBeenCalledTimes(3);
expect(connectedTransport).toBe(transport);
});
it('runs proactive OAuth after unauthenticated connect succeeds', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
mocks.sdkAuth.mockResolvedValueOnce('REDIRECT');
const transport = new MockTransport();
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'calendar',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: new URL('https://calendar.example/mcp'),
});
await flushAuthLoop();
resolveNextCode('proactive-code');
const connectedTransport = await promise;
expect(mocks.sdkAuth).toHaveBeenCalledWith(session.provider, {
serverUrl: new URL('https://calendar.example/mcp'),
fetchFn: undefined,
});
expect(waitForAuthorizationCode).toHaveBeenCalledTimes(1);
expect(transport.calls).toEqual(['proactive-code']);
expect(connect).toHaveBeenCalledTimes(1);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('closes proactive OAuth sessions when cached tokens already authorize', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
mocks.sdkAuth.mockResolvedValueOnce('AUTHORIZED');
const transport = new MockTransport();
const logger = createLogger();
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
serverName: 'calendar',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://calendar.example/mcp',
});
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(transport.calls).toEqual([]);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('skips proactive OAuth when cached tokens are still usable', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
const tokens = vi.fn(async () => ({
access_token: 'cached-token',
token_type: 'Bearer',
expires_at: Math.floor(Date.now() / 1000) + 3600,
}));
session.provider.tokens = tokens;
mocks.sdkAuth.mockResolvedValueOnce('REDIRECT');
const transport = new MockTransport();
const logger = createLogger();
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
serverName: 'courtlistener',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://courtlistener.example/mcp',
});
expect(tokens).toHaveBeenCalledTimes(1);
expect(mocks.sdkAuth).not.toHaveBeenCalled();
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(transport.calls).toEqual([]);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('runs proactive OAuth when cached tokens are expired', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session, waitForAuthorizationCode } = createPendingAuthorizationSession();
session.provider.tokens = vi.fn(async () => ({
access_token: 'expired-token',
token_type: 'Bearer',
expires_at: Math.floor(Date.now() / 1000) - 60,
}));
mocks.sdkAuth.mockResolvedValueOnce('AUTHORIZED');
const transport = new MockTransport();
const logger = createLogger();
const connectedTransport = await connectWithAuth(client, transport, session, logger, {
serverName: 'courtlistener',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://courtlistener.example/mcp',
});
expect(mocks.sdkAuth).toHaveBeenCalledWith(session.provider, {
serverUrl: 'https://courtlistener.example/mcp',
fetchFn: undefined,
});
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(transport.calls).toEqual([]);
expect(session.close).toHaveBeenCalled();
expect(connectedTransport).toBe(transport);
});
it('marks proactive OAuth failures as OAuth flow errors', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session } = createPendingAuthorizationSession();
const authError = new Error('dynamic client registration rejected');
mocks.sdkAuth.mockRejectedValueOnce(authError);
const transport = new MockTransport();
const logger = createLogger();
await expect(
connectWithAuth(client, transport, session, logger, {
serverName: 'calendar',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://calendar.example/mcp',
})
).rejects.toSatisfy((error: unknown) => error === authError && isOAuthFlowError(error));
});
it('marks cached token read failures during proactive OAuth as OAuth flow errors', async () => {
const connect = vi.fn().mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const { session } = createPendingAuthorizationSession();
const tokenError = new Error('malformed token cache');
session.provider.tokens = vi.fn(async () => {
throw tokenError;
});
const transport = new MockTransport();
const logger = createLogger();
await expect(
connectWithAuth(client, transport, session, logger, {
serverName: 'calendar',
maxAttempts: 1,
oauthTimeoutMs: 5000,
serverUrl: 'https://calendar.example/mcp',
})
).rejects.toSatisfy((error: unknown) => error === tokenError && isOAuthFlowError(error));
});
it('marks finishAuth failures as oauth flow errors', async () => {
const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed'));
const client = { connect } as unknown as Client;
const { session, resolveNextCode } = createPendingAuthorizationSession();
const finishAuthError = new Error('token endpoint returned 405');
const transport = new MockTransport(async () => {
throw finishAuthError;
});
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
maxAttempts: 1,
oauthTimeoutMs: 5000,
});
await flushAuthLoop();
resolveNextCode('oauth-code-123');
await expect(promise).rejects.toSatisfy((error: unknown) => error === finishAuthError && isOAuthFlowError(error));
});
it('fails immediately when OAuth never produced an authorization URL', async () => {
const connectError = new UnauthorizedError('dynamic client registration rejected');
const connect = vi.fn().mockRejectedValueOnce(connectError);
const client = { connect } as unknown as Client;
const waitForAuthorizationCode = vi.fn(() => new Promise<string>(() => {}));
const session = {
provider: {
waitForAuthorizationCode,
hasAuthorizationRedirectStarted: () => false,
},
waitForAuthorizationCode,
hasAuthorizationRedirectStarted: () => false,
close: vi.fn(async () => {}),
} as unknown as OAuthSession;
const transport = new MockTransport();
const logger = createLogger();
await expect(
connectWithAuth(client, transport, session, logger, {
serverName: 'figma',
maxAttempts: 1,
oauthTimeoutMs: 5000,
})
).rejects.toSatisfy(
(error: unknown) => error instanceof OAuthAuthorizationNotStartedError && isOAuthFlowError(error)
);
expect(waitForAuthorizationCode).not.toHaveBeenCalled();
expect(logger.warn).not.toHaveBeenCalledWith(expect.stringContaining('Waiting for browser approval'));
expect(logger.error).toHaveBeenCalledWith(
'OAuth authorization could not start.',
expect.any(OAuthAuthorizationNotStartedError)
);
});
it('truncates oversized OAuth startup error details', () => {
const hugeHtml = `<html>${'x'.repeat(5000)}</html>`;
const error = new OAuthAuthorizationNotStartedError('shadcn', new Error(`HTTP 404 raw body: ${hugeHtml}`));
expect(error.message.length).toBeLessThan(1800);
expect(error.message).toContain('[truncated ');
expect(error.message).not.toContain(hugeHtml);
});
});