diff --git a/src/cli/call-arguments.ts b/src/cli/call-arguments.ts index 619497c..97818cb 100644 --- a/src/cli/call-arguments.ts +++ b/src/cli/call-arguments.ts @@ -6,6 +6,7 @@ import { parseKeyValueToken, shouldPromoteSelectorToCommand, } from './call-argument-values.js'; +import { buildUnknownCallFlagMessage } from './call-help.js'; import { extractEphemeralServerFlags } from './ephemeral-flags.js'; import { CliUsageError } from './errors.js'; import { consumeOutputFormat } from './output-format.js'; @@ -39,6 +40,16 @@ interface FlagHandlerContext { type FlagHandler = (context: FlagHandlerContext) => number; +interface ScannedCallTokens { + positional: string[]; + literalPositional: string[]; +} + +interface CallExpressionResolution { + callExpressionProvidedServer: boolean; + callExpressionProvidedTool: boolean; +} + const FLAG_HANDLERS: Record = { '--server': handleServerFlag, '--mcp': handleServerFlag, @@ -60,6 +71,15 @@ export function parseCallArguments(args: string[]): CallArgsParseResult { result.output = consumeOutputFormat(args, { defaultFormat: 'auto', }); + const { positional, literalPositional } = scanCallTokens(args, result, flagState); + const { callExpressionProvidedServer, callExpressionProvidedTool } = applyLeadingCallExpression(positional, result); + resolveSelectorAndTool(positional, result, callExpressionProvidedServer, callExpressionProvidedTool); + applyTrailingArguments(positional, result, flagState); + appendLiteralPositionalArguments(literalPositional, result, flagState); + return result; +} + +function scanCallTokens(args: string[], result: CallArgsParseResult, state: FlagParseState): ScannedCallTokens { const positional: string[] = []; const literalPositional: string[] = []; let index = 0; @@ -75,70 +95,70 @@ export function parseCallArguments(args: string[]): CallArgsParseResult { } const flagHandler = FLAG_HANDLERS[token]; if (flagHandler) { - index = flagHandler({ args, index, result, state: flagState }); + index = flagHandler({ args, index, result, state }); continue; } if (token.startsWith('--')) { - throw new CliUsageError( - [ - `Unknown flag '${token}' passed to call command.`, - `If you intended to pass a tool argument, use '${token.slice(2)}=' or --args '{"${token.slice(2)}": ...}'.`, - "If you intended to pass a literal positional value, insert '--' before it.", - "Run 'mcporter call --help' to see available flags.", - ].join('\n') - ); + throw new CliUsageError(buildUnknownCallFlagMessage(token)); } positional.push(token); index += 1; } + return { positional, literalPositional }; +} - let callExpressionProvidedServer = false; - let callExpressionProvidedTool = false; - - if (positional.length > 0) { - const rawToken = positional[0] ?? ''; - const callExpression = parseLeadingCallExpression(rawToken); - if (callExpression) { - positional.shift(); - callExpressionProvidedServer = Boolean(callExpression.server); - callExpressionProvidedTool = Boolean(callExpression.tool); - if (callExpression.server) { - if (result.server && result.server !== callExpression.server) { - throw new Error( - `Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.` - ); - } - result.server = result.server ?? callExpression.server; - } - if (result.tool && result.tool !== callExpression.tool) { - throw new Error( - `Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.` - ); - } - result.tool = callExpression.tool; - Object.assign(result.args, callExpression.args); - if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) { - result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs]; - } - } +function applyLeadingCallExpression(positional: string[], result: CallArgsParseResult): CallExpressionResolution { + if (positional.length === 0) { + return { callExpressionProvidedServer: false, callExpressionProvidedTool: false }; } + const rawToken = positional[0] ?? ''; + const callExpression = parseLeadingCallExpression(rawToken); + if (!callExpression) { + return { callExpressionProvidedServer: false, callExpressionProvidedTool: false }; + } + positional.shift(); + if (callExpression.server) { + if (result.server && result.server !== callExpression.server) { + throw new Error( + `Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.` + ); + } + result.server = result.server ?? callExpression.server; + } + if (result.tool && result.tool !== callExpression.tool) { + throw new Error( + `Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.` + ); + } + result.tool = callExpression.tool; + Object.assign(result.args, callExpression.args); + if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) { + result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs]; + } + return { + callExpressionProvidedServer: Boolean(callExpression.server), + callExpressionProvidedTool: Boolean(callExpression.tool), + }; +} +function resolveSelectorAndTool( + positional: string[], + result: CallArgsParseResult, + callExpressionProvidedServer: boolean, + callExpressionProvidedTool: boolean +): void { if (!result.selector && positional.length > 0 && !callExpressionProvidedServer && !result.server) { result.selector = positional.shift(); } - if ( !result.server && result.selector && shouldPromoteSelectorToCommand(result.selector) && !result.ephemeral?.stdioCommand ) { - // Treat the first positional token as an ad-hoc stdio command when it looks like - // `npx ...`/`./script`/etc., so users can skip `--stdio` entirely. result.ephemeral = { ...result.ephemeral, stdioCommand: result.selector }; result.selector = undefined; } - const nextPositional = positional[0]; if ( !result.tool && @@ -149,7 +169,9 @@ export function parseCallArguments(args: string[]): CallArgsParseResult { ) { result.tool = positional.shift(); } +} +function applyTrailingArguments(positional: string[], result: CallArgsParseResult, state: FlagParseState): void { const trailingPositional: unknown[] = []; for (let index = 0; index < positional.length; ) { const token = positional[index]; @@ -159,12 +181,12 @@ export function parseCallArguments(args: string[]): CallArgsParseResult { } const parsed = parseKeyValueToken(token, positional[index + 1]); if (!parsed) { - trailingPositional.push(coerceValue(token, flagState.coercionMode)); + trailingPositional.push(coerceValue(token, state.coercionMode)); index += 1; continue; } index += parsed.consumed; - const value = coerceValue(parsed.rawValue, flagState.coercionMode); + const value = coerceValue(parsed.rawValue, state.coercionMode); if (parsed.key === 'tool' && !result.tool) { if (typeof value !== 'string') { throw new Error("Argument 'tool' must be a string value."); @@ -184,13 +206,20 @@ export function parseCallArguments(args: string[]): CallArgsParseResult { if (trailingPositional.length > 0) { result.positionalArgs = [...(result.positionalArgs ?? []), ...trailingPositional]; } - if (literalPositional.length > 0) { - result.positionalArgs = [ - ...(result.positionalArgs ?? []), - ...literalPositional.map((token) => coerceValue(token, flagState.coercionMode)), - ]; +} + +function appendLiteralPositionalArguments( + literalPositional: string[], + result: CallArgsParseResult, + state: FlagParseState +): void { + if (literalPositional.length === 0) { + return; } - return result; + result.positionalArgs = [ + ...(result.positionalArgs ?? []), + ...literalPositional.map((token) => coerceValue(token, state.coercionMode)), + ]; } function handleServerFlag(context: FlagHandlerContext): number { diff --git a/src/cli/call-command.ts b/src/cli/call-command.ts index 85f346e..e025729 100644 --- a/src/cli/call-command.ts +++ b/src/cli/call-command.ts @@ -1,6 +1,12 @@ import { analyzeConnectionError, type ConnectionIssue } from '../error-classifier.js'; import { wrapCallResult } from '../result-utils.js'; import { type CallArgsParseResult, parseCallArguments } from './call-arguments.js'; +import { + CALL_HELP_ADHOC_SERVER_LINES, + CALL_HELP_ARGUMENT_LINES, + CALL_HELP_EXAMPLE_LINES, + CALL_HELP_RUNTIME_FLAG_LINES, +} from './call-help.js'; import { prepareEphemeralServerTarget } from './ephemeral-target.js'; import { looksLikeHttpUrl, normalizeHttpUrlCandidate } from './http-utils.js'; import type { IdentifierResolution } from './identifier-helpers.js'; @@ -170,37 +176,16 @@ export function printCallHelp(): void { ' --tool Override the tool name.', '', 'Arguments:', - ' key=value / key:value Flag-style named arguments.', - ' function-call syntax \'server.tool(arg: "value", other: 1)\'.', - ' --args Provide a JSON object payload.', - ' positional values Accepted when schema order is known.', - ' -- Treat remaining tokens as literal positional values.', + ...CALL_HELP_ARGUMENT_LINES, '', 'Runtime flags:', - ' --timeout Override the call timeout.', - ' --output text|markdown|json|raw Control formatting.', - ' --save-images Save image content blocks to a directory.', - ' --raw-strings Keep numeric-looking argument values as strings.', - ' --no-coerce Keep all key/value and positional arguments as raw strings.', - ' --tail-log Stream returned log handles.', + ...CALL_HELP_RUNTIME_FLAG_LINES, '', 'Ad-hoc servers:', - ' --http-url Register an HTTP server for this run.', - ' --allow-http Permit plain http:// URLs with --http-url.', - ' --stdio Run a stdio MCP server (repeat --stdio-arg for args).', - ' --stdio-arg Append args to the stdio command (repeatable).', - ' --env KEY=value Inject env vars for stdio servers (repeatable).', - ' --cwd Working directory for stdio servers.', - ' --name Override the display name for ad-hoc servers.', - ' --description Override the description for ad-hoc servers.', - ' --persist Write the ad-hoc definition to config/mcporter.json.', - ' --yes Skip confirmation prompts when persisting.', + ...CALL_HELP_ADHOC_SERVER_LINES, '', 'Examples:', - ' mcporter call linear.list_issues team=ENG limit:5', - ' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"', - ' mcporter call https://api.example.com/mcp.fetch url:https://example.com', - ' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com', + ...CALL_HELP_EXAMPLE_LINES, ]; console.error(lines.join('\n')); } diff --git a/src/cli/call-help.ts b/src/cli/call-help.ts new file mode 100644 index 0000000..451ae17 --- /dev/null +++ b/src/cli/call-help.ts @@ -0,0 +1,46 @@ +export const CALL_HELP_ARGUMENT_LINES = [ + ' key=value / key:value Flag-style named arguments.', + ' function-call syntax \'server.tool(arg: "value", other: 1)\'.', + ' --args Provide a JSON object payload.', + ' positional values Accepted when schema order is known.', + ' -- Treat remaining tokens as literal positional values.', +] as const; + +export const CALL_HELP_RUNTIME_FLAG_LINES = [ + ' --timeout Override the call timeout.', + ' --output text|markdown|json|raw Control formatting.', + ' --save-images Save image content blocks to a directory.', + ' --raw-strings Keep numeric-looking argument values as strings.', + ' --no-coerce Keep all key/value and positional arguments as raw strings.', + ' --tail-log Stream returned log handles.', +] as const; + +export const CALL_HELP_ADHOC_SERVER_LINES = [ + ' --http-url Register an HTTP server for this run.', + ' --allow-http Permit plain http:// URLs with --http-url.', + ' --stdio Run a stdio MCP server (repeat --stdio-arg for args).', + ' --stdio-arg Append args to the stdio command (repeatable).', + ' --env KEY=value Inject env vars for stdio servers (repeatable).', + ' --cwd Working directory for stdio servers.', + ' --name Override the display name for ad-hoc servers.', + ' --description Override the description for ad-hoc servers.', + ' --persist Write the ad-hoc definition to config/mcporter.json.', + ' --yes Skip confirmation prompts when persisting.', +] as const; + +export const CALL_HELP_EXAMPLE_LINES = [ + ' mcporter call linear.list_issues team=ENG limit:5', + ' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"', + ' mcporter call https://api.example.com/mcp.fetch url:https://example.com', + ' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com', +] as const; + +export function buildUnknownCallFlagMessage(token: string): string { + const argumentName = token.startsWith('--') ? token.slice(2) : token; + return [ + `Unknown flag '${token}' passed to call command.`, + `If you intended to pass a tool argument, use '${argumentName}=' or --args '{"${argumentName}": ...}'.`, + "If you intended to pass a literal positional value, insert '--' before it.", + "Run 'mcporter call --help' to see available flags.", + ].join('\n'); +} diff --git a/src/runtime/oauth.ts b/src/runtime/oauth.ts index 58ad631..b15d31f 100644 --- a/src/runtime/oauth.ts +++ b/src/runtime/oauth.ts @@ -20,6 +20,12 @@ export interface ConnectWithAuthOptions { recreateTransport?: (transport: OAuthCapableTransport) => Promise; } +interface OAuthConnectState { + activeTransport: OAuthCapableTransport; + attempt: number; + hasCompletedAuthFlow: boolean; +} + export class OAuthTimeoutError extends Error { public readonly timeoutMs: number; public readonly serverName: string; @@ -78,47 +84,60 @@ export async function connectWithAuth( options: ConnectWithAuthOptions = {} ): Promise { const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options; - let activeTransport = transport; - let attempt = 0; - let hasCompletedAuthFlow = false; + const state: OAuthConnectState = { + activeTransport: transport, + attempt: 0, + hasCompletedAuthFlow: false, + }; while (true) { try { - await client.connect(activeTransport); - return activeTransport; + return await attemptTransportConnect(client, state); } catch (error) { const unauthorized = isUnauthorizedError(error); - if (hasCompletedAuthFlow && !unauthorized) { - await closeReplacementTransport(transport, activeTransport); - throw markPostAuthConnectError(error); + if (!shouldRetryAuthorization(state, unauthorized, session)) { + await closeReplacementTransport(transport, state.activeTransport); + throw state.hasCompletedAuthFlow && !unauthorized ? markPostAuthConnectError(error) : error; } - if (!unauthorized || !session) { - await closeReplacementTransport(transport, activeTransport); - throw error; - } - attempt += 1; - if (attempt > maxAttempts) { - await closeReplacementTransport(transport, activeTransport); - throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error; + state.attempt += 1; + if (state.attempt > maxAttempts) { + await closeReplacementTransport(transport, state.activeTransport); + throw state.hasCompletedAuthFlow ? markPostAuthConnectError(error) : error; } logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`); try { - activeTransport = await completeAuthorizationChallenge(activeTransport, session, logger, error, { + state.activeTransport = await completeAuthorizationChallenge(state.activeTransport, session, logger, error, { serverName, oauthTimeoutMs, recreateTransport, }); - hasCompletedAuthFlow = true; + state.hasCompletedAuthFlow = true; logger.info('Authorization code accepted. Retrying connection...'); } catch (authError) { logger.error('OAuth authorization failed while waiting for callback.', authError); - await closeReplacementTransport(transport, activeTransport); + await closeReplacementTransport(transport, state.activeTransport); throw markOAuthFlowError(authError); } } } } +async function attemptTransportConnect(client: Client, state: OAuthConnectState): Promise { + await client.connect(state.activeTransport); + return state.activeTransport; +} + +function shouldRetryAuthorization( + _state: OAuthConnectState, + unauthorized: boolean, + session: OAuthSession | undefined +): session is OAuthSession { + if (!session || !unauthorized) { + return false; + } + return true; +} + async function closeReplacementTransport( originalTransport: OAuthCapableTransport, activeTransport: OAuthCapableTransport diff --git a/src/runtime/transport.ts b/src/runtime/transport.ts index d54a48c..e6281b8 100644 --- a/src/runtime/transport.ts +++ b/src/runtime/transport.ts @@ -59,6 +59,10 @@ interface ResolvedHttpTransportOptions { authProvider?: OAuthSession['provider']; } +type HttpClientContextAttempt = + | { context: ClientContext; nextDefinition?: undefined } + | { context?: undefined; nextDefinition: ServerDefinition }; + function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void { // STDIO instrumentation is handled via sdk-patches side effects. This helper remains // so runtime callers can opt-in without sprinkling conditional checks everywhere. @@ -144,6 +148,216 @@ async function connectHttpTransport( } } +async function applyCachedOAuthHeaderIfAvailable( + definition: ServerDefinition, + logger: Logger, + allowCachedAuth: boolean | undefined +): Promise { + if (!allowCachedAuth || definition.auth !== 'oauth' || definition.command.kind !== 'http') { + return definition; + } + try { + const cached = await readCachedAccessToken(definition, logger); + if (!cached) { + return definition; + } + const existingHeaders = definition.command.headers ?? {}; + if ('Authorization' in existingHeaders) { + return definition; + } + logger.debug?.(`Using cached OAuth access token for '${definition.name}' (non-interactive).`); + return { + ...definition, + command: { + ...definition.command, + headers: { + ...existingHeaders, + Authorization: `Bearer ${cached}`, + }, + }, + }; + } catch (error) { + logger.debug?.( + `Failed to read cached OAuth token for '${definition.name}': ${ + error instanceof Error ? error.message : String(error) + }` + ); + return definition; + } +} + +async function createStdioClientContext( + client: Client, + definition: ServerDefinition & { command: Extract }, + logger: Logger +): Promise { + const resolvedEnvOverrides = + definition.env && Object.keys(definition.env).length > 0 + ? Object.fromEntries( + Object.entries(definition.env) + .map(([key, raw]) => [key, resolveEnvValue(raw)]) + .filter(([, value]) => value !== '') + ) + : undefined; + const mergedEnv = + resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0 + ? { ...process.env, ...resolvedEnvOverrides } + : { ...process.env }; + const transport = new StdioClientTransport({ + command: resolveCommandArgument(definition.command.command), + args: resolveCommandArguments(definition.command.args), + cwd: definition.command.cwd, + env: mergedEnv, + }); + if (STDIO_TRACE_ENABLED) { + attachStdioTraceLogging(transport, definition.name ?? definition.command.command); + } + try { + await client.connect(transport); + } catch (error) { + await closeTransportAndWait(logger, transport).catch(() => {}); + throw error; + } + return { client, transport, definition, oauthSession: undefined }; +} + +async function retryHttpTransportWithFallback( + client: Client, + definition: ServerDefinition, + logger: Logger, + options: CreateClientContextOptions +): Promise { + let activeDefinition = definition; + while (true) { + const attempt = await attemptHttpClientContext(client, activeDefinition, logger, options); + if (!attempt.nextDefinition) { + return attempt.context; + } + activeDefinition = attempt.nextDefinition; + options.onDefinitionPromoted?.(activeDefinition); + } +} + +async function attemptHttpClientContext( + client: Client, + activeDefinition: ServerDefinition, + logger: Logger, + options: CreateClientContextOptions +): Promise { + const command = activeDefinition.command; + if (command.kind !== 'http') { + throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`); + } + let oauthSession: OAuthSession | undefined; + const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0; + if (shouldEstablishOAuth) { + oauthSession = await createOAuthSession(activeDefinition, logger); + } + const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth); + + try { + const context = await connectPrimaryHttpTransport( + client, + activeDefinition, + command, + transportOptions, + oauthSession, + logger, + options + ); + return { context }; + } catch (primaryError) { + if (shouldAbortSseFallback(primaryError)) { + await closeOAuthSession(oauthSession); + throw primaryError; + } + if (isUnauthorizedError(primaryError)) { + await closeOAuthSession(oauthSession); + const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options); + if (promoted) { + return { nextDefinition: promoted }; + } + oauthSession = undefined; + } + if (primaryError instanceof Error) { + logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`); + } + return { + context: await connectSseFallbackTransport( + client, + activeDefinition, + command, + transportOptions, + oauthSession, + logger, + options + ), + }; + } +} + +async function connectPrimaryHttpTransport( + client: Client, + definition: ServerDefinition, + command: Extract, + transportOptions: ResolvedHttpTransportOptions, + oauthSession: OAuthSession | undefined, + logger: Logger, + options: CreateClientContextOptions +): Promise { + const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions); + const transport = await connectHttpTransport(client, createStreamableTransport(), oauthSession, logger, { + serverName: definition.name, + maxAttempts: options.maxOAuthAttempts, + oauthTimeoutMs: options.oauthTimeoutMs, + recreateTransport: async () => createStreamableTransport(), + }); + return { + client, + transport, + definition, + oauthSession, + }; +} + +async function connectSseFallbackTransport( + client: Client, + definition: ServerDefinition, + command: Extract, + transportOptions: ResolvedHttpTransportOptions, + oauthSession: OAuthSession | undefined, + logger: Logger, + options: CreateClientContextOptions +): Promise { + try { + const transport = await connectHttpTransport( + client, + new SSEClientTransport(command.url, transportOptions), + oauthSession, + logger, + { + serverName: definition.name, + maxAttempts: options.maxOAuthAttempts, + oauthTimeoutMs: options.oauthTimeoutMs, + } + ); + return { client, transport, definition, oauthSession }; + } catch (sseError) { + await closeOAuthSession(oauthSession); + if (sseError instanceof OAuthTimeoutError) { + throw sseError; + } + if (isUnauthorizedError(sseError)) { + const promoted = maybePromoteHttpDefinition(definition, logger, options); + if (promoted) { + options.onDefinitionPromoted?.(promoted); + return retryHttpTransportWithFallback(client, promoted, logger, options); + } + } + throw sseError; + } +} + export async function createClientContext( definition: ServerDefinition, logger: Logger, @@ -151,147 +365,16 @@ export async function createClientContext( options: CreateClientContextOptions = {} ): Promise { const client = new Client(clientInfo); - let activeDefinition = definition; - - if (options.allowCachedAuth && activeDefinition.auth === 'oauth' && activeDefinition.command.kind === 'http') { - try { - const cached = await readCachedAccessToken(activeDefinition, logger); - if (cached) { - const existingHeaders = activeDefinition.command.headers ?? {}; - if (!('Authorization' in existingHeaders)) { - activeDefinition = { - ...activeDefinition, - command: { - ...activeDefinition.command, - headers: { - ...existingHeaders, - Authorization: `Bearer ${cached}`, - }, - }, - }; - logger.debug?.(`Using cached OAuth access token for '${activeDefinition.name}' (non-interactive).`); - } - } - } catch (error) { - logger.debug?.( - `Failed to read cached OAuth token for '${activeDefinition.name}': ${ - error instanceof Error ? error.message : String(error) - }` - ); - } - } + const activeDefinition = await applyCachedOAuthHeaderIfAvailable(definition, logger, options.allowCachedAuth); return withEnvOverrides(activeDefinition.env, async () => { if (activeDefinition.command.kind === 'stdio') { - const resolvedEnvOverrides = - activeDefinition.env && Object.keys(activeDefinition.env).length > 0 - ? Object.fromEntries( - Object.entries(activeDefinition.env) - .map(([key, raw]) => [key, resolveEnvValue(raw)]) - .filter(([, value]) => value !== '') - ) - : undefined; - const mergedEnv = - resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0 - ? { ...process.env, ...resolvedEnvOverrides } - : { ...process.env }; - const transport = new StdioClientTransport({ - command: resolveCommandArgument(activeDefinition.command.command), - args: resolveCommandArguments(activeDefinition.command.args), - cwd: activeDefinition.command.cwd, - env: mergedEnv, - }); - if (STDIO_TRACE_ENABLED) { - attachStdioTraceLogging(transport, activeDefinition.name ?? activeDefinition.command.command); - } - try { - await client.connect(transport); - } catch (error) { - await closeTransportAndWait(logger, transport).catch(() => {}); - throw error; - } - return { client, transport, definition: activeDefinition, oauthSession: undefined }; - } - - while (true) { - const command = activeDefinition.command; - if (command.kind !== 'http') { - throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`); - } - let oauthSession: OAuthSession | undefined; - const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0; - if (shouldEstablishOAuth) { - oauthSession = await createOAuthSession(activeDefinition, logger); - } - const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth); - - try { - const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions); - const streamableTransport = await connectHttpTransport( - client, - createStreamableTransport(), - oauthSession, - logger, - { - serverName: activeDefinition.name, - maxAttempts: options.maxOAuthAttempts, - oauthTimeoutMs: options.oauthTimeoutMs, - recreateTransport: async () => createStreamableTransport(), - } - ); - return { - client, - transport: streamableTransport, - definition: activeDefinition, - oauthSession, - }; - } catch (primaryError) { - if (shouldAbortSseFallback(primaryError)) { - await closeOAuthSession(oauthSession); - throw primaryError; - } - if (isUnauthorizedError(primaryError)) { - await closeOAuthSession(oauthSession); - oauthSession = undefined; - const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options); - if (promoted) { - activeDefinition = promoted; - options.onDefinitionPromoted?.(promoted); - continue; - } - } - if (primaryError instanceof Error) { - logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`); - } - try { - const connectedTransport = await connectHttpTransport( - client, - new SSEClientTransport(command.url, transportOptions), - oauthSession, - logger, - { - serverName: activeDefinition.name, - maxAttempts: options.maxOAuthAttempts, - oauthTimeoutMs: options.oauthTimeoutMs, - } - ); - return { client, transport: connectedTransport, definition: activeDefinition, oauthSession }; - } catch (sseError) { - await closeOAuthSession(oauthSession); - if (sseError instanceof OAuthTimeoutError) { - throw sseError; - } - if (isUnauthorizedError(sseError)) { - const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options); - if (promoted) { - activeDefinition = promoted; - options.onDefinitionPromoted?.(promoted); - continue; - } - } - throw sseError; - } - } + return createStdioClientContext( + client, + activeDefinition as ServerDefinition & { command: Extract }, + logger + ); } + return retryHttpTransportWithFallback(client, activeDefinition, logger, options); }); } diff --git a/tests/helpers/runtime-test-helpers.ts b/tests/helpers/runtime-test-helpers.ts new file mode 100644 index 0000000..27021c5 --- /dev/null +++ b/tests/helpers/runtime-test-helpers.ts @@ -0,0 +1,112 @@ +import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js'; +import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { vi } from 'vitest'; +import type { ServerDefinition } from '../../src/config.js'; +import type { Logger } from '../../src/logging.js'; +import type { OAuthSession } from '../../src/oauth.js'; + +export const clientInfo = { name: 'mcporter', version: '0.0.0-test' }; + +export interface LoggerSpy extends Logger { + info: ReturnType void>>; + warn: ReturnType void>>; + error: ReturnType void>>; +} + +export function createLogger(): LoggerSpy { + return { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }; +} + +export function resetLogger(logger: LoggerSpy): void { + logger.info.mockReset(); + logger.warn.mockReset(); + logger.error.mockReset(); +} + +export function stubHttpDefinition(url: string): ServerDefinition { + return { + name: 'http-server', + command: { kind: 'http', url: new URL(url) }, + source: { kind: 'local', path: '' }, + }; +} + +export function stubOAuthHttpDefinition(url: string): ServerDefinition { + return { + ...stubHttpDefinition(url), + auth: 'oauth', + }; +} + +export function createPromotionRecorder() { + const promotedDefinitions: ServerDefinition[] = []; + return { + promotedDefinitions, + onDefinitionPromoted: (promoted: ServerDefinition) => { + promotedDefinitions.push(promoted); + }, + }; +} + +export function createMockOAuthSession(): OAuthSession { + return { + provider: { + waitForAuthorizationCode: vi.fn(), + } as unknown as OAuthSession['provider'], + waitForAuthorizationCode: vi.fn(), + close: vi.fn(async () => {}), + }; +} + +export class MockTransport implements Transport { + public readonly calls: string[] = []; + public readonly close = vi.fn(async () => {}); + + constructor(private readonly finishAuthImpl?: (code: string) => Promise) {} + + async start(): Promise {} + + async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise {} + + async finishAuth(code: string): Promise { + this.calls.push(code); + if (this.finishAuthImpl) { + await this.finishAuthImpl(code); + } + } +} + +export function createPendingAuthorizationSession() { + const pendingResolvers: Array<(code: string) => void> = []; + const waitForAuthorizationCode = vi.fn( + () => + new Promise((resolve) => { + pendingResolvers.push(resolve); + }) + ); + const session: OAuthSession = { + provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'], + waitForAuthorizationCode, + close: vi.fn(async () => {}), + }; + return { + session, + waitForAuthorizationCode, + pendingResolvers, + resolveNextCode: (code: string) => { + const resolve = pendingResolvers.shift(); + if (!resolve) { + throw new Error(`Missing pending authorization resolver for '${code}'.`); + } + resolve(code); + }, + }; +} + +export async function flushAuthLoop(): Promise { + await new Promise((resolve) => setImmediate(resolve)); +} diff --git a/tests/runtime-oauth-connect.test.ts b/tests/runtime-oauth-connect.test.ts index f128008..82bbc06 100644 --- a/tests/runtime-oauth-connect.test.ts +++ b/tests/runtime-oauth-connect.test.ts @@ -1,65 +1,13 @@ import type { Client } from '@modelcontextprotocol/sdk/client'; import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'; -import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js'; -import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import { describe, expect, it, vi } from 'vitest'; -import type { Logger } from '../src/logging.js'; -import type { OAuthSession } from '../src/oauth.js'; import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError } from '../src/runtime/oauth.js'; - -class MockTransport implements Transport { - public readonly calls: string[] = []; - public readonly close = vi.fn(async () => {}); - - constructor(private readonly finishAuthImpl?: (code: string) => Promise) {} - async start(): Promise {} - async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise {} - async finishAuth(code: string): Promise { - this.calls.push(code); - if (this.finishAuthImpl) { - await this.finishAuthImpl(code); - } - } -} - -function createLogger(): Logger { - return { - info: vi.fn(), - warn: vi.fn(), - error: vi.fn(), - }; -} - -function createPendingAuthorizationSession() { - const pendingResolvers: Array<(code: string) => void> = []; - const waitForAuthorizationCode = vi.fn( - () => - new Promise((resolve) => { - pendingResolvers.push(resolve); - }) - ); - const session: OAuthSession = { - provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'], - waitForAuthorizationCode, - close: vi.fn(async () => {}), - }; - return { - session, - waitForAuthorizationCode, - pendingResolvers, - resolveNextCode: (code: string) => { - const resolve = pendingResolvers.shift(); - if (!resolve) { - throw new Error(`Missing pending authorization resolver for '${code}'.`); - } - resolve(code); - }, - }; -} - -async function flushAuthLoop(): Promise { - await new Promise((resolve) => setImmediate(resolve)); -} +import { + createLogger, + createPendingAuthorizationSession, + flushAuthLoop, + MockTransport, +} from './helpers/runtime-test-helpers.js'; describe('connectWithAuth', () => { it('waits for authorization code and retries connection', async () => { diff --git a/tests/runtime-transport.test.ts b/tests/runtime-transport.test.ts index bf2f629..24311c2 100644 --- a/tests/runtime-transport.test.ts +++ b/tests/runtime-transport.test.ts @@ -28,60 +28,33 @@ import type { ServerDefinition } from '../src/config.js'; import * as oauthModule from '../src/oauth.js'; import { markOAuthFlowError, markPostAuthConnectError } from '../src/runtime/oauth.js'; import { createClientContext } from '../src/runtime/transport.js'; +import { + clientInfo, + createLogger, + createMockOAuthSession, + createPromotionRecorder, + resetLogger, + stubHttpDefinition, + stubOAuthHttpDefinition, +} from './helpers/runtime-test-helpers.js'; -const logger = { - info: vi.fn(), - warn: vi.fn(), - error: vi.fn(), -}; - -const clientInfo = { name: 'mcporter', version: '0.0.0-test' }; +const logger = createLogger(); beforeEach(() => { + resetLogger(logger); mocks.connectWithAuth.mockReset(); mocks.connectWithAuth.mockImplementation(async (client, transport) => { await client.connect(transport); return transport; }); mocks.createOAuthSession.mockReset(); - mocks.createOAuthSession.mockResolvedValue({ - provider: { - waitForAuthorizationCode: vi.fn(), - }, - waitForAuthorizationCode: vi.fn(), - close: vi.fn(async () => {}), - }); + mocks.createOAuthSession.mockResolvedValue(createMockOAuthSession()); }); afterEach(() => { vi.restoreAllMocks(); }); -function stubHttpDefinition(url: string): ServerDefinition { - return { - name: 'http-server', - command: { kind: 'http', url: new URL(url) }, - source: { kind: 'local', path: '' }, - }; -} - -function stubOAuthHttpDefinition(url: string): ServerDefinition { - return { - ...stubHttpDefinition(url), - auth: 'oauth', - }; -} - -function createPromotionRecorder() { - const promotedDefinitions: ServerDefinition[] = []; - return { - promotedDefinitions, - onDefinitionPromoted: (promoted: ServerDefinition) => { - promotedDefinitions.push(promoted); - }, - }; -} - describe('createClientContext (HTTP)', () => { it('falls back to SSE when primary connect fails', async () => { const definition = stubHttpDefinition('https://example.com/mcp');