refactor: simplify oauth transport flow

This commit is contained in:
Peter Steinberger 2026-03-28 20:01:24 +00:00
parent 72e3a16f0a
commit 41e049ddc0
No known key found for this signature in database
4 changed files with 273 additions and 285 deletions

View File

@ -8,6 +8,18 @@ export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 60_000;
const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error');
const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error');
export interface OAuthCapableTransport extends Transport {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
export interface ConnectWithAuthOptions {
serverName?: string;
maxAttempts?: number;
oauthTimeoutMs?: number;
recreateTransport?: (transport: OAuthCapableTransport) => Promise<OAuthCapableTransport>;
}
export class OAuthTimeoutError extends Error {
public readonly timeoutMs: number;
public readonly serverName: string;
@ -60,46 +72,16 @@ function hasErrorMarker(error: unknown, marker: symbol): boolean {
export async function connectWithAuth(
client: Client,
transport: Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
},
transport: OAuthCapableTransport,
session: OAuthSession | undefined,
logger: Logger,
options: {
serverName?: string;
maxAttempts?: number;
oauthTimeoutMs?: number;
recreateTransport?: (
transport: Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
) => Promise<
Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
>;
} = {}
): Promise<
Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
> {
options: ConnectWithAuthOptions = {}
): Promise<OAuthCapableTransport> {
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options;
let activeTransport = transport;
let attempt = 0;
let hasCompletedAuthFlow = false;
const closeReplacementTransport = async (): Promise<void> => {
if (activeTransport === transport) {
return;
}
await activeTransport.close().catch(() => {});
};
while (true) {
try {
await client.connect(activeTransport);
@ -107,48 +89,72 @@ export async function connectWithAuth(
} catch (error) {
const unauthorized = isUnauthorizedError(error);
if (hasCompletedAuthFlow && !unauthorized) {
await closeReplacementTransport();
await closeReplacementTransport(transport, activeTransport);
throw markPostAuthConnectError(error);
}
if (!unauthorized || !session) {
await closeReplacementTransport();
await closeReplacementTransport(transport, activeTransport);
throw error;
}
attempt += 1;
if (attempt > maxAttempts) {
await closeReplacementTransport();
await closeReplacementTransport(transport, activeTransport);
throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
}
logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`);
try {
const code = await waitForAuthorizationCodeWithTimeout(
session,
logger,
activeTransport = await completeAuthorizationChallenge(activeTransport, session, logger, error, {
serverName,
oauthTimeoutMs ?? DEFAULT_OAUTH_CODE_TIMEOUT_MS
);
if (typeof activeTransport.finishAuth === 'function') {
await activeTransport.finishAuth(code);
if (recreateTransport) {
const nextTransport = await recreateTransport(activeTransport);
await activeTransport.close().catch(() => {});
activeTransport = nextTransport;
}
hasCompletedAuthFlow = true;
logger.info('Authorization code accepted. Retrying connection...');
} else {
logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.');
throw error;
}
oauthTimeoutMs,
recreateTransport,
});
hasCompletedAuthFlow = true;
logger.info('Authorization code accepted. Retrying connection...');
} catch (authError) {
logger.error('OAuth authorization failed while waiting for callback.', authError);
await closeReplacementTransport();
await closeReplacementTransport(transport, activeTransport);
throw markOAuthFlowError(authError);
}
}
}
}
async function closeReplacementTransport(
originalTransport: OAuthCapableTransport,
activeTransport: OAuthCapableTransport
): Promise<void> {
if (activeTransport === originalTransport) {
return;
}
await activeTransport.close().catch(() => {});
}
async function completeAuthorizationChallenge(
transport: OAuthCapableTransport,
session: OAuthSession,
logger: Logger,
connectError: unknown,
options: Pick<ConnectWithAuthOptions, 'serverName' | 'oauthTimeoutMs' | 'recreateTransport'>
): Promise<OAuthCapableTransport> {
const code = await waitForAuthorizationCodeWithTimeout(
session,
logger,
options.serverName,
options.oauthTimeoutMs ?? DEFAULT_OAUTH_CODE_TIMEOUT_MS
);
if (typeof transport.finishAuth !== 'function') {
logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.');
throw connectError;
}
await transport.finishAuth(code);
if (!options.recreateTransport) {
return transport;
}
const nextTransport = await options.recreateTransport(transport);
await transport.close().catch(() => {});
return nextTransport;
}
// Race the pending OAuth browser handshake so the runtime can't sit on an unresolved promise forever.
export function waitForAuthorizationCodeWithTimeout(
session: OAuthSession,

View File

@ -12,7 +12,13 @@ import { readCachedAccessToken } from '../oauth-persistence.js';
import { materializeHeaders } from '../runtime-header-utils.js';
import { isUnauthorizedError, maybeEnableOAuth } from '../runtime-oauth-support.js';
import { closeTransportAndWait } from '../runtime-process-utils.js';
import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError, OAuthTimeoutError } from './oauth.js';
import {
connectWithAuth,
isOAuthFlowError,
isPostAuthConnectError,
type OAuthCapableTransport,
OAuthTimeoutError,
} from './oauth.js';
import { resolveCommandArgument, resolveCommandArguments } from './utils.js';
const STDIO_TRACE_ENABLED = process.env.MCPORTER_STDIO_TRACE === '1';
@ -48,6 +54,11 @@ function isLegacySseTransportMismatch(error: unknown): boolean {
return issue.kind === 'http' && (issue.statusCode === 404 || issue.statusCode === 405);
}
interface ResolvedHttpTransportOptions {
requestInit?: RequestInit;
authProvider?: OAuthSession['provider'];
}
function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void {
// STDIO instrumentation is handled via sdk-patches side effects. This helper remains
// so runtime callers can opt-in without sprinkling conditional checks everywhere.
@ -67,6 +78,72 @@ export interface CreateClientContextOptions {
readonly allowCachedAuth?: boolean;
}
function removeAuthorizationHeader(headers: Record<string, string> | undefined): Record<string, string> | undefined {
if (!headers) {
return undefined;
}
for (const key of Object.keys(headers)) {
if (key.toLowerCase() === 'authorization') {
delete headers[key];
}
}
return Object.keys(headers).length > 0 ? headers : undefined;
}
function createHttpTransportOptions(
definition: ServerDefinition,
oauthSession: OAuthSession | undefined,
shouldEstablishOAuth: boolean
): ResolvedHttpTransportOptions {
const command = definition.command;
if (command.kind !== 'http') {
throw new Error(`Server '${definition.name}' is not configured for HTTP transport.`);
}
const resolvedHeaders = materializeHeaders(command.headers, definition.name);
const effectiveHeaders = shouldEstablishOAuth ? removeAuthorizationHeader(resolvedHeaders) : resolvedHeaders;
return {
requestInit: effectiveHeaders ? { headers: effectiveHeaders as HeadersInit } : undefined,
authProvider: oauthSession?.provider,
};
}
async function closeOAuthSession(oauthSession?: OAuthSession): Promise<void> {
await oauthSession?.close().catch(() => {});
}
function shouldAbortSseFallback(error: unknown): boolean {
if (isPostAuthConnectError(error)) {
return !isLegacySseTransportMismatch(error);
}
return isOAuthFlowError(error) || error instanceof OAuthTimeoutError;
}
function maybePromoteHttpDefinition(
definition: ServerDefinition,
logger: Logger,
options: CreateClientContextOptions
): ServerDefinition | undefined {
if (options.maxOAuthAttempts === 0) {
return undefined;
}
return maybeEnableOAuth(definition, logger);
}
async function connectHttpTransport<TTransport extends OAuthCapableTransport>(
client: Client,
transport: TTransport,
oauthSession: OAuthSession | undefined,
logger: Logger,
connectOptions: Parameters<typeof connectWithAuth>[4]
): Promise<TTransport> {
try {
return (await connectWithAuth(client, transport, oauthSession, logger, connectOptions)) as TTransport;
} catch (error) {
await closeTransportAndWait(logger, transport).catch(() => {});
throw error;
}
}
export async function createClientContext(
definition: ServerDefinition,
logger: Logger,
@ -146,91 +223,66 @@ export async function createClientContext(
if (shouldEstablishOAuth) {
oauthSession = await createOAuthSession(activeDefinition, logger);
}
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
const resolvedHeaders = materializeHeaders(command.headers, activeDefinition.name);
if (shouldEstablishOAuth && resolvedHeaders) {
for (const key of Object.keys(resolvedHeaders)) {
if (key.toLowerCase() === 'authorization') {
delete resolvedHeaders[key];
}
}
}
const requestInit: RequestInit | undefined =
resolvedHeaders && Object.keys(resolvedHeaders).length > 0
? { headers: resolvedHeaders as HeadersInit }
: undefined;
const baseOptions = {
requestInit,
authProvider: oauthSession?.provider,
};
const attemptConnect = async () => {
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, baseOptions);
let streamableTransport = createStreamableTransport();
try {
streamableTransport = (await connectWithAuth(client, streamableTransport, oauthSession, logger, {
try {
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
const streamableTransport = await connectHttpTransport(
client,
createStreamableTransport(),
oauthSession,
logger,
{
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
recreateTransport: async () => createStreamableTransport(),
})) as StreamableHTTPClientTransport;
return {
client,
transport: streamableTransport,
definition: activeDefinition,
oauthSession,
} as ClientContext;
} catch (error) {
await closeTransportAndWait(logger, streamableTransport).catch(() => {});
throw error;
}
};
try {
return await attemptConnect();
} catch (primaryError) {
if (isPostAuthConnectError(primaryError)) {
if (!isLegacySseTransportMismatch(primaryError)) {
await oauthSession?.close().catch(() => {});
throw primaryError;
}
} else if (isOAuthFlowError(primaryError) || primaryError instanceof OAuthTimeoutError) {
await oauthSession?.close().catch(() => {});
);
return {
client,
transport: streamableTransport,
definition: activeDefinition,
oauthSession,
};
} catch (primaryError) {
if (shouldAbortSseFallback(primaryError)) {
await closeOAuthSession(oauthSession);
throw primaryError;
}
if (isUnauthorizedError(primaryError)) {
await oauthSession?.close().catch(() => {});
await closeOAuthSession(oauthSession);
oauthSession = undefined;
if (options.maxOAuthAttempts !== 0) {
const promoted = maybeEnableOAuth(activeDefinition, logger);
if (promoted) {
activeDefinition = promoted;
options.onDefinitionPromoted?.(promoted);
continue;
}
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
if (promoted) {
activeDefinition = promoted;
options.onDefinitionPromoted?.(promoted);
continue;
}
}
if (primaryError instanceof Error) {
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
}
const sseTransport = new SSEClientTransport(command.url, {
...baseOptions,
});
try {
const connectedTransport = (await connectWithAuth(client, sseTransport, oauthSession, logger, {
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
})) as SSEClientTransport;
const connectedTransport = await connectHttpTransport(
client,
new SSEClientTransport(command.url, transportOptions),
oauthSession,
logger,
{
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
}
);
return { client, transport: connectedTransport, definition: activeDefinition, oauthSession };
} catch (sseError) {
await closeTransportAndWait(logger, sseTransport).catch(() => {});
await oauthSession?.close().catch(() => {});
await closeOAuthSession(oauthSession);
if (sseError instanceof OAuthTimeoutError) {
throw sseError;
}
if (isUnauthorizedError(sseError) && options.maxOAuthAttempts !== 0) {
const promoted = maybeEnableOAuth(activeDefinition, logger);
if (isUnauthorizedError(sseError)) {
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
if (promoted) {
activeDefinition = promoted;
options.onDefinitionPromoted?.(promoted);

View File

@ -22,6 +22,45 @@ class MockTransport implements Transport {
}
}
function createLogger(): Logger {
return {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
}
function createPendingAuthorizationSession() {
const pendingResolvers: Array<(code: string) => void> = [];
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
pendingResolvers.push(resolve);
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
return {
session,
waitForAuthorizationCode,
pendingResolvers,
resolveNextCode: (code: string) => {
const resolve = pendingResolvers.shift();
if (!resolve) {
throw new Error(`Missing pending authorization resolver for '${code}'.`);
}
resolve(code);
},
};
}
async function flushAuthLoop(): Promise<void> {
await new Promise((resolve) => setImmediate(resolve));
}
describe('connectWithAuth', () => {
it('waits for authorization code and retries connection', async () => {
const connect = vi
@ -30,26 +69,10 @@ describe('connectWithAuth', () => {
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
let resolveCode: (code: string) => void = () => {};
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
resolveCode = resolve;
})
);
const close = vi.fn(async () => {});
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close,
};
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -57,8 +80,8 @@ describe('connectWithAuth', () => {
oauthTimeoutMs: 5000,
});
await new Promise((resolve) => setImmediate(resolve));
resolveCode('oauth-code-123');
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
@ -75,25 +98,10 @@ describe('connectWithAuth', () => {
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
let resolveCode: (code: string) => void = () => {};
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
resolveCode = resolve;
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -101,8 +109,8 @@ describe('connectWithAuth', () => {
oauthTimeoutMs: 5000,
});
await new Promise((resolve) => setImmediate(resolve));
resolveCode('oauth-code-123');
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
@ -119,28 +127,12 @@ describe('connectWithAuth', () => {
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
let resolveCode: (code: string) => void = () => {};
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
resolveCode = resolve;
})
);
const close = vi.fn(async () => {});
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close,
};
const { session, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const replacement = new MockTransport();
const recreateTransport = vi.fn(async () => replacement);
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -149,8 +141,8 @@ describe('connectWithAuth', () => {
recreateTransport,
});
await new Promise((resolve) => setImmediate(resolve));
resolveCode('oauth-code-123');
await flushAuthLoop();
resolveNextCode('oauth-code-123');
const connectedTransport = await promise;
@ -169,25 +161,10 @@ describe('connectWithAuth', () => {
.mockRejectedValueOnce(reconnectError);
const client = { connect } as unknown as Client;
let resolveCode: (code: string) => void = () => {};
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
resolveCode = resolve;
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
const { session, resolveNextCode } = createPendingAuthorizationSession();
const transport = new MockTransport();
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -195,8 +172,8 @@ describe('connectWithAuth', () => {
oauthTimeoutMs: 5000,
});
await new Promise((resolve) => setImmediate(resolve));
resolveCode('oauth-code-123');
await flushAuthLoop();
resolveNextCode('oauth-code-123');
await expect(promise).rejects.toSatisfy(
(error: unknown) => error === reconnectError && isPostAuthConnectError(error)
@ -211,25 +188,11 @@ describe('connectWithAuth', () => {
.mockResolvedValueOnce(undefined);
const client = { connect } as unknown as Client;
const pendingResolvers: Array<(code: string) => void> = [];
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
pendingResolvers.push(resolve);
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
const { session, waitForAuthorizationCode, pendingResolvers, resolveNextCode } =
createPendingAuthorizationSession();
const transport = new MockTransport();
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -237,13 +200,13 @@ describe('connectWithAuth', () => {
oauthTimeoutMs: 5000,
});
await new Promise((resolve) => setImmediate(resolve));
await flushAuthLoop();
expect(pendingResolvers).toHaveLength(1);
pendingResolvers.shift()?.('oauth-code-1');
resolveNextCode('oauth-code-1');
await new Promise((resolve) => setImmediate(resolve));
await flushAuthLoop();
expect(pendingResolvers).toHaveLength(1);
pendingResolvers.shift()?.('oauth-code-2');
resolveNextCode('oauth-code-2');
const connectedTransport = await promise;
@ -257,28 +220,13 @@ describe('connectWithAuth', () => {
const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed'));
const client = { connect } as unknown as Client;
let resolveCode: (code: string) => void = () => {};
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
resolveCode = resolve;
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
const { session, resolveNextCode } = createPendingAuthorizationSession();
const finishAuthError = new Error('token endpoint returned 405');
const transport = new MockTransport(async () => {
throw finishAuthError;
});
const logger: Logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const logger = createLogger();
const promise = connectWithAuth(client, transport, session, logger, {
serverName: 'test-server',
@ -286,8 +234,8 @@ describe('connectWithAuth', () => {
oauthTimeoutMs: 5000,
});
await new Promise((resolve) => setImmediate(resolve));
resolveCode('oauth-code-123');
await flushAuthLoop();
resolveNextCode('oauth-code-123');
await expect(promise).rejects.toSatisfy((error: unknown) => error === finishAuthError && isOAuthFlowError(error));
});

View File

@ -1,4 +1,6 @@
import { StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StreamableHTTPClientTransport, StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
const mocks = vi.hoisted(() => ({
@ -63,12 +65,26 @@ function stubHttpDefinition(url: string): ServerDefinition {
};
}
function stubOAuthHttpDefinition(url: string): ServerDefinition {
return {
...stubHttpDefinition(url),
auth: 'oauth',
};
}
function createPromotionRecorder() {
const promotedDefinitions: ServerDefinition[] = [];
return {
promotedDefinitions,
onDefinitionPromoted: (promoted: ServerDefinition) => {
promotedDefinitions.push(promoted);
},
};
}
describe('createClientContext (HTTP)', () => {
it('falls back to SSE when primary connect fails', async () => {
const definition = stubHttpDefinition('https://example.com/mcp');
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const clientConnect = vi
.spyOn(Client.prototype, 'connect')
@ -87,12 +103,7 @@ describe('createClientContext (HTTP)', () => {
});
it('does not fall back to SSE after the OAuth flow fails', async () => {
const definition: ServerDefinition = {
...stubHttpDefinition('https://example.com/secure'),
auth: 'oauth',
};
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const definition = stubOAuthHttpDefinition('https://example.com/secure');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -115,12 +126,7 @@ describe('createClientContext (HTTP)', () => {
});
it('still falls back to SSE after auth when Streamable HTTP reveals a 405 transport mismatch', async () => {
const definition: ServerDefinition = {
...stubHttpDefinition('https://example.com/legacy-sse'),
auth: 'oauth',
};
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const definition = stubOAuthHttpDefinition('https://example.com/legacy-sse');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -139,12 +145,7 @@ describe('createClientContext (HTTP)', () => {
});
it('surfaces provider 405 errors after auth instead of falling back to SSE', async () => {
const definition: ServerDefinition = {
...stubHttpDefinition('https://example.com/provider-405'),
auth: 'oauth',
};
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const definition = stubOAuthHttpDefinition('https://example.com/provider-405');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -166,12 +167,7 @@ describe('createClientContext (HTTP)', () => {
});
it('still falls back to SSE after auth for generic 405 transport errors', async () => {
const definition: ServerDefinition = {
...stubHttpDefinition('https://example.com/legacy-sse-proxy'),
auth: 'oauth',
};
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const definition = stubOAuthHttpDefinition('https://example.com/legacy-sse-proxy');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -192,12 +188,7 @@ describe('createClientContext (HTTP)', () => {
});
it('still falls back to SSE for oauth servers when no Streamable auth challenge was observed', async () => {
const definition: ServerDefinition = {
...stubHttpDefinition('https://example.com/sse-only'),
auth: 'oauth',
};
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const definition = stubOAuthHttpDefinition('https://example.com/sse-only');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -217,7 +208,6 @@ describe('createClientContext (HTTP)', () => {
it('promotes ad-hoc HTTP servers after generic 401 errors from Streamable HTTP', async () => {
const definition = stubHttpDefinition('https://example.com/secure');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -230,12 +220,10 @@ describe('createClientContext (HTTP)', () => {
return transport;
});
const promotedDefinitions: ServerDefinition[] = [];
const { promotedDefinitions, onDefinitionPromoted } = createPromotionRecorder();
const context = await createClientContext(definition, logger, clientInfo, {
maxOAuthAttempts: 1,
onDefinitionPromoted: (promoted) => {
promotedDefinitions.push(promoted);
},
onDefinitionPromoted,
});
expect(context.definition.auth).toBe('oauth');
@ -246,8 +234,6 @@ describe('createClientContext (HTTP)', () => {
it('promotes ad-hoc HTTP servers after generic 401 errors from the SSE fallback path', async () => {
const definition = stubHttpDefinition('https://example.com/sse-auth');
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
mocks.connectWithAuth
.mockImplementationOnce(async (_client, transport) => {
@ -264,12 +250,10 @@ describe('createClientContext (HTTP)', () => {
return transport;
});
const promotedDefinitions: ServerDefinition[] = [];
const { promotedDefinitions, onDefinitionPromoted } = createPromotionRecorder();
const context = await createClientContext(definition, logger, clientInfo, {
maxOAuthAttempts: 1,
onDefinitionPromoted: (promoted) => {
promotedDefinitions.push(promoted);
},
onDefinitionPromoted,
});
expect(context.definition.auth).toBe('oauth');
@ -291,8 +275,6 @@ describe('createClientContext (HTTP)', () => {
},
},
};
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
const createOAuthSessionSpy = vi.spyOn(oauthModule, 'createOAuthSession').mockResolvedValue({
provider: {} as never,
waitForAuthorizationCode: vi.fn(),