refactor: simplify oauth transport flow
This commit is contained in:
parent
72e3a16f0a
commit
41e049ddc0
@ -8,6 +8,18 @@ export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 60_000;
|
||||
const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error');
|
||||
const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error');
|
||||
|
||||
export interface OAuthCapableTransport extends Transport {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface ConnectWithAuthOptions {
|
||||
serverName?: string;
|
||||
maxAttempts?: number;
|
||||
oauthTimeoutMs?: number;
|
||||
recreateTransport?: (transport: OAuthCapableTransport) => Promise<OAuthCapableTransport>;
|
||||
}
|
||||
|
||||
export class OAuthTimeoutError extends Error {
|
||||
public readonly timeoutMs: number;
|
||||
public readonly serverName: string;
|
||||
@ -60,46 +72,16 @@ function hasErrorMarker(error: unknown, marker: symbol): boolean {
|
||||
|
||||
export async function connectWithAuth(
|
||||
client: Client,
|
||||
transport: Transport & {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
},
|
||||
transport: OAuthCapableTransport,
|
||||
session: OAuthSession | undefined,
|
||||
logger: Logger,
|
||||
options: {
|
||||
serverName?: string;
|
||||
maxAttempts?: number;
|
||||
oauthTimeoutMs?: number;
|
||||
recreateTransport?: (
|
||||
transport: Transport & {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
}
|
||||
) => Promise<
|
||||
Transport & {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
}
|
||||
>;
|
||||
} = {}
|
||||
): Promise<
|
||||
Transport & {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
}
|
||||
> {
|
||||
options: ConnectWithAuthOptions = {}
|
||||
): Promise<OAuthCapableTransport> {
|
||||
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options;
|
||||
let activeTransport = transport;
|
||||
let attempt = 0;
|
||||
let hasCompletedAuthFlow = false;
|
||||
|
||||
const closeReplacementTransport = async (): Promise<void> => {
|
||||
if (activeTransport === transport) {
|
||||
return;
|
||||
}
|
||||
await activeTransport.close().catch(() => {});
|
||||
};
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
await client.connect(activeTransport);
|
||||
@ -107,48 +89,72 @@ export async function connectWithAuth(
|
||||
} catch (error) {
|
||||
const unauthorized = isUnauthorizedError(error);
|
||||
if (hasCompletedAuthFlow && !unauthorized) {
|
||||
await closeReplacementTransport();
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw markPostAuthConnectError(error);
|
||||
}
|
||||
if (!unauthorized || !session) {
|
||||
await closeReplacementTransport();
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw error;
|
||||
}
|
||||
attempt += 1;
|
||||
if (attempt > maxAttempts) {
|
||||
await closeReplacementTransport();
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
|
||||
}
|
||||
logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`);
|
||||
try {
|
||||
const code = await waitForAuthorizationCodeWithTimeout(
|
||||
session,
|
||||
logger,
|
||||
activeTransport = await completeAuthorizationChallenge(activeTransport, session, logger, error, {
|
||||
serverName,
|
||||
oauthTimeoutMs ?? DEFAULT_OAUTH_CODE_TIMEOUT_MS
|
||||
);
|
||||
if (typeof activeTransport.finishAuth === 'function') {
|
||||
await activeTransport.finishAuth(code);
|
||||
if (recreateTransport) {
|
||||
const nextTransport = await recreateTransport(activeTransport);
|
||||
await activeTransport.close().catch(() => {});
|
||||
activeTransport = nextTransport;
|
||||
}
|
||||
hasCompletedAuthFlow = true;
|
||||
logger.info('Authorization code accepted. Retrying connection...');
|
||||
} else {
|
||||
logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.');
|
||||
throw error;
|
||||
}
|
||||
oauthTimeoutMs,
|
||||
recreateTransport,
|
||||
});
|
||||
hasCompletedAuthFlow = true;
|
||||
logger.info('Authorization code accepted. Retrying connection...');
|
||||
} catch (authError) {
|
||||
logger.error('OAuth authorization failed while waiting for callback.', authError);
|
||||
await closeReplacementTransport();
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw markOAuthFlowError(authError);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function closeReplacementTransport(
|
||||
originalTransport: OAuthCapableTransport,
|
||||
activeTransport: OAuthCapableTransport
|
||||
): Promise<void> {
|
||||
if (activeTransport === originalTransport) {
|
||||
return;
|
||||
}
|
||||
await activeTransport.close().catch(() => {});
|
||||
}
|
||||
|
||||
async function completeAuthorizationChallenge(
|
||||
transport: OAuthCapableTransport,
|
||||
session: OAuthSession,
|
||||
logger: Logger,
|
||||
connectError: unknown,
|
||||
options: Pick<ConnectWithAuthOptions, 'serverName' | 'oauthTimeoutMs' | 'recreateTransport'>
|
||||
): Promise<OAuthCapableTransport> {
|
||||
const code = await waitForAuthorizationCodeWithTimeout(
|
||||
session,
|
||||
logger,
|
||||
options.serverName,
|
||||
options.oauthTimeoutMs ?? DEFAULT_OAUTH_CODE_TIMEOUT_MS
|
||||
);
|
||||
if (typeof transport.finishAuth !== 'function') {
|
||||
logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.');
|
||||
throw connectError;
|
||||
}
|
||||
await transport.finishAuth(code);
|
||||
if (!options.recreateTransport) {
|
||||
return transport;
|
||||
}
|
||||
const nextTransport = await options.recreateTransport(transport);
|
||||
await transport.close().catch(() => {});
|
||||
return nextTransport;
|
||||
}
|
||||
|
||||
// Race the pending OAuth browser handshake so the runtime can't sit on an unresolved promise forever.
|
||||
export function waitForAuthorizationCodeWithTimeout(
|
||||
session: OAuthSession,
|
||||
|
||||
@ -12,7 +12,13 @@ import { readCachedAccessToken } from '../oauth-persistence.js';
|
||||
import { materializeHeaders } from '../runtime-header-utils.js';
|
||||
import { isUnauthorizedError, maybeEnableOAuth } from '../runtime-oauth-support.js';
|
||||
import { closeTransportAndWait } from '../runtime-process-utils.js';
|
||||
import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError, OAuthTimeoutError } from './oauth.js';
|
||||
import {
|
||||
connectWithAuth,
|
||||
isOAuthFlowError,
|
||||
isPostAuthConnectError,
|
||||
type OAuthCapableTransport,
|
||||
OAuthTimeoutError,
|
||||
} from './oauth.js';
|
||||
import { resolveCommandArgument, resolveCommandArguments } from './utils.js';
|
||||
|
||||
const STDIO_TRACE_ENABLED = process.env.MCPORTER_STDIO_TRACE === '1';
|
||||
@ -48,6 +54,11 @@ function isLegacySseTransportMismatch(error: unknown): boolean {
|
||||
return issue.kind === 'http' && (issue.statusCode === 404 || issue.statusCode === 405);
|
||||
}
|
||||
|
||||
interface ResolvedHttpTransportOptions {
|
||||
requestInit?: RequestInit;
|
||||
authProvider?: OAuthSession['provider'];
|
||||
}
|
||||
|
||||
function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void {
|
||||
// STDIO instrumentation is handled via sdk-patches side effects. This helper remains
|
||||
// so runtime callers can opt-in without sprinkling conditional checks everywhere.
|
||||
@ -67,6 +78,72 @@ export interface CreateClientContextOptions {
|
||||
readonly allowCachedAuth?: boolean;
|
||||
}
|
||||
|
||||
function removeAuthorizationHeader(headers: Record<string, string> | undefined): Record<string, string> | undefined {
|
||||
if (!headers) {
|
||||
return undefined;
|
||||
}
|
||||
for (const key of Object.keys(headers)) {
|
||||
if (key.toLowerCase() === 'authorization') {
|
||||
delete headers[key];
|
||||
}
|
||||
}
|
||||
return Object.keys(headers).length > 0 ? headers : undefined;
|
||||
}
|
||||
|
||||
function createHttpTransportOptions(
|
||||
definition: ServerDefinition,
|
||||
oauthSession: OAuthSession | undefined,
|
||||
shouldEstablishOAuth: boolean
|
||||
): ResolvedHttpTransportOptions {
|
||||
const command = definition.command;
|
||||
if (command.kind !== 'http') {
|
||||
throw new Error(`Server '${definition.name}' is not configured for HTTP transport.`);
|
||||
}
|
||||
const resolvedHeaders = materializeHeaders(command.headers, definition.name);
|
||||
const effectiveHeaders = shouldEstablishOAuth ? removeAuthorizationHeader(resolvedHeaders) : resolvedHeaders;
|
||||
return {
|
||||
requestInit: effectiveHeaders ? { headers: effectiveHeaders as HeadersInit } : undefined,
|
||||
authProvider: oauthSession?.provider,
|
||||
};
|
||||
}
|
||||
|
||||
async function closeOAuthSession(oauthSession?: OAuthSession): Promise<void> {
|
||||
await oauthSession?.close().catch(() => {});
|
||||
}
|
||||
|
||||
function shouldAbortSseFallback(error: unknown): boolean {
|
||||
if (isPostAuthConnectError(error)) {
|
||||
return !isLegacySseTransportMismatch(error);
|
||||
}
|
||||
return isOAuthFlowError(error) || error instanceof OAuthTimeoutError;
|
||||
}
|
||||
|
||||
function maybePromoteHttpDefinition(
|
||||
definition: ServerDefinition,
|
||||
logger: Logger,
|
||||
options: CreateClientContextOptions
|
||||
): ServerDefinition | undefined {
|
||||
if (options.maxOAuthAttempts === 0) {
|
||||
return undefined;
|
||||
}
|
||||
return maybeEnableOAuth(definition, logger);
|
||||
}
|
||||
|
||||
async function connectHttpTransport<TTransport extends OAuthCapableTransport>(
|
||||
client: Client,
|
||||
transport: TTransport,
|
||||
oauthSession: OAuthSession | undefined,
|
||||
logger: Logger,
|
||||
connectOptions: Parameters<typeof connectWithAuth>[4]
|
||||
): Promise<TTransport> {
|
||||
try {
|
||||
return (await connectWithAuth(client, transport, oauthSession, logger, connectOptions)) as TTransport;
|
||||
} catch (error) {
|
||||
await closeTransportAndWait(logger, transport).catch(() => {});
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function createClientContext(
|
||||
definition: ServerDefinition,
|
||||
logger: Logger,
|
||||
@ -146,91 +223,66 @@ export async function createClientContext(
|
||||
if (shouldEstablishOAuth) {
|
||||
oauthSession = await createOAuthSession(activeDefinition, logger);
|
||||
}
|
||||
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
|
||||
|
||||
const resolvedHeaders = materializeHeaders(command.headers, activeDefinition.name);
|
||||
if (shouldEstablishOAuth && resolvedHeaders) {
|
||||
for (const key of Object.keys(resolvedHeaders)) {
|
||||
if (key.toLowerCase() === 'authorization') {
|
||||
delete resolvedHeaders[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
const requestInit: RequestInit | undefined =
|
||||
resolvedHeaders && Object.keys(resolvedHeaders).length > 0
|
||||
? { headers: resolvedHeaders as HeadersInit }
|
||||
: undefined;
|
||||
const baseOptions = {
|
||||
requestInit,
|
||||
authProvider: oauthSession?.provider,
|
||||
};
|
||||
|
||||
const attemptConnect = async () => {
|
||||
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, baseOptions);
|
||||
let streamableTransport = createStreamableTransport();
|
||||
try {
|
||||
streamableTransport = (await connectWithAuth(client, streamableTransport, oauthSession, logger, {
|
||||
try {
|
||||
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
|
||||
const streamableTransport = await connectHttpTransport(
|
||||
client,
|
||||
createStreamableTransport(),
|
||||
oauthSession,
|
||||
logger,
|
||||
{
|
||||
serverName: activeDefinition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
recreateTransport: async () => createStreamableTransport(),
|
||||
})) as StreamableHTTPClientTransport;
|
||||
return {
|
||||
client,
|
||||
transport: streamableTransport,
|
||||
definition: activeDefinition,
|
||||
oauthSession,
|
||||
} as ClientContext;
|
||||
} catch (error) {
|
||||
await closeTransportAndWait(logger, streamableTransport).catch(() => {});
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
return await attemptConnect();
|
||||
} catch (primaryError) {
|
||||
if (isPostAuthConnectError(primaryError)) {
|
||||
if (!isLegacySseTransportMismatch(primaryError)) {
|
||||
await oauthSession?.close().catch(() => {});
|
||||
throw primaryError;
|
||||
}
|
||||
} else if (isOAuthFlowError(primaryError) || primaryError instanceof OAuthTimeoutError) {
|
||||
await oauthSession?.close().catch(() => {});
|
||||
);
|
||||
return {
|
||||
client,
|
||||
transport: streamableTransport,
|
||||
definition: activeDefinition,
|
||||
oauthSession,
|
||||
};
|
||||
} catch (primaryError) {
|
||||
if (shouldAbortSseFallback(primaryError)) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
throw primaryError;
|
||||
}
|
||||
if (isUnauthorizedError(primaryError)) {
|
||||
await oauthSession?.close().catch(() => {});
|
||||
await closeOAuthSession(oauthSession);
|
||||
oauthSession = undefined;
|
||||
if (options.maxOAuthAttempts !== 0) {
|
||||
const promoted = maybeEnableOAuth(activeDefinition, logger);
|
||||
if (promoted) {
|
||||
activeDefinition = promoted;
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
continue;
|
||||
}
|
||||
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
|
||||
if (promoted) {
|
||||
activeDefinition = promoted;
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (primaryError instanceof Error) {
|
||||
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
|
||||
}
|
||||
const sseTransport = new SSEClientTransport(command.url, {
|
||||
...baseOptions,
|
||||
});
|
||||
try {
|
||||
const connectedTransport = (await connectWithAuth(client, sseTransport, oauthSession, logger, {
|
||||
serverName: activeDefinition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
})) as SSEClientTransport;
|
||||
const connectedTransport = await connectHttpTransport(
|
||||
client,
|
||||
new SSEClientTransport(command.url, transportOptions),
|
||||
oauthSession,
|
||||
logger,
|
||||
{
|
||||
serverName: activeDefinition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
}
|
||||
);
|
||||
return { client, transport: connectedTransport, definition: activeDefinition, oauthSession };
|
||||
} catch (sseError) {
|
||||
await closeTransportAndWait(logger, sseTransport).catch(() => {});
|
||||
await oauthSession?.close().catch(() => {});
|
||||
await closeOAuthSession(oauthSession);
|
||||
if (sseError instanceof OAuthTimeoutError) {
|
||||
throw sseError;
|
||||
}
|
||||
if (isUnauthorizedError(sseError) && options.maxOAuthAttempts !== 0) {
|
||||
const promoted = maybeEnableOAuth(activeDefinition, logger);
|
||||
if (isUnauthorizedError(sseError)) {
|
||||
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
|
||||
if (promoted) {
|
||||
activeDefinition = promoted;
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
|
||||
@ -22,6 +22,45 @@ class MockTransport implements Transport {
|
||||
}
|
||||
}
|
||||
|
||||
function createLogger(): Logger {
|
||||
return {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
}
|
||||
|
||||
function createPendingAuthorizationSession() {
|
||||
const pendingResolvers: Array<(code: string) => void> = [];
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
pendingResolvers.push(resolve);
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
return {
|
||||
session,
|
||||
waitForAuthorizationCode,
|
||||
pendingResolvers,
|
||||
resolveNextCode: (code: string) => {
|
||||
const resolve = pendingResolvers.shift();
|
||||
if (!resolve) {
|
||||
throw new Error(`Missing pending authorization resolver for '${code}'.`);
|
||||
}
|
||||
resolve(code);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function flushAuthLoop(): Promise<void> {
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
}
|
||||
|
||||
describe('connectWithAuth', () => {
|
||||
it('waits for authorization code and retries connection', async () => {
|
||||
const connect = vi
|
||||
@ -30,26 +69,10 @@ describe('connectWithAuth', () => {
|
||||
.mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
let resolveCode: (code: string) => void = () => {};
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
resolveCode = resolve;
|
||||
})
|
||||
);
|
||||
const close = vi.fn(async () => {});
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close,
|
||||
};
|
||||
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -57,8 +80,8 @@ describe('connectWithAuth', () => {
|
||||
oauthTimeoutMs: 5000,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
resolveCode('oauth-code-123');
|
||||
await flushAuthLoop();
|
||||
resolveNextCode('oauth-code-123');
|
||||
|
||||
const connectedTransport = await promise;
|
||||
|
||||
@ -75,25 +98,10 @@ describe('connectWithAuth', () => {
|
||||
.mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
let resolveCode: (code: string) => void = () => {};
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
resolveCode = resolve;
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
const { session, waitForAuthorizationCode, resolveNextCode } = createPendingAuthorizationSession();
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -101,8 +109,8 @@ describe('connectWithAuth', () => {
|
||||
oauthTimeoutMs: 5000,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
resolveCode('oauth-code-123');
|
||||
await flushAuthLoop();
|
||||
resolveNextCode('oauth-code-123');
|
||||
|
||||
const connectedTransport = await promise;
|
||||
|
||||
@ -119,28 +127,12 @@ describe('connectWithAuth', () => {
|
||||
.mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
let resolveCode: (code: string) => void = () => {};
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
resolveCode = resolve;
|
||||
})
|
||||
);
|
||||
const close = vi.fn(async () => {});
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close,
|
||||
};
|
||||
const { session, resolveNextCode } = createPendingAuthorizationSession();
|
||||
|
||||
const transport = new MockTransport();
|
||||
const replacement = new MockTransport();
|
||||
const recreateTransport = vi.fn(async () => replacement);
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -149,8 +141,8 @@ describe('connectWithAuth', () => {
|
||||
recreateTransport,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
resolveCode('oauth-code-123');
|
||||
await flushAuthLoop();
|
||||
resolveNextCode('oauth-code-123');
|
||||
|
||||
const connectedTransport = await promise;
|
||||
|
||||
@ -169,25 +161,10 @@ describe('connectWithAuth', () => {
|
||||
.mockRejectedValueOnce(reconnectError);
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
let resolveCode: (code: string) => void = () => {};
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
resolveCode = resolve;
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
const { session, resolveNextCode } = createPendingAuthorizationSession();
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -195,8 +172,8 @@ describe('connectWithAuth', () => {
|
||||
oauthTimeoutMs: 5000,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
resolveCode('oauth-code-123');
|
||||
await flushAuthLoop();
|
||||
resolveNextCode('oauth-code-123');
|
||||
|
||||
await expect(promise).rejects.toSatisfy(
|
||||
(error: unknown) => error === reconnectError && isPostAuthConnectError(error)
|
||||
@ -211,25 +188,11 @@ describe('connectWithAuth', () => {
|
||||
.mockResolvedValueOnce(undefined);
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
const pendingResolvers: Array<(code: string) => void> = [];
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
pendingResolvers.push(resolve);
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
const { session, waitForAuthorizationCode, pendingResolvers, resolveNextCode } =
|
||||
createPendingAuthorizationSession();
|
||||
|
||||
const transport = new MockTransport();
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -237,13 +200,13 @@ describe('connectWithAuth', () => {
|
||||
oauthTimeoutMs: 5000,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
await flushAuthLoop();
|
||||
expect(pendingResolvers).toHaveLength(1);
|
||||
pendingResolvers.shift()?.('oauth-code-1');
|
||||
resolveNextCode('oauth-code-1');
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
await flushAuthLoop();
|
||||
expect(pendingResolvers).toHaveLength(1);
|
||||
pendingResolvers.shift()?.('oauth-code-2');
|
||||
resolveNextCode('oauth-code-2');
|
||||
|
||||
const connectedTransport = await promise;
|
||||
|
||||
@ -257,28 +220,13 @@ describe('connectWithAuth', () => {
|
||||
const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed'));
|
||||
const client = { connect } as unknown as Client;
|
||||
|
||||
let resolveCode: (code: string) => void = () => {};
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
resolveCode = resolve;
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
const { session, resolveNextCode } = createPendingAuthorizationSession();
|
||||
|
||||
const finishAuthError = new Error('token endpoint returned 405');
|
||||
const transport = new MockTransport(async () => {
|
||||
throw finishAuthError;
|
||||
});
|
||||
const logger: Logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
const logger = createLogger();
|
||||
|
||||
const promise = connectWithAuth(client, transport, session, logger, {
|
||||
serverName: 'test-server',
|
||||
@ -286,8 +234,8 @@ describe('connectWithAuth', () => {
|
||||
oauthTimeoutMs: 5000,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
resolveCode('oauth-code-123');
|
||||
await flushAuthLoop();
|
||||
resolveNextCode('oauth-code-123');
|
||||
|
||||
await expect(promise).rejects.toSatisfy((error: unknown) => error === finishAuthError && isOAuthFlowError(error));
|
||||
});
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import { StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StreamableHTTPClientTransport, StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
@ -63,12 +65,26 @@ function stubHttpDefinition(url: string): ServerDefinition {
|
||||
};
|
||||
}
|
||||
|
||||
function stubOAuthHttpDefinition(url: string): ServerDefinition {
|
||||
return {
|
||||
...stubHttpDefinition(url),
|
||||
auth: 'oauth',
|
||||
};
|
||||
}
|
||||
|
||||
function createPromotionRecorder() {
|
||||
const promotedDefinitions: ServerDefinition[] = [];
|
||||
return {
|
||||
promotedDefinitions,
|
||||
onDefinitionPromoted: (promoted: ServerDefinition) => {
|
||||
promotedDefinitions.push(promoted);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('createClientContext (HTTP)', () => {
|
||||
it('falls back to SSE when primary connect fails', async () => {
|
||||
const definition = stubHttpDefinition('https://example.com/mcp');
|
||||
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
|
||||
const clientConnect = vi
|
||||
.spyOn(Client.prototype, 'connect')
|
||||
@ -87,12 +103,7 @@ describe('createClientContext (HTTP)', () => {
|
||||
});
|
||||
|
||||
it('does not fall back to SSE after the OAuth flow fails', async () => {
|
||||
const definition: ServerDefinition = {
|
||||
...stubHttpDefinition('https://example.com/secure'),
|
||||
auth: 'oauth',
|
||||
};
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const definition = stubOAuthHttpDefinition('https://example.com/secure');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -115,12 +126,7 @@ describe('createClientContext (HTTP)', () => {
|
||||
});
|
||||
|
||||
it('still falls back to SSE after auth when Streamable HTTP reveals a 405 transport mismatch', async () => {
|
||||
const definition: ServerDefinition = {
|
||||
...stubHttpDefinition('https://example.com/legacy-sse'),
|
||||
auth: 'oauth',
|
||||
};
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const definition = stubOAuthHttpDefinition('https://example.com/legacy-sse');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -139,12 +145,7 @@ describe('createClientContext (HTTP)', () => {
|
||||
});
|
||||
|
||||
it('surfaces provider 405 errors after auth instead of falling back to SSE', async () => {
|
||||
const definition: ServerDefinition = {
|
||||
...stubHttpDefinition('https://example.com/provider-405'),
|
||||
auth: 'oauth',
|
||||
};
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const definition = stubOAuthHttpDefinition('https://example.com/provider-405');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -166,12 +167,7 @@ describe('createClientContext (HTTP)', () => {
|
||||
});
|
||||
|
||||
it('still falls back to SSE after auth for generic 405 transport errors', async () => {
|
||||
const definition: ServerDefinition = {
|
||||
...stubHttpDefinition('https://example.com/legacy-sse-proxy'),
|
||||
auth: 'oauth',
|
||||
};
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const definition = stubOAuthHttpDefinition('https://example.com/legacy-sse-proxy');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -192,12 +188,7 @@ describe('createClientContext (HTTP)', () => {
|
||||
});
|
||||
|
||||
it('still falls back to SSE for oauth servers when no Streamable auth challenge was observed', async () => {
|
||||
const definition: ServerDefinition = {
|
||||
...stubHttpDefinition('https://example.com/sse-only'),
|
||||
auth: 'oauth',
|
||||
};
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const definition = stubOAuthHttpDefinition('https://example.com/sse-only');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -217,7 +208,6 @@ describe('createClientContext (HTTP)', () => {
|
||||
|
||||
it('promotes ad-hoc HTTP servers after generic 401 errors from Streamable HTTP', async () => {
|
||||
const definition = stubHttpDefinition('https://example.com/secure');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -230,12 +220,10 @@ describe('createClientContext (HTTP)', () => {
|
||||
return transport;
|
||||
});
|
||||
|
||||
const promotedDefinitions: ServerDefinition[] = [];
|
||||
const { promotedDefinitions, onDefinitionPromoted } = createPromotionRecorder();
|
||||
const context = await createClientContext(definition, logger, clientInfo, {
|
||||
maxOAuthAttempts: 1,
|
||||
onDefinitionPromoted: (promoted) => {
|
||||
promotedDefinitions.push(promoted);
|
||||
},
|
||||
onDefinitionPromoted,
|
||||
});
|
||||
|
||||
expect(context.definition.auth).toBe('oauth');
|
||||
@ -246,8 +234,6 @@ describe('createClientContext (HTTP)', () => {
|
||||
|
||||
it('promotes ad-hoc HTTP servers after generic 401 errors from the SSE fallback path', async () => {
|
||||
const definition = stubHttpDefinition('https://example.com/sse-auth');
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
|
||||
mocks.connectWithAuth
|
||||
.mockImplementationOnce(async (_client, transport) => {
|
||||
@ -264,12 +250,10 @@ describe('createClientContext (HTTP)', () => {
|
||||
return transport;
|
||||
});
|
||||
|
||||
const promotedDefinitions: ServerDefinition[] = [];
|
||||
const { promotedDefinitions, onDefinitionPromoted } = createPromotionRecorder();
|
||||
const context = await createClientContext(definition, logger, clientInfo, {
|
||||
maxOAuthAttempts: 1,
|
||||
onDefinitionPromoted: (promoted) => {
|
||||
promotedDefinitions.push(promoted);
|
||||
},
|
||||
onDefinitionPromoted,
|
||||
});
|
||||
|
||||
expect(context.definition.auth).toBe('oauth');
|
||||
@ -291,8 +275,6 @@ describe('createClientContext (HTTP)', () => {
|
||||
},
|
||||
},
|
||||
};
|
||||
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
||||
const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
||||
const createOAuthSessionSpy = vi.spyOn(oauthModule, 'createOAuthSession').mockResolvedValue({
|
||||
provider: {} as never,
|
||||
waitForAuthorizationCode: vi.fn(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user