From 86e19f4413ba66874e6ba1ff617bbb23bafd9dbb Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 20 May 2026 17:09:14 +0100 Subject: [PATCH] fix: use cached auth for daemon OAuth calls --- src/cli/tool-cache.ts | 11 ++-- src/daemon/host.ts | 3 +- src/daemon/protocol.ts | 2 +- src/daemon/runtime-wrapper.ts | 2 +- src/runtime.ts | 36 ++++++++----- tests/cli-call-execution.test.ts | 18 +++++-- tests/daemon-host.test.ts | 87 ++++++++++++++++++++++++++++---- tests/keep-alive-runtime.test.ts | 15 +++++- tests/runtime-compose.test.ts | 62 +++++++++++++++++++++++ tests/serve.test.ts | 1 + 10 files changed, 201 insertions(+), 36 deletions(-) diff --git a/src/cli/tool-cache.ts b/src/cli/tool-cache.ts index 6d12dbf..16f1cdc 100644 --- a/src/cli/tool-cache.ts +++ b/src/cli/tool-cache.ts @@ -12,7 +12,7 @@ const runtimeCache = new WeakMap>>( function cacheKey(serverName: string, options: LoadToolMetadataOptions): string { const includeSchema = options.includeSchema !== false; const autoAuthorize = options.autoAuthorize !== false; - const allowCachedAuth = options.allowCachedAuth === true; + const allowCachedAuth = options.allowCachedAuth !== false; return `${serverName}::schema:${includeSchema ? '1' : '0'}::auth:${autoAuthorize ? '1' : '0'}::cached-auth:${allowCachedAuth ? '1' : '0'}`; } @@ -33,10 +33,11 @@ export async function loadToolMetadata( } const includeSchema = options.includeSchema !== false; const autoAuthorize = options.autoAuthorize !== false; - const listOptions: ListToolsOptions = - options.allowCachedAuth === undefined - ? { includeSchema, autoAuthorize } - : { includeSchema, autoAuthorize, allowCachedAuth: options.allowCachedAuth }; + const listOptions: ListToolsOptions = { + includeSchema, + autoAuthorize, + allowCachedAuth: options.allowCachedAuth ?? true, + }; const promise = runtime .listTools(serverName, listOptions) .then((tools) => tools.map((tool) => buildToolMetadata(tool))) diff --git a/src/daemon/host.ts b/src/daemon/host.ts index 78e6962..86d7245 100644 --- a/src/daemon/host.ts +++ b/src/daemon/host.ts @@ -302,7 +302,6 @@ async function processRequest( const result = await runtime.callTool(params.server, params.tool, { args: params.args ?? {}, timeoutMs: params.timeoutMs, - allowCachedAuth: params.allowCachedAuth ?? true, }); markActivity(params.server, activity); if (loggable) { @@ -328,7 +327,7 @@ async function processRequest( const result = await runtime.listTools(params.server, { includeSchema: params.includeSchema, autoAuthorize: params.autoAuthorize, - allowCachedAuth: true, + allowCachedAuth: params.allowCachedAuth ?? true, }); markActivity(params.server, activity); if (loggable) { diff --git a/src/daemon/protocol.ts b/src/daemon/protocol.ts index d46a13a..d2c72ea 100644 --- a/src/daemon/protocol.ts +++ b/src/daemon/protocol.ts @@ -28,13 +28,13 @@ export interface CallToolParams { readonly tool: string; readonly args?: Record; readonly timeoutMs?: number; - readonly allowCachedAuth?: boolean; } export interface ListToolsParams { readonly server: string; readonly includeSchema?: boolean; readonly autoAuthorize?: boolean; + readonly allowCachedAuth?: boolean; } export interface ListResourcesParams { diff --git a/src/daemon/runtime-wrapper.ts b/src/daemon/runtime-wrapper.ts index a840258..0304233 100644 --- a/src/daemon/runtime-wrapper.ts +++ b/src/daemon/runtime-wrapper.ts @@ -61,6 +61,7 @@ class KeepAliveRuntime implements Runtime { server, includeSchema: options?.includeSchema, autoAuthorize: options?.autoAuthorize, + allowCachedAuth: options?.allowCachedAuth ?? true, }) )) as Awaited>; } @@ -75,7 +76,6 @@ class KeepAliveRuntime implements Runtime { tool: toolName, args: options?.args, timeoutMs: options?.timeoutMs, - allowCachedAuth: options?.allowCachedAuth ?? true, }) ); } diff --git a/src/runtime.ts b/src/runtime.ts index 6965744..4d6a015 100644 --- a/src/runtime.ts +++ b/src/runtime.ts @@ -103,7 +103,13 @@ export async function callOnce(params: { class McpRuntime implements Runtime { private readonly definitions: Map; - private readonly clients = new Map>(); + private readonly clients = new Map< + string, + { + readonly promise: Promise; + readonly allowCachedAuth: boolean | undefined; + } + >(); private readonly logger: RuntimeLogger; private readonly clientInfo: { name: string; version: string }; private readonly oauthTimeoutMs?: number; @@ -150,12 +156,12 @@ class McpRuntime implements Runtime { } async getInstructions(server: string): Promise { - const contextPromise = this.clients.get(server.trim()); - if (!contextPromise) { + const cached = this.clients.get(server.trim()); + if (!cached) { return undefined; } try { - const context = await contextPromise; + const context = await cached.promise; const instructions = typeof context.client.getInstructions === 'function' ? context.client.getInstructions() : undefined; if (typeof instructions !== 'string') { @@ -175,7 +181,7 @@ class McpRuntime implements Runtime { const context = await this.connect(server, { maxOAuthAttempts: autoAuthorize ? undefined : 0, skipCache: !autoAuthorize, - allowCachedAuth: options.allowCachedAuth, + allowCachedAuth: options.allowCachedAuth ?? true, oauthSessionOptions: options.oauthSessionOptions, }); try { @@ -218,7 +224,9 @@ class McpRuntime implements Runtime { ); } try { - const { client } = await this.connect(server); + const { client } = await this.connect(server, { + allowCachedAuth: true, + }); const params: CallToolRequest['params'] = { name: toolName, arguments: options.args ?? {}, @@ -276,7 +284,10 @@ class McpRuntime implements Runtime { if (useCache) { const existing = this.clients.get(normalized); if (existing) { - return existing; + if (existing.allowCachedAuth === options.allowCachedAuth || options.allowCachedAuth === undefined) { + return existing.promise; + } + await this.close(normalized).catch(() => {}); } } @@ -294,7 +305,7 @@ class McpRuntime implements Runtime { }); if (useCache) { - this.clients.set(normalized, connection); + this.clients.set(normalized, { promise: connection, allowCachedAuth: options.allowCachedAuth }); try { return await connection; } catch (error) { @@ -310,10 +321,11 @@ class McpRuntime implements Runtime { async close(server?: string): Promise { if (server) { const normalized = server.trim(); - const context = await this.clients.get(normalized); - if (!context) { + const cached = this.clients.get(normalized); + if (!cached) { return; } + const context = await cached.promise; await context.client.close().catch(() => {}); await closeTransportAndWait(this.logger, context.transport).catch(() => {}); await context.oauthSession?.close().catch(() => {}); @@ -321,9 +333,9 @@ class McpRuntime implements Runtime { return; } - for (const [name, promise] of this.clients.entries()) { + for (const [name, cached] of this.clients.entries()) { try { - const context = await promise; + const context = await cached.promise; await context.client.close().catch(() => {}); await closeTransportAndWait(this.logger, context.transport).catch(() => {}); await context.oauthSession?.close().catch(() => {}); diff --git a/tests/cli-call-execution.test.ts b/tests/cli-call-execution.test.ts index 5c867d9..be55aca 100644 --- a/tests/cli-call-execution.test.ts +++ b/tests/cli-call-execution.test.ts @@ -82,7 +82,11 @@ describe('CLI call execution behavior', () => { }, }) ); - expect(listTools).toHaveBeenCalledWith('slack', { autoAuthorize: true, includeSchema: true }); + expect(listTools).toHaveBeenCalledWith('slack', { + autoAuthorize: true, + includeSchema: true, + allowCachedAuth: true, + }); logSpy.mockRestore(); }); @@ -117,7 +121,11 @@ describe('CLI call execution behavior', () => { }, }) ); - expect(listTools).toHaveBeenCalledWith('email', { autoAuthorize: true, includeSchema: true }); + expect(listTools).toHaveBeenCalledWith('email', { + autoAuthorize: true, + includeSchema: true, + allowCachedAuth: true, + }); logSpy.mockRestore(); }); @@ -326,7 +334,11 @@ describe('CLI call execution behavior', () => { expect(callTool).toHaveBeenCalledTimes(2); expect(callTool).toHaveBeenNthCalledWith(1, 'linear', 'listIssues', expect.objectContaining({ args: {} })); expect(callTool).toHaveBeenNthCalledWith(2, 'linear', 'list_issues', expect.objectContaining({ args: {} })); - expect(listTools).toHaveBeenCalledWith('linear', { autoAuthorize: true, includeSchema: false }); + expect(listTools).toHaveBeenCalledWith('linear', { + autoAuthorize: true, + includeSchema: false, + allowCachedAuth: true, + }); logSpy.mockRestore(); }); diff --git a/tests/daemon-host.test.ts b/tests/daemon-host.test.ts index 8fc24e9..a87a2ad 100644 --- a/tests/daemon-host.test.ts +++ b/tests/daemon-host.test.ts @@ -1,21 +1,21 @@ -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; import type { ServerDefinition } from '../src/config.js'; import { __testProcessRequest } from '../src/daemon/host.js'; import type { DaemonRequest } from '../src/daemon/protocol.js'; import type { Runtime } from '../src/runtime.js'; describe('daemon host request handling', () => { - it('reuses pre-parsed requests without reparsing payloads', async () => { - const metadata = { - configPath: '/tmp/config.json', - configLayers: [], - configMtimeMs: Date.now(), - socketPath: '/tmp/socket', - startedAt: Date.now(), - logPath: null, - }; - const logContext = { enabled: false, logAllServers: false, servers: new Set() }; + const metadata = { + configPath: '/tmp/config.json', + configLayers: [], + configMtimeMs: Date.now(), + socketPath: '/tmp/socket', + startedAt: Date.now(), + logPath: null, + }; + const logContext = { enabled: false, logAllServers: false, servers: new Set() }; + it('reuses pre-parsed requests without reparsing payloads', async () => { const parsedRequest: DaemonRequest = { id: '1', method: 'status', params: {} }; const result = await __testProcessRequest( '!!!invalid-json!!!', @@ -30,4 +30,69 @@ describe('daemon host request handling', () => { expect(result.response.ok).toBe(true); expect(result.shouldShutdown).toBe(false); }); + + it('defaults daemon callTool and listTools requests to cached auth', async () => { + const runtime = createRuntimeDouble(); + const managedServers = createManagedServers(); + + await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, { + id: 'call', + method: 'callTool', + params: { server: 'oauth', tool: 'ping' }, + }); + + expect(runtime.callTool).toHaveBeenCalledWith('oauth', 'ping', { + args: {}, + timeoutMs: undefined, + }); + + await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, { + id: 'list', + method: 'listTools', + params: { server: 'oauth', includeSchema: true }, + }); + + expect(runtime.listTools).toHaveBeenCalledWith('oauth', { + includeSchema: true, + autoAuthorize: undefined, + allowCachedAuth: true, + }); + }); + + it('preserves explicit listTools cached-auth opt out on daemon requests', async () => { + const runtime = createRuntimeDouble(); + const managedServers = createManagedServers(); + + await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, { + id: 'list', + method: 'listTools', + params: { server: 'oauth', allowCachedAuth: false }, + }); + + expect(runtime.listTools).toHaveBeenCalledWith('oauth', { + includeSchema: undefined, + autoAuthorize: undefined, + allowCachedAuth: false, + }); + }); }); + +function createRuntimeDouble(): Pick { + return { + callTool: vi.fn().mockResolvedValue({ ok: true }), + listTools: vi.fn().mockResolvedValue([]), + }; +} + +function createManagedServers(): Map { + return new Map([ + [ + 'oauth', + { + name: 'oauth', + command: { kind: 'http', url: new URL('https://oauth.example.com/mcp') }, + lifecycle: { mode: 'keep-alive' }, + }, + ], + ]); +} diff --git a/tests/keep-alive-runtime.test.ts b/tests/keep-alive-runtime.test.ts index 94a1527..b886131 100644 --- a/tests/keep-alive-runtime.test.ts +++ b/tests/keep-alive-runtime.test.ts @@ -105,7 +105,20 @@ describe('createKeepAliveRuntime', () => { }); await keepAliveRuntime.listTools('alpha', { includeSchema: true }); - expect(daemon.listTools).toHaveBeenCalledWith({ server: 'alpha', includeSchema: true, autoAuthorize: undefined }); + expect(daemon.listTools).toHaveBeenCalledWith({ + server: 'alpha', + includeSchema: true, + autoAuthorize: undefined, + allowCachedAuth: true, + }); + + await keepAliveRuntime.listTools('alpha', { allowCachedAuth: false }); + expect(daemon.listTools).toHaveBeenLastCalledWith({ + server: 'alpha', + includeSchema: undefined, + autoAuthorize: undefined, + allowCachedAuth: false, + }); await keepAliveRuntime.listResources('alpha', { cursor: '1' }); expect(daemon.listResources).toHaveBeenCalledWith({ server: 'alpha', params: { cursor: '1' } }); diff --git a/tests/runtime-compose.test.ts b/tests/runtime-compose.test.ts index cba6349..1d1b02e 100644 --- a/tests/runtime-compose.test.ts +++ b/tests/runtime-compose.test.ts @@ -5,6 +5,7 @@ const mocks = vi.hoisted(() => { const listToolsMock = vi.fn(); const callToolMock = vi.fn(); const listResourcesMock = vi.fn(); + const readCachedAccessTokenMock = vi.fn(); const clientInstances: unknown[] = []; const streamableInstances: unknown[] = []; const stdioInstances: unknown[] = []; @@ -71,6 +72,7 @@ const mocks = vi.hoisted(() => { listToolsMock, callToolMock, listResourcesMock, + readCachedAccessTokenMock, clientInstances, streamableInstances, stdioInstances, @@ -102,6 +104,10 @@ vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ UnauthorizedError: mocks.MockUnauthorizedError, })); +vi.mock('../src/oauth-persistence.js', () => ({ + readCachedAccessToken: mocks.readCachedAccessTokenMock, +})); + import { createRuntime } from '../src/runtime.js'; describe('mcporter composability', () => { @@ -110,6 +116,7 @@ describe('mcporter composability', () => { mocks.listToolsMock.mockReset(); mocks.callToolMock.mockReset(); mocks.listResourcesMock.mockReset(); + mocks.readCachedAccessTokenMock.mockReset(); mocks.clientInstances.length = 0; mocks.streamableInstances.length = 0; mocks.stdioInstances.length = 0; @@ -117,6 +124,7 @@ describe('mcporter composability', () => { mocks.listToolsMock.mockResolvedValue({ tools: [] }); mocks.callToolMock.mockResolvedValue({ ok: true }); mocks.listResourcesMock.mockResolvedValue({ resources: [] }); + mocks.readCachedAccessTokenMock.mockResolvedValue(undefined); }); afterEach(() => { @@ -237,6 +245,60 @@ describe('mcporter composability', () => { expect(instance?.options?.env?.MCPORTER_STDIO_TEST).toBe('from-config'); expect(instance?.options?.env?.EXTRA).toBe('42'); }); + + it('applies cached auth for callTool connections', async () => { + mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token'); + const runtime = await createRuntime({ + servers: [ + { + name: 'oauth', + command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') }, + }, + ], + }); + + try { + await runtime.callTool('oauth', 'ping'); + expect(mocks.readCachedAccessTokenMock).toHaveBeenCalledOnce(); + const streamableTransport = mocks.streamableInstances[0] as { + options?: { requestInit?: { headers?: Record } }; + }; + expect(streamableTransport.options?.requestInit?.headers).toEqual({ + Authorization: 'Bearer cached-token', + }); + } finally { + await runtime.close(); + } + }); + + it('reconnects when callTool needs cached auth after an uncached connection', async () => { + const runtime = await createRuntime({ + servers: [ + { + name: 'oauth', + command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') }, + }, + ], + }); + + try { + await runtime.listTools('oauth', { allowCachedAuth: false }); + expect(mocks.streamableInstances).toHaveLength(1); + + mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token'); + await runtime.callTool('oauth', 'ping'); + + expect(mocks.streamableInstances).toHaveLength(2); + const streamableTransport = mocks.streamableInstances[1] as { + options?: { requestInit?: { headers?: Record } }; + }; + expect(streamableTransport.options?.requestInit?.headers).toEqual({ + Authorization: 'Bearer cached-token', + }); + } finally { + await runtime.close(); + } + }); }); describe('stdio transport environment', () => { diff --git a/tests/serve.test.ts b/tests/serve.test.ts index 03acf20..e1c5b7b 100644 --- a/tests/serve.test.ts +++ b/tests/serve.test.ts @@ -200,6 +200,7 @@ describe('mcporter serve bridge', () => { server: 'alpha', includeSchema: true, autoAuthorize: true, + allowCachedAuth: true, }); expect(baseRuntime.listTools).not.toHaveBeenCalled();