Merge pull request #182 from bradhallett/fix/daemon-allowCachedAuth

fix(daemon): pass allowCachedAuth to runtime for OAuth token reuse
This commit is contained in:
Peter Steinberger 2026-05-20 17:10:33 +01:00 committed by GitHub
commit ccfaa2f4f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 201 additions and 32 deletions

View File

@ -12,7 +12,7 @@ const runtimeCache = new WeakMap<Runtime, Map<string, Promise<ToolMetadata[]>>>(
function cacheKey(serverName: string, options: LoadToolMetadataOptions): string {
const includeSchema = options.includeSchema !== false;
const autoAuthorize = options.autoAuthorize !== false;
const allowCachedAuth = options.allowCachedAuth === true;
const allowCachedAuth = options.allowCachedAuth !== false;
return `${serverName}::schema:${includeSchema ? '1' : '0'}::auth:${autoAuthorize ? '1' : '0'}::cached-auth:${allowCachedAuth ? '1' : '0'}`;
}
@ -33,10 +33,11 @@ export async function loadToolMetadata(
}
const includeSchema = options.includeSchema !== false;
const autoAuthorize = options.autoAuthorize !== false;
const listOptions: ListToolsOptions =
options.allowCachedAuth === undefined
? { includeSchema, autoAuthorize }
: { includeSchema, autoAuthorize, allowCachedAuth: options.allowCachedAuth };
const listOptions: ListToolsOptions = {
includeSchema,
autoAuthorize,
allowCachedAuth: options.allowCachedAuth ?? true,
};
const promise = runtime
.listTools(serverName, listOptions)
.then((tools) => tools.map((tool) => buildToolMetadata(tool)))

View File

@ -327,6 +327,7 @@ async function processRequest(
const result = await runtime.listTools(params.server, {
includeSchema: params.includeSchema,
autoAuthorize: params.autoAuthorize,
allowCachedAuth: params.allowCachedAuth ?? true,
});
markActivity(params.server, activity);
if (loggable) {

View File

@ -34,6 +34,7 @@ export interface ListToolsParams {
readonly server: string;
readonly includeSchema?: boolean;
readonly autoAuthorize?: boolean;
readonly allowCachedAuth?: boolean;
}
export interface ListResourcesParams {

View File

@ -61,6 +61,7 @@ class KeepAliveRuntime implements Runtime {
server,
includeSchema: options?.includeSchema,
autoAuthorize: options?.autoAuthorize,
allowCachedAuth: options?.allowCachedAuth ?? true,
})
)) as Awaited<ReturnType<Runtime['listTools']>>;
}

View File

