refactor: split call and runtime helpers

This commit is contained in:
Peter Steinberger 2026-03-29 09:36:35 +09:00
parent 36b2584c17
commit 83cc3b9a4c
No known key found for this signature in database
8 changed files with 524 additions and 329 deletions

View File

@ -6,6 +6,7 @@ import {
parseKeyValueToken,
shouldPromoteSelectorToCommand,
} from './call-argument-values.js';
import { buildUnknownCallFlagMessage } from './call-help.js';
import { extractEphemeralServerFlags } from './ephemeral-flags.js';
import { CliUsageError } from './errors.js';
import { consumeOutputFormat } from './output-format.js';
@ -39,6 +40,16 @@ interface FlagHandlerContext {
type FlagHandler = (context: FlagHandlerContext) => number;
interface ScannedCallTokens {
positional: string[];
literalPositional: string[];
}
interface CallExpressionResolution {
callExpressionProvidedServer: boolean;
callExpressionProvidedTool: boolean;
}
const FLAG_HANDLERS: Record<string, FlagHandler> = {
'--server': handleServerFlag,
'--mcp': handleServerFlag,
@ -60,6 +71,15 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
result.output = consumeOutputFormat(args, {
defaultFormat: 'auto',
});
const { positional, literalPositional } = scanCallTokens(args, result, flagState);
const { callExpressionProvidedServer, callExpressionProvidedTool } = applyLeadingCallExpression(positional, result);
resolveSelectorAndTool(positional, result, callExpressionProvidedServer, callExpressionProvidedTool);
applyTrailingArguments(positional, result, flagState);
appendLiteralPositionalArguments(literalPositional, result, flagState);
return result;
}
function scanCallTokens(args: string[], result: CallArgsParseResult, state: FlagParseState): ScannedCallTokens {
const positional: string[] = [];
const literalPositional: string[] = [];
let index = 0;
@ -75,70 +95,70 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
}
const flagHandler = FLAG_HANDLERS[token];
if (flagHandler) {
index = flagHandler({ args, index, result, state: flagState });
index = flagHandler({ args, index, result, state });
continue;
}
if (token.startsWith('--')) {
throw new CliUsageError(
[
`Unknown flag '${token}' passed to call command.`,
`If you intended to pass a tool argument, use '${token.slice(2)}=<value>' or --args '{"${token.slice(2)}": ...}'.`,
"If you intended to pass a literal positional value, insert '--' before it.",
"Run 'mcporter call --help' to see available flags.",
].join('\n')
);
throw new CliUsageError(buildUnknownCallFlagMessage(token));
}
positional.push(token);
index += 1;
}
return { positional, literalPositional };
}
let callExpressionProvidedServer = false;
let callExpressionProvidedTool = false;
if (positional.length > 0) {
const rawToken = positional[0] ?? '';
const callExpression = parseLeadingCallExpression(rawToken);
if (callExpression) {
positional.shift();
callExpressionProvidedServer = Boolean(callExpression.server);
callExpressionProvidedTool = Boolean(callExpression.tool);
if (callExpression.server) {
if (result.server && result.server !== callExpression.server) {
throw new Error(
`Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.`
);
}
result.server = result.server ?? callExpression.server;
}
if (result.tool && result.tool !== callExpression.tool) {
throw new Error(
`Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.`
);
}
result.tool = callExpression.tool;
Object.assign(result.args, callExpression.args);
if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) {
result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs];
}
}
function applyLeadingCallExpression(positional: string[], result: CallArgsParseResult): CallExpressionResolution {
if (positional.length === 0) {
return { callExpressionProvidedServer: false, callExpressionProvidedTool: false };
}
const rawToken = positional[0] ?? '';
const callExpression = parseLeadingCallExpression(rawToken);
if (!callExpression) {
return { callExpressionProvidedServer: false, callExpressionProvidedTool: false };
}
positional.shift();
if (callExpression.server) {
if (result.server && result.server !== callExpression.server) {
throw new Error(
`Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.`
);
}
result.server = result.server ?? callExpression.server;
}
if (result.tool && result.tool !== callExpression.tool) {
throw new Error(
`Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.`
);
}
result.tool = callExpression.tool;
Object.assign(result.args, callExpression.args);
if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) {
result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs];
}
return {
callExpressionProvidedServer: Boolean(callExpression.server),
callExpressionProvidedTool: Boolean(callExpression.tool),
};
}
function resolveSelectorAndTool(
positional: string[],
result: CallArgsParseResult,
callExpressionProvidedServer: boolean,
callExpressionProvidedTool: boolean
): void {
if (!result.selector && positional.length > 0 && !callExpressionProvidedServer && !result.server) {
result.selector = positional.shift();
}
if (
!result.server &&
result.selector &&
shouldPromoteSelectorToCommand(result.selector) &&
!result.ephemeral?.stdioCommand
) {
// Treat the first positional token as an ad-hoc stdio command when it looks like
// `npx ...`/`./script`/etc., so users can skip `--stdio` entirely.
result.ephemeral = { ...result.ephemeral, stdioCommand: result.selector };
result.selector = undefined;
}
const nextPositional = positional[0];
if (
!result.tool &&
@ -149,7 +169,9 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
) {
result.tool = positional.shift();
}
}
function applyTrailingArguments(positional: string[], result: CallArgsParseResult, state: FlagParseState): void {
const trailingPositional: unknown[] = [];
for (let index = 0; index < positional.length; ) {
const token = positional[index];
@ -159,12 +181,12 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
}
const parsed = parseKeyValueToken(token, positional[index + 1]);
if (!parsed) {
trailingPositional.push(coerceValue(token, flagState.coercionMode));
trailingPositional.push(coerceValue(token, state.coercionMode));
index += 1;
continue;
}
index += parsed.consumed;
const value = coerceValue(parsed.rawValue, flagState.coercionMode);
const value = coerceValue(parsed.rawValue, state.coercionMode);
if (parsed.key === 'tool' && !result.tool) {
if (typeof value !== 'string') {
throw new Error("Argument 'tool' must be a string value.");
@ -184,13 +206,20 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
if (trailingPositional.length > 0) {
result.positionalArgs = [...(result.positionalArgs ?? []), ...trailingPositional];
}
if (literalPositional.length > 0) {
result.positionalArgs = [
...(result.positionalArgs ?? []),
...literalPositional.map((token) => coerceValue(token, flagState.coercionMode)),
];
}
function appendLiteralPositionalArguments(
literalPositional: string[],
result: CallArgsParseResult,
state: FlagParseState
): void {
if (literalPositional.length === 0) {
return;
}
return result;
result.positionalArgs = [
...(result.positionalArgs ?? []),
...literalPositional.map((token) => coerceValue(token, state.coercionMode)),
];
}
function handleServerFlag(context: FlagHandlerContext): number {

View File

@ -1,6 +1,12 @@
import { analyzeConnectionError, type ConnectionIssue } from '../error-classifier.js';
import { wrapCallResult } from '../result-utils.js';
import { type CallArgsParseResult, parseCallArguments } from './call-arguments.js';
import {
CALL_HELP_ADHOC_SERVER_LINES,
CALL_HELP_ARGUMENT_LINES,
CALL_HELP_EXAMPLE_LINES,
CALL_HELP_RUNTIME_FLAG_LINES,
} from './call-help.js';
import { prepareEphemeralServerTarget } from './ephemeral-target.js';
import { looksLikeHttpUrl, normalizeHttpUrlCandidate } from './http-utils.js';
import type { IdentifierResolution } from './identifier-helpers.js';
@ -170,37 +176,16 @@ export function printCallHelp(): void {
' --tool <name> Override the tool name.',
'',
'Arguments:',
' key=value / key:value Flag-style named arguments.',
' function-call syntax \'server.tool(arg: "value", other: 1)\'.',
' --args <json> Provide a JSON object payload.',
' positional values Accepted when schema order is known.',
' -- Treat remaining tokens as literal positional values.',
...CALL_HELP_ARGUMENT_LINES,
'',
'Runtime flags:',
' --timeout <ms> Override the call timeout.',
' --output text|markdown|json|raw Control formatting.',
' --save-images <dir> Save image content blocks to a directory.',
' --raw-strings Keep numeric-looking argument values as strings.',
' --no-coerce Keep all key/value and positional arguments as raw strings.',
' --tail-log Stream returned log handles.',
...CALL_HELP_RUNTIME_FLAG_LINES,
'',
'Ad-hoc servers:',
' --http-url <url> Register an HTTP server for this run.',
' --allow-http Permit plain http:// URLs with --http-url.',
' --stdio <command> Run a stdio MCP server (repeat --stdio-arg for args).',
' --stdio-arg <value> Append args to the stdio command (repeatable).',
' --env KEY=value Inject env vars for stdio servers (repeatable).',
' --cwd <path> Working directory for stdio servers.',
' --name <value> Override the display name for ad-hoc servers.',
' --description <text> Override the description for ad-hoc servers.',
' --persist <path> Write the ad-hoc definition to config/mcporter.json.',
' --yes Skip confirmation prompts when persisting.',
...CALL_HELP_ADHOC_SERVER_LINES,
'',
'Examples:',
' mcporter call linear.list_issues team=ENG limit:5',
' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"',
' mcporter call https://api.example.com/mcp.fetch url:https://example.com',
' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com',
...CALL_HELP_EXAMPLE_LINES,
];
console.error(lines.join('\n'));
}

46
src/cli/call-help.ts Normal file
View File

@ -0,0 +1,46 @@
export const CALL_HELP_ARGUMENT_LINES = [
' key=value / key:value Flag-style named arguments.',
' function-call syntax \'server.tool(arg: "value", other: 1)\'.',
' --args <json> Provide a JSON object payload.',
' positional values Accepted when schema order is known.',
' -- Treat remaining tokens as literal positional values.',
] as const;
export const CALL_HELP_RUNTIME_FLAG_LINES = [
' --timeout <ms> Override the call timeout.',
' --output text|markdown|json|raw Control formatting.',
' --save-images <dir> Save image content blocks to a directory.',
' --raw-strings Keep numeric-looking argument values as strings.',
' --no-coerce Keep all key/value and positional arguments as raw strings.',
' --tail-log Stream returned log handles.',
] as const;
export const CALL_HELP_ADHOC_SERVER_LINES = [
' --http-url <url> Register an HTTP server for this run.',
' --allow-http Permit plain http:// URLs with --http-url.',
' --stdio <command> Run a stdio MCP server (repeat --stdio-arg for args).',
' --stdio-arg <value> Append args to the stdio command (repeatable).',
' --env KEY=value Inject env vars for stdio servers (repeatable).',
' --cwd <path> Working directory for stdio servers.',
' --name <value> Override the display name for ad-hoc servers.',
' --description <text> Override the description for ad-hoc servers.',
' --persist <path> Write the ad-hoc definition to config/mcporter.json.',
' --yes Skip confirmation prompts when persisting.',
] as const;
export const CALL_HELP_EXAMPLE_LINES = [
' mcporter call linear.list_issues team=ENG limit:5',
' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"',
' mcporter call https://api.example.com/mcp.fetch url:https://example.com',
' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com',
] as const;
export function buildUnknownCallFlagMessage(token: string): string {
const argumentName = token.startsWith('--') ? token.slice(2) : token;
return [
`Unknown flag '${token}' passed to call command.`,
`If you intended to pass a tool argument, use '${argumentName}=<value>' or --args '{"${argumentName}": ...}'.`,
"If you intended to pass a literal positional value, insert '--' before it.",
"Run 'mcporter call --help' to see available flags.",
].join('\n');
}

View File

@ -20,6 +20,12 @@ export interface ConnectWithAuthOptions {
recreateTransport?: (transport: OAuthCapableTransport) => Promise<OAuthCapableTransport>;
}
interface OAuthConnectState {
activeTransport: OAuthCapableTransport;
attempt: number;
hasCompletedAuthFlow: boolean;
}
export class OAuthTimeoutError extends Error {
public readonly timeoutMs: number;
public readonly serverName: string;
@ -78,47 +84,60 @@ export async function connectWithAuth(
options: ConnectWithAuthOptions = {}
): Promise<OAuthCapableTransport> {
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options;
let activeTransport = transport;
let attempt = 0;
let hasCompletedAuthFlow = false;
const state: OAuthConnectState = {
activeTransport: transport,
attempt: 0,
hasCompletedAuthFlow: false,
};
while (true) {
try {
await client.connect(activeTransport);
return activeTransport;
return await attemptTransportConnect(client, state);
} catch (error) {
const unauthorized = isUnauthorizedError(error);
if (hasCompletedAuthFlow && !unauthorized) {
await closeReplacementTransport(transport, activeTransport);
throw markPostAuthConnectError(error);
if (!shouldRetryAuthorization(state, unauthorized, session)) {
await closeReplacementTransport(transport, state.activeTransport);
throw state.hasCompletedAuthFlow && !unauthorized ? markPostAuthConnectError(error) : error;
}
if (!unauthorized || !session) {
await closeReplacementTransport(transport, activeTransport);
throw error;
}
attempt += 1;
if (attempt > maxAttempts) {
await closeReplacementTransport(transport, activeTransport);
throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
state.attempt += 1;
if (state.attempt > maxAttempts) {
await closeReplacementTransport(transport, state.activeTransport);
throw state.hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
}
logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`);
try {
activeTransport = await completeAuthorizationChallenge(activeTransport, session, logger, error, {
state.activeTransport = await completeAuthorizationChallenge(state.activeTransport, session, logger, error, {
serverName,
oauthTimeoutMs,
recreateTransport,
});
hasCompletedAuthFlow = true;
state.hasCompletedAuthFlow = true;
logger.info('Authorization code accepted. Retrying connection...');
} catch (authError) {
logger.error('OAuth authorization failed while waiting for callback.', authError);
await closeReplacementTransport(transport, activeTransport);
await closeReplacementTransport(transport, state.activeTransport);
throw markOAuthFlowError(authError);
}
}
}
}
async function attemptTransportConnect(client: Client, state: OAuthConnectState): Promise<OAuthCapableTransport> {
await client.connect(state.activeTransport);
return state.activeTransport;
}
function shouldRetryAuthorization(
_state: OAuthConnectState,
unauthorized: boolean,
session: OAuthSession | undefined
): session is OAuthSession {
if (!session || !unauthorized) {
return false;
}
return true;
}
async function closeReplacementTransport(
originalTransport: OAuthCapableTransport,
activeTransport: OAuthCapableTransport

View File

@ -59,6 +59,10 @@ interface ResolvedHttpTransportOptions {
authProvider?: OAuthSession['provider'];
}
type HttpClientContextAttempt =
| { context: ClientContext; nextDefinition?: undefined }
| { context?: undefined; nextDefinition: ServerDefinition };
function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void {
// STDIO instrumentation is handled via sdk-patches side effects. This helper remains
// so runtime callers can opt-in without sprinkling conditional checks everywhere.
@ -144,6 +148,216 @@ async function connectHttpTransport<TTransport extends OAuthCapableTransport>(
}
}
async function applyCachedOAuthHeaderIfAvailable(
definition: ServerDefinition,
logger: Logger,
allowCachedAuth: boolean | undefined
): Promise<ServerDefinition> {
if (!allowCachedAuth || definition.auth !== 'oauth' || definition.command.kind !== 'http') {
return definition;
}
try {
const cached = await readCachedAccessToken(definition, logger);
if (!cached) {
return definition;
}
const existingHeaders = definition.command.headers ?? {};
if ('Authorization' in existingHeaders) {
return definition;
}
logger.debug?.(`Using cached OAuth access token for '${definition.name}' (non-interactive).`);
return {
...definition,
command: {
...definition.command,
headers: {
...existingHeaders,
Authorization: `Bearer ${cached}`,
},
},
};
} catch (error) {
logger.debug?.(
`Failed to read cached OAuth token for '${definition.name}': ${
error instanceof Error ? error.message : String(error)
}`
);
return definition;
}
}
async function createStdioClientContext(
client: Client,
definition: ServerDefinition & { command: Extract<ServerDefinition['command'], { kind: 'stdio' }> },
logger: Logger
): Promise<ClientContext> {
const resolvedEnvOverrides =
definition.env && Object.keys(definition.env).length > 0
? Object.fromEntries(
Object.entries(definition.env)
.map(([key, raw]) => [key, resolveEnvValue(raw)])
.filter(([, value]) => value !== '')
)
: undefined;
const mergedEnv =
resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0
? { ...process.env, ...resolvedEnvOverrides }
: { ...process.env };
const transport = new StdioClientTransport({
command: resolveCommandArgument(definition.command.command),
args: resolveCommandArguments(definition.command.args),
cwd: definition.command.cwd,
env: mergedEnv,
});
if (STDIO_TRACE_ENABLED) {
attachStdioTraceLogging(transport, definition.name ?? definition.command.command);
}
try {
await client.connect(transport);
} catch (error) {
await closeTransportAndWait(logger, transport).catch(() => {});
throw error;
}
return { client, transport, definition, oauthSession: undefined };
}
async function retryHttpTransportWithFallback(
client: Client,
definition: ServerDefinition,
logger: Logger,
options: CreateClientContextOptions
): Promise<ClientContext> {
let activeDefinition = definition;
while (true) {
const attempt = await attemptHttpClientContext(client, activeDefinition, logger, options);
if (!attempt.nextDefinition) {
return attempt.context;
}
activeDefinition = attempt.nextDefinition;
options.onDefinitionPromoted?.(activeDefinition);
}
}
async function attemptHttpClientContext(
client: Client,
activeDefinition: ServerDefinition,
logger: Logger,
options: CreateClientContextOptions
): Promise<HttpClientContextAttempt> {
const command = activeDefinition.command;
if (command.kind !== 'http') {
throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`);
}
let oauthSession: OAuthSession | undefined;
const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0;
if (shouldEstablishOAuth) {
oauthSession = await createOAuthSession(activeDefinition, logger);
}
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
try {
const context = await connectPrimaryHttpTransport(
client,
activeDefinition,
command,
transportOptions,
oauthSession,
logger,
options
);
return { context };
} catch (primaryError) {
if (shouldAbortSseFallback(primaryError)) {
await closeOAuthSession(oauthSession);
throw primaryError;
}
if (isUnauthorizedError(primaryError)) {
await closeOAuthSession(oauthSession);
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
if (promoted) {
return { nextDefinition: promoted };
}
oauthSession = undefined;
}
if (primaryError instanceof Error) {
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
}
return {
context: await connectSseFallbackTransport(
client,
activeDefinition,
command,
transportOptions,
oauthSession,
logger,
options
),
};
}
}
async function connectPrimaryHttpTransport(
client: Client,
definition: ServerDefinition,
command: Extract<ServerDefinition['command'], { kind: 'http' }>,
transportOptions: ResolvedHttpTransportOptions,
oauthSession: OAuthSession | undefined,
logger: Logger,
options: CreateClientContextOptions
): Promise<ClientContext> {
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
const transport = await connectHttpTransport(client, createStreamableTransport(), oauthSession, logger, {
serverName: definition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
recreateTransport: async () => createStreamableTransport(),
});
return {
client,
transport,
definition,
oauthSession,
};
}
async function connectSseFallbackTransport(
client: Client,
definition: ServerDefinition,
command: Extract<ServerDefinition['command'], { kind: 'http' }>,
transportOptions: ResolvedHttpTransportOptions,
oauthSession: OAuthSession | undefined,
logger: Logger,
options: CreateClientContextOptions
): Promise<ClientContext> {
try {
const transport = await connectHttpTransport(
client,
new SSEClientTransport(command.url, transportOptions),
oauthSession,
logger,
{
serverName: definition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
}
);
return { client, transport, definition, oauthSession };
} catch (sseError) {
await closeOAuthSession(oauthSession);
if (sseError instanceof OAuthTimeoutError) {
throw sseError;
}
if (isUnauthorizedError(sseError)) {
const promoted = maybePromoteHttpDefinition(definition, logger, options);
if (promoted) {
options.onDefinitionPromoted?.(promoted);
return retryHttpTransportWithFallback(client, promoted, logger, options);
}
}
throw sseError;
}
}
export async function createClientContext(
definition: ServerDefinition,
logger: Logger,
@ -151,147 +365,16 @@ export async function createClientContext(
options: CreateClientContextOptions = {}
): Promise<ClientContext> {
const client = new Client(clientInfo);
let activeDefinition = definition;
if (options.allowCachedAuth && activeDefinition.auth === 'oauth' && activeDefinition.command.kind === 'http') {
try {
const cached = await readCachedAccessToken(activeDefinition, logger);
if (cached) {
const existingHeaders = activeDefinition.command.headers ?? {};
if (!('Authorization' in existingHeaders)) {
activeDefinition = {
...activeDefinition,
command: {
...activeDefinition.command,
headers: {
...existingHeaders,
Authorization: `Bearer ${cached}`,
},
},
};
logger.debug?.(`Using cached OAuth access token for '${activeDefinition.name}' (non-interactive).`);
}
}
} catch (error) {
logger.debug?.(
`Failed to read cached OAuth token for '${activeDefinition.name}': ${
error instanceof Error ? error.message : String(error)
}`
);
}
}
const activeDefinition = await applyCachedOAuthHeaderIfAvailable(definition, logger, options.allowCachedAuth);
return withEnvOverrides(activeDefinition.env, async () => {
if (activeDefinition.command.kind === 'stdio') {
const resolvedEnvOverrides =
activeDefinition.env && Object.keys(activeDefinition.env).length > 0
? Object.fromEntries(
Object.entries(activeDefinition.env)
.map(([key, raw]) => [key, resolveEnvValue(raw)])
.filter(([, value]) => value !== '')
)
: undefined;
const mergedEnv =
resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0
? { ...process.env, ...resolvedEnvOverrides }
: { ...process.env };
const transport = new StdioClientTransport({
command: resolveCommandArgument(activeDefinition.command.command),
args: resolveCommandArguments(activeDefinition.command.args),
cwd: activeDefinition.command.cwd,
env: mergedEnv,
});
if (STDIO_TRACE_ENABLED) {
attachStdioTraceLogging(transport, activeDefinition.name ?? activeDefinition.command.command);
}
try {
await client.connect(transport);
} catch (error) {
await closeTransportAndWait(logger, transport).catch(() => {});
throw error;
}
return { client, transport, definition: activeDefinition, oauthSession: undefined };
}
while (true) {
const command = activeDefinition.command;
if (command.kind !== 'http') {
throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`);
}
let oauthSession: OAuthSession | undefined;
const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0;
if (shouldEstablishOAuth) {
oauthSession = await createOAuthSession(activeDefinition, logger);
}
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
try {
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
const streamableTransport = await connectHttpTransport(
client,
createStreamableTransport(),
oauthSession,
logger,
{
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
recreateTransport: async () => createStreamableTransport(),
}
);
return {
client,
transport: streamableTransport,
definition: activeDefinition,
oauthSession,
};
} catch (primaryError) {
if (shouldAbortSseFallback(primaryError)) {
await closeOAuthSession(oauthSession);
throw primaryError;
}
if (isUnauthorizedError(primaryError)) {
await closeOAuthSession(oauthSession);
oauthSession = undefined;
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
if (promoted) {
activeDefinition = promoted;
options.onDefinitionPromoted?.(promoted);
continue;
}
}
if (primaryError instanceof Error) {
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
}
try {
const connectedTransport = await connectHttpTransport(
client,
new SSEClientTransport(command.url, transportOptions),
oauthSession,
logger,
{
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
}
);
return { client, transport: connectedTransport, definition: activeDefinition, oauthSession };
} catch (sseError) {
await closeOAuthSession(oauthSession);
if (sseError instanceof OAuthTimeoutError) {
throw sseError;
}
if (isUnauthorizedError(sseError)) {
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
if (promoted) {
activeDefinition = promoted;
options.onDefinitionPromoted?.(promoted);
continue;
}
}
throw sseError;
}
}
return createStdioClientContext(
client,
activeDefinition as ServerDefinition & { command: Extract<ServerDefinition['command'], { kind: 'stdio' }> },
logger
);
}
return retryHttpTransportWithFallback(client, activeDefinition, logger, options);
});
}

View File

@ -0,0 +1,112 @@
import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import { vi } from 'vitest';
import type { ServerDefinition } from '../../src/config.js';
import type { Logger } from '../../src/logging.js';
import type { OAuthSession } from '../../src/oauth.js';
export const clientInfo = { name: 'mcporter', version: '0.0.0-test' };
export interface LoggerSpy extends Logger {
info: ReturnType<typeof vi.fn<(message: string) => void>>;
warn: ReturnType<typeof vi.fn<(message: string) => void>>;
error: ReturnType<typeof vi.fn<(message: string, error?: unknown) => void>>;
}
export function createLogger(): LoggerSpy {
return {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
}
export function resetLogger(logger: LoggerSpy): void {
logger.info.mockReset();
logger.warn.mockReset();
logger.error.mockReset();
}
export function stubHttpDefinition(url: string): ServerDefinition {
return {
name: 'http-server',
command: { kind: 'http', url: new URL(url) },
source: { kind: 'local', path: '<adhoc>' },
};
}
export function stubOAuthHttpDefinition(url: string): ServerDefinition {
return {
...stubHttpDefinition(url),
auth: 'oauth',
};
}
export function createPromotionRecorder() {
const promotedDefinitions: ServerDefinition[] = [];
return {
promotedDefinitions,
onDefinitionPromoted: (promoted: ServerDefinition) => {
promotedDefinitions.push(promoted);
},
};
}
export function createMockOAuthSession(): OAuthSession {
return {
provider: {
waitForAuthorizationCode: vi.fn(),
} as unknown as OAuthSession['provider'],
waitForAuthorizationCode: vi.fn(),
close: vi.fn(async () => {}),
};
}
export class MockTransport implements Transport {
public readonly calls: string[] = [];
public readonly close = vi.fn(async () => {});
constructor(private readonly finishAuthImpl?: (code: string) => Promise<void>) {}
async start(): Promise<void> {}
async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise<void> {}
async finishAuth(code: string): Promise<void> {
this.calls.push(code);
if (this.finishAuthImpl) {
await this.finishAuthImpl(code);
}
}
}
export function createPendingAuthorizationSession() {
const pendingResolvers: Array<(code: string) => void> = [];
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
pendingResolvers.push(resolve);
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
return {
session,
waitForAuthorizationCode,
pendingResolvers,
resolveNextCode: (code: string) => {
const resolve = pendingResolvers.shift();
if (!resolve) {
throw new Error(`Missing pending authorization resolver for '${code}'.`);
}
resolve(code);
},
};
}
export async function flushAuthLoop(): Promise<void> {
await new Promise((resolve) => setImmediate(resolve));
}

View File

@ -1,65 +1,13 @@
import type { Client } from '@modelcontextprotocol/sdk/client';
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js';
import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import { describe, expect, it, vi } from 'vitest';
import type { Logger } from '../src/logging.js';
import type { OAuthSession } from '../src/oauth.js';
import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError } from '../src/runtime/oauth.js';
class MockTransport implements Transport {
public readonly calls: string[] = [];
public readonly close = vi.fn(async () => {});
constructor(private readonly finishAuthImpl?: (code: string) => Promise<void>) {}
async start(): Promise<void> {}
async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise<void> {}
async finishAuth(code: string): Promise<void> {
this.calls.push(code);
if (this.finishAuthImpl) {
await this.finishAuthImpl(code);
}
}
}
function createLogger(): Logger {
return {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
}
function createPendingAuthorizationSession() {
const pendingResolvers: Array<(code: string) => void> = [];
const waitForAuthorizationCode = vi.fn(
() =>
new Promise<string>((resolve) => {
pendingResolvers.push(resolve);
})
);
const session: OAuthSession = {
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
waitForAuthorizationCode,
close: vi.fn(async () => {}),
};
return {
session,
waitForAuthorizationCode,
pendingResolvers,
resolveNextCode: (code: string) => {
const resolve = pendingResolvers.shift();
if (!resolve) {
throw new Error(`Missing pending authorization resolver for '${code}'.`);
}
resolve(code);
},
};
}
async function flushAuthLoop(): Promise<void> {
await new Promise((resolve) => setImmediate(resolve));
}
import {
createLogger,
createPendingAuthorizationSession,
flushAuthLoop,
MockTransport,
} from './helpers/runtime-test-helpers.js';
describe('connectWithAuth', () => {
it('waits for authorization code and retries connection', async () => {

View File

@ -28,60 +28,33 @@ import type { ServerDefinition } from '../src/config.js';
import * as oauthModule from '../src/oauth.js';
import { markOAuthFlowError, markPostAuthConnectError } from '../src/runtime/oauth.js';
import { createClientContext } from '../src/runtime/transport.js';
import {
clientInfo,
createLogger,
createMockOAuthSession,
createPromotionRecorder,
resetLogger,
stubHttpDefinition,
stubOAuthHttpDefinition,
} from './helpers/runtime-test-helpers.js';
const logger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
};
const clientInfo = { name: 'mcporter', version: '0.0.0-test' };
const logger = createLogger();
beforeEach(() => {
resetLogger(logger);
mocks.connectWithAuth.mockReset();
mocks.connectWithAuth.mockImplementation(async (client, transport) => {
await client.connect(transport);
return transport;
});
mocks.createOAuthSession.mockReset();
mocks.createOAuthSession.mockResolvedValue({
provider: {
waitForAuthorizationCode: vi.fn(),
},
waitForAuthorizationCode: vi.fn(),
close: vi.fn(async () => {}),
});
mocks.createOAuthSession.mockResolvedValue(createMockOAuthSession());
});
afterEach(() => {
vi.restoreAllMocks();
});
function stubHttpDefinition(url: string): ServerDefinition {
return {
name: 'http-server',
command: { kind: 'http', url: new URL(url) },
source: { kind: 'local', path: '<adhoc>' },
};
}
function stubOAuthHttpDefinition(url: string): ServerDefinition {
return {
...stubHttpDefinition(url),
auth: 'oauth',
};
}
function createPromotionRecorder() {
const promotedDefinitions: ServerDefinition[] = [];
return {
promotedDefinitions,
onDefinitionPromoted: (promoted: ServerDefinition) => {
promotedDefinitions.push(promoted);
},
};
}
describe('createClientContext (HTTP)', () => {
it('falls back to SSE when primary connect fails', async () => {
const definition = stubHttpDefinition('https://example.com/mcp');