refactor: split call and runtime helpers
This commit is contained in:
parent
36b2584c17
commit
83cc3b9a4c
@ -6,6 +6,7 @@ import {
|
||||
parseKeyValueToken,
|
||||
shouldPromoteSelectorToCommand,
|
||||
} from './call-argument-values.js';
|
||||
import { buildUnknownCallFlagMessage } from './call-help.js';
|
||||
import { extractEphemeralServerFlags } from './ephemeral-flags.js';
|
||||
import { CliUsageError } from './errors.js';
|
||||
import { consumeOutputFormat } from './output-format.js';
|
||||
@ -39,6 +40,16 @@ interface FlagHandlerContext {
|
||||
|
||||
type FlagHandler = (context: FlagHandlerContext) => number;
|
||||
|
||||
interface ScannedCallTokens {
|
||||
positional: string[];
|
||||
literalPositional: string[];
|
||||
}
|
||||
|
||||
interface CallExpressionResolution {
|
||||
callExpressionProvidedServer: boolean;
|
||||
callExpressionProvidedTool: boolean;
|
||||
}
|
||||
|
||||
const FLAG_HANDLERS: Record<string, FlagHandler> = {
|
||||
'--server': handleServerFlag,
|
||||
'--mcp': handleServerFlag,
|
||||
@ -60,6 +71,15 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
result.output = consumeOutputFormat(args, {
|
||||
defaultFormat: 'auto',
|
||||
});
|
||||
const { positional, literalPositional } = scanCallTokens(args, result, flagState);
|
||||
const { callExpressionProvidedServer, callExpressionProvidedTool } = applyLeadingCallExpression(positional, result);
|
||||
resolveSelectorAndTool(positional, result, callExpressionProvidedServer, callExpressionProvidedTool);
|
||||
applyTrailingArguments(positional, result, flagState);
|
||||
appendLiteralPositionalArguments(literalPositional, result, flagState);
|
||||
return result;
|
||||
}
|
||||
|
||||
function scanCallTokens(args: string[], result: CallArgsParseResult, state: FlagParseState): ScannedCallTokens {
|
||||
const positional: string[] = [];
|
||||
const literalPositional: string[] = [];
|
||||
let index = 0;
|
||||
@ -75,70 +95,70 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
}
|
||||
const flagHandler = FLAG_HANDLERS[token];
|
||||
if (flagHandler) {
|
||||
index = flagHandler({ args, index, result, state: flagState });
|
||||
index = flagHandler({ args, index, result, state });
|
||||
continue;
|
||||
}
|
||||
if (token.startsWith('--')) {
|
||||
throw new CliUsageError(
|
||||
[
|
||||
`Unknown flag '${token}' passed to call command.`,
|
||||
`If you intended to pass a tool argument, use '${token.slice(2)}=<value>' or --args '{"${token.slice(2)}": ...}'.`,
|
||||
"If you intended to pass a literal positional value, insert '--' before it.",
|
||||
"Run 'mcporter call --help' to see available flags.",
|
||||
].join('\n')
|
||||
);
|
||||
throw new CliUsageError(buildUnknownCallFlagMessage(token));
|
||||
}
|
||||
positional.push(token);
|
||||
index += 1;
|
||||
}
|
||||
return { positional, literalPositional };
|
||||
}
|
||||
|
||||
let callExpressionProvidedServer = false;
|
||||
let callExpressionProvidedTool = false;
|
||||
|
||||
if (positional.length > 0) {
|
||||
const rawToken = positional[0] ?? '';
|
||||
const callExpression = parseLeadingCallExpression(rawToken);
|
||||
if (callExpression) {
|
||||
positional.shift();
|
||||
callExpressionProvidedServer = Boolean(callExpression.server);
|
||||
callExpressionProvidedTool = Boolean(callExpression.tool);
|
||||
if (callExpression.server) {
|
||||
if (result.server && result.server !== callExpression.server) {
|
||||
throw new Error(
|
||||
`Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.`
|
||||
);
|
||||
}
|
||||
result.server = result.server ?? callExpression.server;
|
||||
}
|
||||
if (result.tool && result.tool !== callExpression.tool) {
|
||||
throw new Error(
|
||||
`Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.`
|
||||
);
|
||||
}
|
||||
result.tool = callExpression.tool;
|
||||
Object.assign(result.args, callExpression.args);
|
||||
if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) {
|
||||
result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs];
|
||||
}
|
||||
}
|
||||
function applyLeadingCallExpression(positional: string[], result: CallArgsParseResult): CallExpressionResolution {
|
||||
if (positional.length === 0) {
|
||||
return { callExpressionProvidedServer: false, callExpressionProvidedTool: false };
|
||||
}
|
||||
const rawToken = positional[0] ?? '';
|
||||
const callExpression = parseLeadingCallExpression(rawToken);
|
||||
if (!callExpression) {
|
||||
return { callExpressionProvidedServer: false, callExpressionProvidedTool: false };
|
||||
}
|
||||
positional.shift();
|
||||
if (callExpression.server) {
|
||||
if (result.server && result.server !== callExpression.server) {
|
||||
throw new Error(
|
||||
`Conflicting server names: '${result.server}' from flags and '${callExpression.server}' from call expression.`
|
||||
);
|
||||
}
|
||||
result.server = result.server ?? callExpression.server;
|
||||
}
|
||||
if (result.tool && result.tool !== callExpression.tool) {
|
||||
throw new Error(
|
||||
`Conflicting tool names: '${result.tool}' from flags and '${callExpression.tool}' from call expression.`
|
||||
);
|
||||
}
|
||||
result.tool = callExpression.tool;
|
||||
Object.assign(result.args, callExpression.args);
|
||||
if (callExpression.positionalArgs && callExpression.positionalArgs.length > 0) {
|
||||
result.positionalArgs = [...(result.positionalArgs ?? []), ...callExpression.positionalArgs];
|
||||
}
|
||||
return {
|
||||
callExpressionProvidedServer: Boolean(callExpression.server),
|
||||
callExpressionProvidedTool: Boolean(callExpression.tool),
|
||||
};
|
||||
}
|
||||
|
||||
function resolveSelectorAndTool(
|
||||
positional: string[],
|
||||
result: CallArgsParseResult,
|
||||
callExpressionProvidedServer: boolean,
|
||||
callExpressionProvidedTool: boolean
|
||||
): void {
|
||||
if (!result.selector && positional.length > 0 && !callExpressionProvidedServer && !result.server) {
|
||||
result.selector = positional.shift();
|
||||
}
|
||||
|
||||
if (
|
||||
!result.server &&
|
||||
result.selector &&
|
||||
shouldPromoteSelectorToCommand(result.selector) &&
|
||||
!result.ephemeral?.stdioCommand
|
||||
) {
|
||||
// Treat the first positional token as an ad-hoc stdio command when it looks like
|
||||
// `npx ...`/`./script`/etc., so users can skip `--stdio` entirely.
|
||||
result.ephemeral = { ...result.ephemeral, stdioCommand: result.selector };
|
||||
result.selector = undefined;
|
||||
}
|
||||
|
||||
const nextPositional = positional[0];
|
||||
if (
|
||||
!result.tool &&
|
||||
@ -149,7 +169,9 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
) {
|
||||
result.tool = positional.shift();
|
||||
}
|
||||
}
|
||||
|
||||
function applyTrailingArguments(positional: string[], result: CallArgsParseResult, state: FlagParseState): void {
|
||||
const trailingPositional: unknown[] = [];
|
||||
for (let index = 0; index < positional.length; ) {
|
||||
const token = positional[index];
|
||||
@ -159,12 +181,12 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
}
|
||||
const parsed = parseKeyValueToken(token, positional[index + 1]);
|
||||
if (!parsed) {
|
||||
trailingPositional.push(coerceValue(token, flagState.coercionMode));
|
||||
trailingPositional.push(coerceValue(token, state.coercionMode));
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
index += parsed.consumed;
|
||||
const value = coerceValue(parsed.rawValue, flagState.coercionMode);
|
||||
const value = coerceValue(parsed.rawValue, state.coercionMode);
|
||||
if (parsed.key === 'tool' && !result.tool) {
|
||||
if (typeof value !== 'string') {
|
||||
throw new Error("Argument 'tool' must be a string value.");
|
||||
@ -184,13 +206,20 @@ export function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
if (trailingPositional.length > 0) {
|
||||
result.positionalArgs = [...(result.positionalArgs ?? []), ...trailingPositional];
|
||||
}
|
||||
if (literalPositional.length > 0) {
|
||||
result.positionalArgs = [
|
||||
...(result.positionalArgs ?? []),
|
||||
...literalPositional.map((token) => coerceValue(token, flagState.coercionMode)),
|
||||
];
|
||||
}
|
||||
|
||||
function appendLiteralPositionalArguments(
|
||||
literalPositional: string[],
|
||||
result: CallArgsParseResult,
|
||||
state: FlagParseState
|
||||
): void {
|
||||
if (literalPositional.length === 0) {
|
||||
return;
|
||||
}
|
||||
return result;
|
||||
result.positionalArgs = [
|
||||
...(result.positionalArgs ?? []),
|
||||
...literalPositional.map((token) => coerceValue(token, state.coercionMode)),
|
||||
];
|
||||
}
|
||||
|
||||
function handleServerFlag(context: FlagHandlerContext): number {
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
import { analyzeConnectionError, type ConnectionIssue } from '../error-classifier.js';
|
||||
import { wrapCallResult } from '../result-utils.js';
|
||||
import { type CallArgsParseResult, parseCallArguments } from './call-arguments.js';
|
||||
import {
|
||||
CALL_HELP_ADHOC_SERVER_LINES,
|
||||
CALL_HELP_ARGUMENT_LINES,
|
||||
CALL_HELP_EXAMPLE_LINES,
|
||||
CALL_HELP_RUNTIME_FLAG_LINES,
|
||||
} from './call-help.js';
|
||||
import { prepareEphemeralServerTarget } from './ephemeral-target.js';
|
||||
import { looksLikeHttpUrl, normalizeHttpUrlCandidate } from './http-utils.js';
|
||||
import type { IdentifierResolution } from './identifier-helpers.js';
|
||||
@ -170,37 +176,16 @@ export function printCallHelp(): void {
|
||||
' --tool <name> Override the tool name.',
|
||||
'',
|
||||
'Arguments:',
|
||||
' key=value / key:value Flag-style named arguments.',
|
||||
' function-call syntax \'server.tool(arg: "value", other: 1)\'.',
|
||||
' --args <json> Provide a JSON object payload.',
|
||||
' positional values Accepted when schema order is known.',
|
||||
' -- Treat remaining tokens as literal positional values.',
|
||||
...CALL_HELP_ARGUMENT_LINES,
|
||||
'',
|
||||
'Runtime flags:',
|
||||
' --timeout <ms> Override the call timeout.',
|
||||
' --output text|markdown|json|raw Control formatting.',
|
||||
' --save-images <dir> Save image content blocks to a directory.',
|
||||
' --raw-strings Keep numeric-looking argument values as strings.',
|
||||
' --no-coerce Keep all key/value and positional arguments as raw strings.',
|
||||
' --tail-log Stream returned log handles.',
|
||||
...CALL_HELP_RUNTIME_FLAG_LINES,
|
||||
'',
|
||||
'Ad-hoc servers:',
|
||||
' --http-url <url> Register an HTTP server for this run.',
|
||||
' --allow-http Permit plain http:// URLs with --http-url.',
|
||||
' --stdio <command> Run a stdio MCP server (repeat --stdio-arg for args).',
|
||||
' --stdio-arg <value> Append args to the stdio command (repeatable).',
|
||||
' --env KEY=value Inject env vars for stdio servers (repeatable).',
|
||||
' --cwd <path> Working directory for stdio servers.',
|
||||
' --name <value> Override the display name for ad-hoc servers.',
|
||||
' --description <text> Override the description for ad-hoc servers.',
|
||||
' --persist <path> Write the ad-hoc definition to config/mcporter.json.',
|
||||
' --yes Skip confirmation prompts when persisting.',
|
||||
...CALL_HELP_ADHOC_SERVER_LINES,
|
||||
'',
|
||||
'Examples:',
|
||||
' mcporter call linear.list_issues team=ENG limit:5',
|
||||
' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"',
|
||||
' mcporter call https://api.example.com/mcp.fetch url:https://example.com',
|
||||
' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com',
|
||||
...CALL_HELP_EXAMPLE_LINES,
|
||||
];
|
||||
console.error(lines.join('\n'));
|
||||
}
|
||||
|
||||
46
src/cli/call-help.ts
Normal file
46
src/cli/call-help.ts
Normal file
@ -0,0 +1,46 @@
|
||||
export const CALL_HELP_ARGUMENT_LINES = [
|
||||
' key=value / key:value Flag-style named arguments.',
|
||||
' function-call syntax \'server.tool(arg: "value", other: 1)\'.',
|
||||
' --args <json> Provide a JSON object payload.',
|
||||
' positional values Accepted when schema order is known.',
|
||||
' -- Treat remaining tokens as literal positional values.',
|
||||
] as const;
|
||||
|
||||
export const CALL_HELP_RUNTIME_FLAG_LINES = [
|
||||
' --timeout <ms> Override the call timeout.',
|
||||
' --output text|markdown|json|raw Control formatting.',
|
||||
' --save-images <dir> Save image content blocks to a directory.',
|
||||
' --raw-strings Keep numeric-looking argument values as strings.',
|
||||
' --no-coerce Keep all key/value and positional arguments as raw strings.',
|
||||
' --tail-log Stream returned log handles.',
|
||||
] as const;
|
||||
|
||||
export const CALL_HELP_ADHOC_SERVER_LINES = [
|
||||
' --http-url <url> Register an HTTP server for this run.',
|
||||
' --allow-http Permit plain http:// URLs with --http-url.',
|
||||
' --stdio <command> Run a stdio MCP server (repeat --stdio-arg for args).',
|
||||
' --stdio-arg <value> Append args to the stdio command (repeatable).',
|
||||
' --env KEY=value Inject env vars for stdio servers (repeatable).',
|
||||
' --cwd <path> Working directory for stdio servers.',
|
||||
' --name <value> Override the display name for ad-hoc servers.',
|
||||
' --description <text> Override the description for ad-hoc servers.',
|
||||
' --persist <path> Write the ad-hoc definition to config/mcporter.json.',
|
||||
' --yes Skip confirmation prompts when persisting.',
|
||||
] as const;
|
||||
|
||||
export const CALL_HELP_EXAMPLE_LINES = [
|
||||
' mcporter call linear.list_issues team=ENG limit:5',
|
||||
' mcporter call "linear.create_issue(title: \\"Bug\\", team: \\"ENG\\")"',
|
||||
' mcporter call https://api.example.com/mcp.fetch url:https://example.com',
|
||||
' mcporter call --stdio "bun run ./server.ts" scrape url=https://example.com',
|
||||
] as const;
|
||||
|
||||
export function buildUnknownCallFlagMessage(token: string): string {
|
||||
const argumentName = token.startsWith('--') ? token.slice(2) : token;
|
||||
return [
|
||||
`Unknown flag '${token}' passed to call command.`,
|
||||
`If you intended to pass a tool argument, use '${argumentName}=<value>' or --args '{"${argumentName}": ...}'.`,
|
||||
"If you intended to pass a literal positional value, insert '--' before it.",
|
||||
"Run 'mcporter call --help' to see available flags.",
|
||||
].join('\n');
|
||||
}
|
||||
@ -20,6 +20,12 @@ export interface ConnectWithAuthOptions {
|
||||
recreateTransport?: (transport: OAuthCapableTransport) => Promise<OAuthCapableTransport>;
|
||||
}
|
||||
|
||||
interface OAuthConnectState {
|
||||
activeTransport: OAuthCapableTransport;
|
||||
attempt: number;
|
||||
hasCompletedAuthFlow: boolean;
|
||||
}
|
||||
|
||||
export class OAuthTimeoutError extends Error {
|
||||
public readonly timeoutMs: number;
|
||||
public readonly serverName: string;
|
||||
@ -78,47 +84,60 @@ export async function connectWithAuth(
|
||||
options: ConnectWithAuthOptions = {}
|
||||
): Promise<OAuthCapableTransport> {
|
||||
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options;
|
||||
let activeTransport = transport;
|
||||
let attempt = 0;
|
||||
let hasCompletedAuthFlow = false;
|
||||
const state: OAuthConnectState = {
|
||||
activeTransport: transport,
|
||||
attempt: 0,
|
||||
hasCompletedAuthFlow: false,
|
||||
};
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
await client.connect(activeTransport);
|
||||
return activeTransport;
|
||||
return await attemptTransportConnect(client, state);
|
||||
} catch (error) {
|
||||
const unauthorized = isUnauthorizedError(error);
|
||||
if (hasCompletedAuthFlow && !unauthorized) {
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw markPostAuthConnectError(error);
|
||||
if (!shouldRetryAuthorization(state, unauthorized, session)) {
|
||||
await closeReplacementTransport(transport, state.activeTransport);
|
||||
throw state.hasCompletedAuthFlow && !unauthorized ? markPostAuthConnectError(error) : error;
|
||||
}
|
||||
if (!unauthorized || !session) {
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw error;
|
||||
}
|
||||
attempt += 1;
|
||||
if (attempt > maxAttempts) {
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
|
||||
state.attempt += 1;
|
||||
if (state.attempt > maxAttempts) {
|
||||
await closeReplacementTransport(transport, state.activeTransport);
|
||||
throw state.hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
|
||||
}
|
||||
logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`);
|
||||
try {
|
||||
activeTransport = await completeAuthorizationChallenge(activeTransport, session, logger, error, {
|
||||
state.activeTransport = await completeAuthorizationChallenge(state.activeTransport, session, logger, error, {
|
||||
serverName,
|
||||
oauthTimeoutMs,
|
||||
recreateTransport,
|
||||
});
|
||||
hasCompletedAuthFlow = true;
|
||||
state.hasCompletedAuthFlow = true;
|
||||
logger.info('Authorization code accepted. Retrying connection...');
|
||||
} catch (authError) {
|
||||
logger.error('OAuth authorization failed while waiting for callback.', authError);
|
||||
await closeReplacementTransport(transport, activeTransport);
|
||||
await closeReplacementTransport(transport, state.activeTransport);
|
||||
throw markOAuthFlowError(authError);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function attemptTransportConnect(client: Client, state: OAuthConnectState): Promise<OAuthCapableTransport> {
|
||||
await client.connect(state.activeTransport);
|
||||
return state.activeTransport;
|
||||
}
|
||||
|
||||
function shouldRetryAuthorization(
|
||||
_state: OAuthConnectState,
|
||||
unauthorized: boolean,
|
||||
session: OAuthSession | undefined
|
||||
): session is OAuthSession {
|
||||
if (!session || !unauthorized) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async function closeReplacementTransport(
|
||||
originalTransport: OAuthCapableTransport,
|
||||
activeTransport: OAuthCapableTransport
|
||||
|
||||
@ -59,6 +59,10 @@ interface ResolvedHttpTransportOptions {
|
||||
authProvider?: OAuthSession['provider'];
|
||||
}
|
||||
|
||||
type HttpClientContextAttempt =
|
||||
| { context: ClientContext; nextDefinition?: undefined }
|
||||
| { context?: undefined; nextDefinition: ServerDefinition };
|
||||
|
||||
function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void {
|
||||
// STDIO instrumentation is handled via sdk-patches side effects. This helper remains
|
||||
// so runtime callers can opt-in without sprinkling conditional checks everywhere.
|
||||
@ -144,6 +148,216 @@ async function connectHttpTransport<TTransport extends OAuthCapableTransport>(
|
||||
}
|
||||
}
|
||||
|
||||
async function applyCachedOAuthHeaderIfAvailable(
|
||||
definition: ServerDefinition,
|
||||
logger: Logger,
|
||||
allowCachedAuth: boolean | undefined
|
||||
): Promise<ServerDefinition> {
|
||||
if (!allowCachedAuth || definition.auth !== 'oauth' || definition.command.kind !== 'http') {
|
||||
return definition;
|
||||
}
|
||||
try {
|
||||
const cached = await readCachedAccessToken(definition, logger);
|
||||
if (!cached) {
|
||||
return definition;
|
||||
}
|
||||
const existingHeaders = definition.command.headers ?? {};
|
||||
if ('Authorization' in existingHeaders) {
|
||||
return definition;
|
||||
}
|
||||
logger.debug?.(`Using cached OAuth access token for '${definition.name}' (non-interactive).`);
|
||||
return {
|
||||
...definition,
|
||||
command: {
|
||||
...definition.command,
|
||||
headers: {
|
||||
...existingHeaders,
|
||||
Authorization: `Bearer ${cached}`,
|
||||
},
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.debug?.(
|
||||
`Failed to read cached OAuth token for '${definition.name}': ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return definition;
|
||||
}
|
||||
}
|
||||
|
||||
async function createStdioClientContext(
|
||||
client: Client,
|
||||
definition: ServerDefinition & { command: Extract<ServerDefinition['command'], { kind: 'stdio' }> },
|
||||
logger: Logger
|
||||
): Promise<ClientContext> {
|
||||
const resolvedEnvOverrides =
|
||||
definition.env && Object.keys(definition.env).length > 0
|
||||
? Object.fromEntries(
|
||||
Object.entries(definition.env)
|
||||
.map(([key, raw]) => [key, resolveEnvValue(raw)])
|
||||
.filter(([, value]) => value !== '')
|
||||
)
|
||||
: undefined;
|
||||
const mergedEnv =
|
||||
resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0
|
||||
? { ...process.env, ...resolvedEnvOverrides }
|
||||
: { ...process.env };
|
||||
const transport = new StdioClientTransport({
|
||||
command: resolveCommandArgument(definition.command.command),
|
||||
args: resolveCommandArguments(definition.command.args),
|
||||
cwd: definition.command.cwd,
|
||||
env: mergedEnv,
|
||||
});
|
||||
if (STDIO_TRACE_ENABLED) {
|
||||
attachStdioTraceLogging(transport, definition.name ?? definition.command.command);
|
||||
}
|
||||
try {
|
||||
await client.connect(transport);
|
||||
} catch (error) {
|
||||
await closeTransportAndWait(logger, transport).catch(() => {});
|
||||
throw error;
|
||||
}
|
||||
return { client, transport, definition, oauthSession: undefined };
|
||||
}
|
||||
|
||||
async function retryHttpTransportWithFallback(
|
||||
client: Client,
|
||||
definition: ServerDefinition,
|
||||
logger: Logger,
|
||||
options: CreateClientContextOptions
|
||||
): Promise<ClientContext> {
|
||||
let activeDefinition = definition;
|
||||
while (true) {
|
||||
const attempt = await attemptHttpClientContext(client, activeDefinition, logger, options);
|
||||
if (!attempt.nextDefinition) {
|
||||
return attempt.context;
|
||||
}
|
||||
activeDefinition = attempt.nextDefinition;
|
||||
options.onDefinitionPromoted?.(activeDefinition);
|
||||
}
|
||||
}
|
||||
|
||||
async function attemptHttpClientContext(
|
||||
client: Client,
|
||||
activeDefinition: ServerDefinition,
|
||||
logger: Logger,
|
||||
options: CreateClientContextOptions
|
||||
): Promise<HttpClientContextAttempt> {
|
||||
const command = activeDefinition.command;
|
||||
if (command.kind !== 'http') {
|
||||
throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`);
|
||||
}
|
||||
let oauthSession: OAuthSession | undefined;
|
||||
const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0;
|
||||
if (shouldEstablishOAuth) {
|
||||
oauthSession = await createOAuthSession(activeDefinition, logger);
|
||||
}
|
||||
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
|
||||
|
||||
try {
|
||||
const context = await connectPrimaryHttpTransport(
|
||||
client,
|
||||
activeDefinition,
|
||||
command,
|
||||
transportOptions,
|
||||
oauthSession,
|
||||
logger,
|
||||
options
|
||||
);
|
||||
return { context };
|
||||
} catch (primaryError) {
|
||||
if (shouldAbortSseFallback(primaryError)) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
throw primaryError;
|
||||
}
|
||||
if (isUnauthorizedError(primaryError)) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
|
||||
if (promoted) {
|
||||
return { nextDefinition: promoted };
|
||||
}
|
||||
oauthSession = undefined;
|
||||
}
|
||||
if (primaryError instanceof Error) {
|
||||
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
|
||||
}
|
||||
return {
|
||||
context: await connectSseFallbackTransport(
|
||||
client,
|
||||
activeDefinition,
|
||||
command,
|
||||
transportOptions,
|
||||
oauthSession,
|
||||
logger,
|
||||
options
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function connectPrimaryHttpTransport(
|
||||
client: Client,
|
||||
definition: ServerDefinition,
|
||||
command: Extract<ServerDefinition['command'], { kind: 'http' }>,
|
||||
transportOptions: ResolvedHttpTransportOptions,
|
||||
oauthSession: OAuthSession | undefined,
|
||||
logger: Logger,
|
||||
options: CreateClientContextOptions
|
||||
): Promise<ClientContext> {
|
||||
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
|
||||
const transport = await connectHttpTransport(client, createStreamableTransport(), oauthSession, logger, {
|
||||
serverName: definition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
recreateTransport: async () => createStreamableTransport(),
|
||||
});
|
||||
return {
|
||||
client,
|
||||
transport,
|
||||
definition,
|
||||
oauthSession,
|
||||
};
|
||||
}
|
||||
|
||||
async function connectSseFallbackTransport(
|
||||
client: Client,
|
||||
definition: ServerDefinition,
|
||||
command: Extract<ServerDefinition['command'], { kind: 'http' }>,
|
||||
transportOptions: ResolvedHttpTransportOptions,
|
||||
oauthSession: OAuthSession | undefined,
|
||||
logger: Logger,
|
||||
options: CreateClientContextOptions
|
||||
): Promise<ClientContext> {
|
||||
try {
|
||||
const transport = await connectHttpTransport(
|
||||
client,
|
||||
new SSEClientTransport(command.url, transportOptions),
|
||||
oauthSession,
|
||||
logger,
|
||||
{
|
||||
serverName: definition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
}
|
||||
);
|
||||
return { client, transport, definition, oauthSession };
|
||||
} catch (sseError) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
if (sseError instanceof OAuthTimeoutError) {
|
||||
throw sseError;
|
||||
}
|
||||
if (isUnauthorizedError(sseError)) {
|
||||
const promoted = maybePromoteHttpDefinition(definition, logger, options);
|
||||
if (promoted) {
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
return retryHttpTransportWithFallback(client, promoted, logger, options);
|
||||
}
|
||||
}
|
||||
throw sseError;
|
||||
}
|
||||
}
|
||||
|
||||
export async function createClientContext(
|
||||
definition: ServerDefinition,
|
||||
logger: Logger,
|
||||
@ -151,147 +365,16 @@ export async function createClientContext(
|
||||
options: CreateClientContextOptions = {}
|
||||
): Promise<ClientContext> {
|
||||
const client = new Client(clientInfo);
|
||||
let activeDefinition = definition;
|
||||
|
||||
if (options.allowCachedAuth && activeDefinition.auth === 'oauth' && activeDefinition.command.kind === 'http') {
|
||||
try {
|
||||
const cached = await readCachedAccessToken(activeDefinition, logger);
|
||||
if (cached) {
|
||||
const existingHeaders = activeDefinition.command.headers ?? {};
|
||||
if (!('Authorization' in existingHeaders)) {
|
||||
activeDefinition = {
|
||||
...activeDefinition,
|
||||
command: {
|
||||
...activeDefinition.command,
|
||||
headers: {
|
||||
...existingHeaders,
|
||||
Authorization: `Bearer ${cached}`,
|
||||
},
|
||||
},
|
||||
};
|
||||
logger.debug?.(`Using cached OAuth access token for '${activeDefinition.name}' (non-interactive).`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.debug?.(
|
||||
`Failed to read cached OAuth token for '${activeDefinition.name}': ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
}
|
||||
}
|
||||
const activeDefinition = await applyCachedOAuthHeaderIfAvailable(definition, logger, options.allowCachedAuth);
|
||||
|
||||
return withEnvOverrides(activeDefinition.env, async () => {
|
||||
if (activeDefinition.command.kind === 'stdio') {
|
||||
const resolvedEnvOverrides =
|
||||
activeDefinition.env && Object.keys(activeDefinition.env).length > 0
|
||||
? Object.fromEntries(
|
||||
Object.entries(activeDefinition.env)
|
||||
.map(([key, raw]) => [key, resolveEnvValue(raw)])
|
||||
.filter(([, value]) => value !== '')
|
||||
)
|
||||
: undefined;
|
||||
const mergedEnv =
|
||||
resolvedEnvOverrides && Object.keys(resolvedEnvOverrides).length > 0
|
||||
? { ...process.env, ...resolvedEnvOverrides }
|
||||
: { ...process.env };
|
||||
const transport = new StdioClientTransport({
|
||||
command: resolveCommandArgument(activeDefinition.command.command),
|
||||
args: resolveCommandArguments(activeDefinition.command.args),
|
||||
cwd: activeDefinition.command.cwd,
|
||||
env: mergedEnv,
|
||||
});
|
||||
if (STDIO_TRACE_ENABLED) {
|
||||
attachStdioTraceLogging(transport, activeDefinition.name ?? activeDefinition.command.command);
|
||||
}
|
||||
try {
|
||||
await client.connect(transport);
|
||||
} catch (error) {
|
||||
await closeTransportAndWait(logger, transport).catch(() => {});
|
||||
throw error;
|
||||
}
|
||||
return { client, transport, definition: activeDefinition, oauthSession: undefined };
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const command = activeDefinition.command;
|
||||
if (command.kind !== 'http') {
|
||||
throw new Error(`Server '${activeDefinition.name}' is not configured for HTTP transport.`);
|
||||
}
|
||||
let oauthSession: OAuthSession | undefined;
|
||||
const shouldEstablishOAuth = activeDefinition.auth === 'oauth' && options.maxOAuthAttempts !== 0;
|
||||
if (shouldEstablishOAuth) {
|
||||
oauthSession = await createOAuthSession(activeDefinition, logger);
|
||||
}
|
||||
const transportOptions = createHttpTransportOptions(activeDefinition, oauthSession, shouldEstablishOAuth);
|
||||
|
||||
try {
|
||||
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, transportOptions);
|
||||
const streamableTransport = await connectHttpTransport(
|
||||
client,
|
||||
createStreamableTransport(),
|
||||
oauthSession,
|
||||
logger,
|
||||
{
|
||||
serverName: activeDefinition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
recreateTransport: async () => createStreamableTransport(),
|
||||
}
|
||||
);
|
||||
return {
|
||||
client,
|
||||
transport: streamableTransport,
|
||||
definition: activeDefinition,
|
||||
oauthSession,
|
||||
};
|
||||
} catch (primaryError) {
|
||||
if (shouldAbortSseFallback(primaryError)) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
throw primaryError;
|
||||
}
|
||||
if (isUnauthorizedError(primaryError)) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
oauthSession = undefined;
|
||||
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
|
||||
if (promoted) {
|
||||
activeDefinition = promoted;
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (primaryError instanceof Error) {
|
||||
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
|
||||
}
|
||||
try {
|
||||
const connectedTransport = await connectHttpTransport(
|
||||
client,
|
||||
new SSEClientTransport(command.url, transportOptions),
|
||||
oauthSession,
|
||||
logger,
|
||||
{
|
||||
serverName: activeDefinition.name,
|
||||
maxAttempts: options.maxOAuthAttempts,
|
||||
oauthTimeoutMs: options.oauthTimeoutMs,
|
||||
}
|
||||
);
|
||||
return { client, transport: connectedTransport, definition: activeDefinition, oauthSession };
|
||||
} catch (sseError) {
|
||||
await closeOAuthSession(oauthSession);
|
||||
if (sseError instanceof OAuthTimeoutError) {
|
||||
throw sseError;
|
||||
}
|
||||
if (isUnauthorizedError(sseError)) {
|
||||
const promoted = maybePromoteHttpDefinition(activeDefinition, logger, options);
|
||||
if (promoted) {
|
||||
activeDefinition = promoted;
|
||||
options.onDefinitionPromoted?.(promoted);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
throw sseError;
|
||||
}
|
||||
}
|
||||
return createStdioClientContext(
|
||||
client,
|
||||
activeDefinition as ServerDefinition & { command: Extract<ServerDefinition['command'], { kind: 'stdio' }> },
|
||||
logger
|
||||
);
|
||||
}
|
||||
return retryHttpTransportWithFallback(client, activeDefinition, logger, options);
|
||||
});
|
||||
}
|
||||
|
||||
112
tests/helpers/runtime-test-helpers.ts
Normal file
112
tests/helpers/runtime-test-helpers.ts
Normal file
@ -0,0 +1,112 @@
|
||||
import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { vi } from 'vitest';
|
||||
import type { ServerDefinition } from '../../src/config.js';
|
||||
import type { Logger } from '../../src/logging.js';
|
||||
import type { OAuthSession } from '../../src/oauth.js';
|
||||
|
||||
export const clientInfo = { name: 'mcporter', version: '0.0.0-test' };
|
||||
|
||||
export interface LoggerSpy extends Logger {
|
||||
info: ReturnType<typeof vi.fn<(message: string) => void>>;
|
||||
warn: ReturnType<typeof vi.fn<(message: string) => void>>;
|
||||
error: ReturnType<typeof vi.fn<(message: string, error?: unknown) => void>>;
|
||||
}
|
||||
|
||||
export function createLogger(): LoggerSpy {
|
||||
return {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
}
|
||||
|
||||
export function resetLogger(logger: LoggerSpy): void {
|
||||
logger.info.mockReset();
|
||||
logger.warn.mockReset();
|
||||
logger.error.mockReset();
|
||||
}
|
||||
|
||||
export function stubHttpDefinition(url: string): ServerDefinition {
|
||||
return {
|
||||
name: 'http-server',
|
||||
command: { kind: 'http', url: new URL(url) },
|
||||
source: { kind: 'local', path: '<adhoc>' },
|
||||
};
|
||||
}
|
||||
|
||||
export function stubOAuthHttpDefinition(url: string): ServerDefinition {
|
||||
return {
|
||||
...stubHttpDefinition(url),
|
||||
auth: 'oauth',
|
||||
};
|
||||
}
|
||||
|
||||
export function createPromotionRecorder() {
|
||||
const promotedDefinitions: ServerDefinition[] = [];
|
||||
return {
|
||||
promotedDefinitions,
|
||||
onDefinitionPromoted: (promoted: ServerDefinition) => {
|
||||
promotedDefinitions.push(promoted);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function createMockOAuthSession(): OAuthSession {
|
||||
return {
|
||||
provider: {
|
||||
waitForAuthorizationCode: vi.fn(),
|
||||
} as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode: vi.fn(),
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
}
|
||||
|
||||
export class MockTransport implements Transport {
|
||||
public readonly calls: string[] = [];
|
||||
public readonly close = vi.fn(async () => {});
|
||||
|
||||
constructor(private readonly finishAuthImpl?: (code: string) => Promise<void>) {}
|
||||
|
||||
async start(): Promise<void> {}
|
||||
|
||||
async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise<void> {}
|
||||
|
||||
async finishAuth(code: string): Promise<void> {
|
||||
this.calls.push(code);
|
||||
if (this.finishAuthImpl) {
|
||||
await this.finishAuthImpl(code);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function createPendingAuthorizationSession() {
|
||||
const pendingResolvers: Array<(code: string) => void> = [];
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
pendingResolvers.push(resolve);
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
return {
|
||||
session,
|
||||
waitForAuthorizationCode,
|
||||
pendingResolvers,
|
||||
resolveNextCode: (code: string) => {
|
||||
const resolve = pendingResolvers.shift();
|
||||
if (!resolve) {
|
||||
throw new Error(`Missing pending authorization resolver for '${code}'.`);
|
||||
}
|
||||
resolve(code);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function flushAuthLoop(): Promise<void> {
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
}
|
||||
@ -1,65 +1,13 @@
|
||||
import type { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type { Transport, TransportSendOptions } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { Logger } from '../src/logging.js';
|
||||
import type { OAuthSession } from '../src/oauth.js';
|
||||
import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError } from '../src/runtime/oauth.js';
|
||||
|
||||
class MockTransport implements Transport {
|
||||
public readonly calls: string[] = [];
|
||||
public readonly close = vi.fn(async () => {});
|
||||
|
||||
constructor(private readonly finishAuthImpl?: (code: string) => Promise<void>) {}
|
||||
async start(): Promise<void> {}
|
||||
async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise<void> {}
|
||||
async finishAuth(code: string): Promise<void> {
|
||||
this.calls.push(code);
|
||||
if (this.finishAuthImpl) {
|
||||
await this.finishAuthImpl(code);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function createLogger(): Logger {
|
||||
return {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
}
|
||||
|
||||
function createPendingAuthorizationSession() {
|
||||
const pendingResolvers: Array<(code: string) => void> = [];
|
||||
const waitForAuthorizationCode = vi.fn(
|
||||
() =>
|
||||
new Promise<string>((resolve) => {
|
||||
pendingResolvers.push(resolve);
|
||||
})
|
||||
);
|
||||
const session: OAuthSession = {
|
||||
provider: { waitForAuthorizationCode } as unknown as OAuthSession['provider'],
|
||||
waitForAuthorizationCode,
|
||||
close: vi.fn(async () => {}),
|
||||
};
|
||||
return {
|
||||
session,
|
||||
waitForAuthorizationCode,
|
||||
pendingResolvers,
|
||||
resolveNextCode: (code: string) => {
|
||||
const resolve = pendingResolvers.shift();
|
||||
if (!resolve) {
|
||||
throw new Error(`Missing pending authorization resolver for '${code}'.`);
|
||||
}
|
||||
resolve(code);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function flushAuthLoop(): Promise<void> {
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
}
|
||||
import {
|
||||
createLogger,
|
||||
createPendingAuthorizationSession,
|
||||
flushAuthLoop,
|
||||
MockTransport,
|
||||
} from './helpers/runtime-test-helpers.js';
|
||||
|
||||
describe('connectWithAuth', () => {
|
||||
it('waits for authorization code and retries connection', async () => {
|
||||
|
||||
@ -28,60 +28,33 @@ import type { ServerDefinition } from '../src/config.js';
|
||||
import * as oauthModule from '../src/oauth.js';
|
||||
import { markOAuthFlowError, markPostAuthConnectError } from '../src/runtime/oauth.js';
|
||||
import { createClientContext } from '../src/runtime/transport.js';
|
||||
import {
|
||||
clientInfo,
|
||||
createLogger,
|
||||
createMockOAuthSession,
|
||||
createPromotionRecorder,
|
||||
resetLogger,
|
||||
stubHttpDefinition,
|
||||
stubOAuthHttpDefinition,
|
||||
} from './helpers/runtime-test-helpers.js';
|
||||
|
||||
const logger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
|
||||
const clientInfo = { name: 'mcporter', version: '0.0.0-test' };
|
||||
const logger = createLogger();
|
||||
|
||||
beforeEach(() => {
|
||||
resetLogger(logger);
|
||||
mocks.connectWithAuth.mockReset();
|
||||
mocks.connectWithAuth.mockImplementation(async (client, transport) => {
|
||||
await client.connect(transport);
|
||||
return transport;
|
||||
});
|
||||
mocks.createOAuthSession.mockReset();
|
||||
mocks.createOAuthSession.mockResolvedValue({
|
||||
provider: {
|
||||
waitForAuthorizationCode: vi.fn(),
|
||||
},
|
||||
waitForAuthorizationCode: vi.fn(),
|
||||
close: vi.fn(async () => {}),
|
||||
});
|
||||
mocks.createOAuthSession.mockResolvedValue(createMockOAuthSession());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
function stubHttpDefinition(url: string): ServerDefinition {
|
||||
return {
|
||||
name: 'http-server',
|
||||
command: { kind: 'http', url: new URL(url) },
|
||||
source: { kind: 'local', path: '<adhoc>' },
|
||||
};
|
||||
}
|
||||
|
||||
function stubOAuthHttpDefinition(url: string): ServerDefinition {
|
||||
return {
|
||||
...stubHttpDefinition(url),
|
||||
auth: 'oauth',
|
||||
};
|
||||
}
|
||||
|
||||
function createPromotionRecorder() {
|
||||
const promotedDefinitions: ServerDefinition[] = [];
|
||||
return {
|
||||
promotedDefinitions,
|
||||
onDefinitionPromoted: (promoted: ServerDefinition) => {
|
||||
promotedDefinitions.push(promoted);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('createClientContext (HTTP)', () => {
|
||||
it('falls back to SSE when primary connect fails', async () => {
|
||||
const definition = stubHttpDefinition('https://example.com/mcp');
|
||||
|
||||
Loading…
Reference in New Issue
Block a user