@ -103,7 +103,13 @@ export async function callOnce(params: {
class McpRuntime implements Runtime {
private readonly definitions: Map<string, ServerDefinition>;
private readonly clients = new Map<string, Promise<ClientContext>>();
private readonly clients = new Map<
string,
{
readonly promise: Promise<ClientContext>;
readonly allowCachedAuth: boolean | undefined;
}
>();
private readonly logger: RuntimeLogger;
private readonly clientInfo: { name: string; version: string };
private readonly oauthTimeoutMs?: number;
@ -150,12 +156,12 @@ class McpRuntime implements Runtime {
}
async getInstructions(server: string): Promise<string | undefined> {
const contextPromise = this.clients.get(server.trim());
if (!contextPromise) {
const cached = this.clients.get(server.trim());
if (!cached) {
return undefined;
}
try {
const context = await contextPromise;
const context = await cached.promise;
const instructions =
typeof context.client.getInstructions === 'function' ? context.client.getInstructions() : undefined;
if (typeof instructions !== 'string') {
@ -175,7 +181,7 @@ class McpRuntime implements Runtime {
const context = await this.connect(server, {
maxOAuthAttempts: autoAuthorize ? undefined : 0,
skipCache: !autoAuthorize,
allowCachedAuth: options.allowCachedAuth,
allowCachedAuth: options.allowCachedAuth ?? true,
oauthSessionOptions: options.oauthSessionOptions,
});
try {
@ -218,7 +224,9 @@ class McpRuntime implements Runtime {
);
}
try {
const { client } = await this.connect(server);
const { client } = await this.connect(server, {
allowCachedAuth: true,
});
const params: CallToolRequest['params'] = {
name: toolName,
arguments: options.args ?? {},
@ -276,7 +284,10 @@ class McpRuntime implements Runtime {
if (useCache) {
const existing = this.clients.get(normalized);
if (existing) {
return existing;
if (existing.allowCachedAuth === options.allowCachedAuth || options.allowCachedAuth === undefined) {
return existing.promise;
}
await this.close(normalized).catch(() => {});
}
}
@ -294,7 +305,7 @@ class McpRuntime implements Runtime {
});
if (useCache) {
this.clients.set(normalized, connection);
this.clients.set(normalized, { promise: connection, allowCachedAuth: options.allowCachedAuth });
try {
return await connection;
} catch (error) {
@ -310,10 +321,11 @@ class McpRuntime implements Runtime {
async close(server?: string): Promise<void> {
if (server) {
const normalized = server.trim();
const context = await this.clients.get(normalized);
if (!context) {
const cached = this.clients.get(normalized);
if (!cached) {
return;
}
const context = await cached.promise;
await context.client.close().catch(() => {});
await closeTransportAndWait(this.logger, context.transport).catch(() => {});
await context.oauthSession?.close().catch(() => {});
@ -321,9 +333,9 @@ class McpRuntime implements Runtime {
return;
}
for (const [name, promise] of this.clients.entries()) {
for (const [name, cached] of this.clients.entries()) {
try {
const context = await promise;
const context = await cached.promise;
await context.client.close().catch(() => {});
await closeTransportAndWait(this.logger, context.transport).catch(() => {});
await context.oauthSession?.close().catch(() => {});

View File

@ -82,7 +82,11 @@ describe('CLI call execution behavior', () => {
},
})
);
expect(listTools).toHaveBeenCalledWith('slack', { autoAuthorize: true, includeSchema: true });
expect(listTools).toHaveBeenCalledWith('slack', {
autoAuthorize: true,
includeSchema: true,
allowCachedAuth: true,
});
logSpy.mockRestore();
});
@ -117,7 +121,11 @@ describe('CLI call execution behavior', () => {
},
})
);
expect(listTools).toHaveBeenCalledWith('email', { autoAuthorize: true, includeSchema: true });
expect(listTools).toHaveBeenCalledWith('email', {
autoAuthorize: true,
includeSchema: true,
allowCachedAuth: true,
});
logSpy.mockRestore();
});
@ -326,7 +334,11 @@ describe('CLI call execution behavior', () => {
expect(callTool).toHaveBeenCalledTimes(2);
expect(callTool).toHaveBeenNthCalledWith(1, 'linear', 'listIssues', expect.objectContaining({ args: {} }));
expect(callTool).toHaveBeenNthCalledWith(2, 'linear', 'list_issues', expect.objectContaining({ args: {} }));
expect(listTools).toHaveBeenCalledWith('linear', { autoAuthorize: true, includeSchema: false });
expect(listTools).toHaveBeenCalledWith('linear', {
autoAuthorize: true,
includeSchema: false,
allowCachedAuth: true,
});
logSpy.mockRestore();
});

View File

@ -1,21 +1,21 @@
import { describe, expect, it } from 'vitest';
import { describe, expect, it, vi } from 'vitest';
import type { ServerDefinition } from '../src/config.js';
import { __testProcessRequest } from '../src/daemon/host.js';
import type { DaemonRequest } from '../src/daemon/protocol.js';
import type { Runtime } from '../src/runtime.js';
describe('daemon host request handling', () => {
it('reuses pre-parsed requests without reparsing payloads', async () => {
const metadata = {
configPath: '/tmp/config.json',
configLayers: [],
configMtimeMs: Date.now(),
socketPath: '/tmp/socket',
startedAt: Date.now(),
logPath: null,
};
const logContext = { enabled: false, logAllServers: false, servers: new Set<string>() };
const metadata = {
configPath: '/tmp/config.json',
configLayers: [],
configMtimeMs: Date.now(),
socketPath: '/tmp/socket',
startedAt: Date.now(),
logPath: null,
};
const logContext = { enabled: false, logAllServers: false, servers: new Set<string>() };
it('reuses pre-parsed requests without reparsing payloads', async () => {
const parsedRequest: DaemonRequest = { id: '1', method: 'status', params: {} };
const result = await __testProcessRequest(
'!!!invalid-json!!!',
@ -30,4 +30,69 @@ describe('daemon host request handling', () => {
expect(result.response.ok).toBe(true);
expect(result.shouldShutdown).toBe(false);
});
it('defaults daemon callTool and listTools requests to cached auth', async () => {
const runtime = createRuntimeDouble();
const managedServers = createManagedServers();
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
id: 'call',
method: 'callTool',
params: { server: 'oauth', tool: 'ping' },
});
expect(runtime.callTool).toHaveBeenCalledWith('oauth', 'ping', {
args: {},
timeoutMs: undefined,
});
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
id: 'list',
method: 'listTools',
params: { server: 'oauth', includeSchema: true },
});
expect(runtime.listTools).toHaveBeenCalledWith('oauth', {
includeSchema: true,
autoAuthorize: undefined,
allowCachedAuth: true,
});
});
it('preserves explicit listTools cached-auth opt out on daemon requests', async () => {
const runtime = createRuntimeDouble();
const managedServers = createManagedServers();
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
id: 'list',
method: 'listTools',
params: { server: 'oauth', allowCachedAuth: false },
});
expect(runtime.listTools).toHaveBeenCalledWith('oauth', {
includeSchema: undefined,
autoAuthorize: undefined,
allowCachedAuth: false,
});
});
});
function createRuntimeDouble(): Pick<Runtime, 'callTool' | 'listTools'> {
return {
callTool: vi.fn().mockResolvedValue({ ok: true }),
listTools: vi.fn().mockResolvedValue([]),
};
}
function createManagedServers(): Map<string, ServerDefinition> {
return new Map([
[
'oauth',
{
name: 'oauth',
command: { kind: 'http', url: new URL('https://oauth.example.com/mcp') },
lifecycle: { mode: 'keep-alive' },
},
],
]);
}

View File

@ -105,7 +105,20 @@ describe('createKeepAliveRuntime', () => {
});
await keepAliveRuntime.listTools('alpha', { includeSchema: true });
expect(daemon.listTools).toHaveBeenCalledWith({ server: 'alpha', includeSchema: true, autoAuthorize: undefined });
expect(daemon.listTools).toHaveBeenCalledWith({
server: 'alpha',
includeSchema: true,
autoAuthorize: undefined,
allowCachedAuth: true,
});
await keepAliveRuntime.listTools('alpha', { allowCachedAuth: false });
expect(daemon.listTools).toHaveBeenLastCalledWith({
server: 'alpha',
includeSchema: undefined,
autoAuthorize: undefined,
allowCachedAuth: false,
});
await keepAliveRuntime.listResources('alpha', { cursor: '1' });
expect(daemon.listResources).toHaveBeenCalledWith({ server: 'alpha', params: { cursor: '1' } });

View File

@ -5,6 +5,7 @@ const mocks = vi.hoisted(() => {
const listToolsMock = vi.fn();
const callToolMock = vi.fn();
const listResourcesMock = vi.fn();
const readCachedAccessTokenMock = vi.fn();
const clientInstances: unknown[] = [];
const streamableInstances: unknown[] = [];
const stdioInstances: unknown[] = [];
@ -71,6 +72,7 @@ const mocks = vi.hoisted(() => {
listToolsMock,
callToolMock,
listResourcesMock,
readCachedAccessTokenMock,
clientInstances,
streamableInstances,
stdioInstances,
@ -102,6 +104,10 @@ vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
UnauthorizedError: mocks.MockUnauthorizedError,
}));
vi.mock('../src/oauth-persistence.js', () => ({
readCachedAccessToken: mocks.readCachedAccessTokenMock,
}));
import { createRuntime } from '../src/runtime.js';
describe('mcporter composability', () => {
@ -110,6 +116,7 @@ describe('mcporter composability', () => {
mocks.listToolsMock.mockReset();
mocks.callToolMock.mockReset();
mocks.listResourcesMock.mockReset();
mocks.readCachedAccessTokenMock.mockReset();
mocks.clientInstances.length = 0;
mocks.streamableInstances.length = 0;
mocks.stdioInstances.length = 0;
@ -117,6 +124,7 @@ describe('mcporter composability', () => {
mocks.listToolsMock.mockResolvedValue({ tools: [] });
mocks.callToolMock.mockResolvedValue({ ok: true });
mocks.listResourcesMock.mockResolvedValue({ resources: [] });
mocks.readCachedAccessTokenMock.mockResolvedValue(undefined);
});
afterEach(() => {
@ -237,6 +245,60 @@ describe('mcporter composability', () => {
expect(instance?.options?.env?.MCPORTER_STDIO_TEST).toBe('from-config');
expect(instance?.options?.env?.EXTRA).toBe('42');
});
it('applies cached auth for callTool connections', async () => {
mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token');
const runtime = await createRuntime({
servers: [
{
name: 'oauth',
command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') },
},
],
});
try {
await runtime.callTool('oauth', 'ping');
expect(mocks.readCachedAccessTokenMock).toHaveBeenCalledOnce();
const streamableTransport = mocks.streamableInstances[0] as {
options?: { requestInit?: { headers?: Record<string, string> } };
};
expect(streamableTransport.options?.requestInit?.headers).toEqual({
Authorization: 'Bearer cached-token',
});
} finally {
await runtime.close();
}
});
it('reconnects when callTool needs cached auth after an uncached connection', async () => {
const runtime = await createRuntime({
servers: [
{
name: 'oauth',
command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') },
},
],
});
try {
await runtime.listTools('oauth', { allowCachedAuth: false });
expect(mocks.streamableInstances).toHaveLength(1);
mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token');
await runtime.callTool('oauth', 'ping');
expect(mocks.streamableInstances).toHaveLength(2);
const streamableTransport = mocks.streamableInstances[1] as {
options?: { requestInit?: { headers?: Record<string, string> } };
};
expect(streamableTransport.options?.requestInit?.headers).toEqual({
Authorization: 'Bearer cached-token',
});
} finally {
await runtime.close();
}
});
});
describe('stdio transport environment', () => {

View File

@ -200,6 +200,7 @@ describe('mcporter serve bridge', () => {
server: 'alpha',
includeSchema: true,
autoAuthorize: true,
allowCachedAuth: true,
});
expect(baseRuntime.listTools).not.toHaveBeenCalled();