fix: use cached auth for daemon OAuth calls
This commit is contained in:
parent
1e6ce66d22
commit
86e19f4413
@ -12,7 +12,7 @@ const runtimeCache = new WeakMap<Runtime, Map<string, Promise<ToolMetadata[]>>>(
|
||||
function cacheKey(serverName: string, options: LoadToolMetadataOptions): string {
|
||||
const includeSchema = options.includeSchema !== false;
|
||||
const autoAuthorize = options.autoAuthorize !== false;
|
||||
const allowCachedAuth = options.allowCachedAuth === true;
|
||||
const allowCachedAuth = options.allowCachedAuth !== false;
|
||||
return `${serverName}::schema:${includeSchema ? '1' : '0'}::auth:${autoAuthorize ? '1' : '0'}::cached-auth:${allowCachedAuth ? '1' : '0'}`;
|
||||
}
|
||||
|
||||
@ -33,10 +33,11 @@ export async function loadToolMetadata(
|
||||
}
|
||||
const includeSchema = options.includeSchema !== false;
|
||||
const autoAuthorize = options.autoAuthorize !== false;
|
||||
const listOptions: ListToolsOptions =
|
||||
options.allowCachedAuth === undefined
|
||||
? { includeSchema, autoAuthorize }
|
||||
: { includeSchema, autoAuthorize, allowCachedAuth: options.allowCachedAuth };
|
||||
const listOptions: ListToolsOptions = {
|
||||
includeSchema,
|
||||
autoAuthorize,
|
||||
allowCachedAuth: options.allowCachedAuth ?? true,
|
||||
};
|
||||
const promise = runtime
|
||||
.listTools(serverName, listOptions)
|
||||
.then((tools) => tools.map((tool) => buildToolMetadata(tool)))
|
||||
|
||||
@ -302,7 +302,6 @@ async function processRequest(
|
||||
const result = await runtime.callTool(params.server, params.tool, {
|
||||
args: params.args ?? {},
|
||||
timeoutMs: params.timeoutMs,
|
||||
allowCachedAuth: params.allowCachedAuth ?? true,
|
||||
});
|
||||
markActivity(params.server, activity);
|
||||
if (loggable) {
|
||||
@ -328,7 +327,7 @@ async function processRequest(
|
||||
const result = await runtime.listTools(params.server, {
|
||||
includeSchema: params.includeSchema,
|
||||
autoAuthorize: params.autoAuthorize,
|
||||
allowCachedAuth: true,
|
||||
allowCachedAuth: params.allowCachedAuth ?? true,
|
||||
});
|
||||
markActivity(params.server, activity);
|
||||
if (loggable) {
|
||||
|
||||
@ -28,13 +28,13 @@ export interface CallToolParams {
|
||||
readonly tool: string;
|
||||
readonly args?: Record<string, unknown>;
|
||||
readonly timeoutMs?: number;
|
||||
readonly allowCachedAuth?: boolean;
|
||||
}
|
||||
|
||||
export interface ListToolsParams {
|
||||
readonly server: string;
|
||||
readonly includeSchema?: boolean;
|
||||
readonly autoAuthorize?: boolean;
|
||||
readonly allowCachedAuth?: boolean;
|
||||
}
|
||||
|
||||
export interface ListResourcesParams {
|
||||
|
||||
@ -61,6 +61,7 @@ class KeepAliveRuntime implements Runtime {
|
||||
server,
|
||||
includeSchema: options?.includeSchema,
|
||||
autoAuthorize: options?.autoAuthorize,
|
||||
allowCachedAuth: options?.allowCachedAuth ?? true,
|
||||
})
|
||||
)) as Awaited<ReturnType<Runtime['listTools']>>;
|
||||
}
|
||||
@ -75,7 +76,6 @@ class KeepAliveRuntime implements Runtime {
|
||||
tool: toolName,
|
||||
args: options?.args,
|
||||
timeoutMs: options?.timeoutMs,
|
||||
allowCachedAuth: options?.allowCachedAuth ?? true,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@ -103,7 +103,13 @@ export async function callOnce(params: {
|
||||
|
||||
class McpRuntime implements Runtime {
|
||||
private readonly definitions: Map<string, ServerDefinition>;
|
||||
private readonly clients = new Map<string, Promise<ClientContext>>();
|
||||
private readonly clients = new Map<
|
||||
string,
|
||||
{
|
||||
readonly promise: Promise<ClientContext>;
|
||||
readonly allowCachedAuth: boolean | undefined;
|
||||
}
|
||||
>();
|
||||
private readonly logger: RuntimeLogger;
|
||||
private readonly clientInfo: { name: string; version: string };
|
||||
private readonly oauthTimeoutMs?: number;
|
||||
@ -150,12 +156,12 @@ class McpRuntime implements Runtime {
|
||||
}
|
||||
|
||||
async getInstructions(server: string): Promise<string | undefined> {
|
||||
const contextPromise = this.clients.get(server.trim());
|
||||
if (!contextPromise) {
|
||||
const cached = this.clients.get(server.trim());
|
||||
if (!cached) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const context = await contextPromise;
|
||||
const context = await cached.promise;
|
||||
const instructions =
|
||||
typeof context.client.getInstructions === 'function' ? context.client.getInstructions() : undefined;
|
||||
if (typeof instructions !== 'string') {
|
||||
@ -175,7 +181,7 @@ class McpRuntime implements Runtime {
|
||||
const context = await this.connect(server, {
|
||||
maxOAuthAttempts: autoAuthorize ? undefined : 0,
|
||||
skipCache: !autoAuthorize,
|
||||
allowCachedAuth: options.allowCachedAuth,
|
||||
allowCachedAuth: options.allowCachedAuth ?? true,
|
||||
oauthSessionOptions: options.oauthSessionOptions,
|
||||
});
|
||||
try {
|
||||
@ -218,7 +224,9 @@ class McpRuntime implements Runtime {
|
||||
);
|
||||
}
|
||||
try {
|
||||
const { client } = await this.connect(server);
|
||||
const { client } = await this.connect(server, {
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
const params: CallToolRequest['params'] = {
|
||||
name: toolName,
|
||||
arguments: options.args ?? {},
|
||||
@ -276,7 +284,10 @@ class McpRuntime implements Runtime {
|
||||
if (useCache) {
|
||||
const existing = this.clients.get(normalized);
|
||||
if (existing) {
|
||||
return existing;
|
||||
if (existing.allowCachedAuth === options.allowCachedAuth || options.allowCachedAuth === undefined) {
|
||||
return existing.promise;
|
||||
}
|
||||
await this.close(normalized).catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,7 +305,7 @@ class McpRuntime implements Runtime {
|
||||
});
|
||||
|
||||
if (useCache) {
|
||||
this.clients.set(normalized, connection);
|
||||
this.clients.set(normalized, { promise: connection, allowCachedAuth: options.allowCachedAuth });
|
||||
try {
|
||||
return await connection;
|
||||
} catch (error) {
|
||||
@ -310,10 +321,11 @@ class McpRuntime implements Runtime {
|
||||
async close(server?: string): Promise<void> {
|
||||
if (server) {
|
||||
const normalized = server.trim();
|
||||
const context = await this.clients.get(normalized);
|
||||
if (!context) {
|
||||
const cached = this.clients.get(normalized);
|
||||
if (!cached) {
|
||||
return;
|
||||
}
|
||||
const context = await cached.promise;
|
||||
await context.client.close().catch(() => {});
|
||||
await closeTransportAndWait(this.logger, context.transport).catch(() => {});
|
||||
await context.oauthSession?.close().catch(() => {});
|
||||
@ -321,9 +333,9 @@ class McpRuntime implements Runtime {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const [name, promise] of this.clients.entries()) {
|
||||
for (const [name, cached] of this.clients.entries()) {
|
||||
try {
|
||||
const context = await promise;
|
||||
const context = await cached.promise;
|
||||
await context.client.close().catch(() => {});
|
||||
await closeTransportAndWait(this.logger, context.transport).catch(() => {});
|
||||
await context.oauthSession?.close().catch(() => {});
|
||||
|
||||
@ -82,7 +82,11 @@ describe('CLI call execution behavior', () => {
|
||||
},
|
||||
})
|
||||
);
|
||||
expect(listTools).toHaveBeenCalledWith('slack', { autoAuthorize: true, includeSchema: true });
|
||||
expect(listTools).toHaveBeenCalledWith('slack', {
|
||||
autoAuthorize: true,
|
||||
includeSchema: true,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
logSpy.mockRestore();
|
||||
});
|
||||
|
||||
@ -117,7 +121,11 @@ describe('CLI call execution behavior', () => {
|
||||
},
|
||||
})
|
||||
);
|
||||
expect(listTools).toHaveBeenCalledWith('email', { autoAuthorize: true, includeSchema: true });
|
||||
expect(listTools).toHaveBeenCalledWith('email', {
|
||||
autoAuthorize: true,
|
||||
includeSchema: true,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
logSpy.mockRestore();
|
||||
});
|
||||
|
||||
@ -326,7 +334,11 @@ describe('CLI call execution behavior', () => {
|
||||
expect(callTool).toHaveBeenCalledTimes(2);
|
||||
expect(callTool).toHaveBeenNthCalledWith(1, 'linear', 'listIssues', expect.objectContaining({ args: {} }));
|
||||
expect(callTool).toHaveBeenNthCalledWith(2, 'linear', 'list_issues', expect.objectContaining({ args: {} }));
|
||||
expect(listTools).toHaveBeenCalledWith('linear', { autoAuthorize: true, includeSchema: false });
|
||||
expect(listTools).toHaveBeenCalledWith('linear', {
|
||||
autoAuthorize: true,
|
||||
includeSchema: false,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
|
||||
logSpy.mockRestore();
|
||||
});
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { ServerDefinition } from '../src/config.js';
|
||||
import { __testProcessRequest } from '../src/daemon/host.js';
|
||||
import type { DaemonRequest } from '../src/daemon/protocol.js';
|
||||
import type { Runtime } from '../src/runtime.js';
|
||||
|
||||
describe('daemon host request handling', () => {
|
||||
it('reuses pre-parsed requests without reparsing payloads', async () => {
|
||||
const metadata = {
|
||||
configPath: '/tmp/config.json',
|
||||
configLayers: [],
|
||||
configMtimeMs: Date.now(),
|
||||
socketPath: '/tmp/socket',
|
||||
startedAt: Date.now(),
|
||||
logPath: null,
|
||||
};
|
||||
const logContext = { enabled: false, logAllServers: false, servers: new Set<string>() };
|
||||
const metadata = {
|
||||
configPath: '/tmp/config.json',
|
||||
configLayers: [],
|
||||
configMtimeMs: Date.now(),
|
||||
socketPath: '/tmp/socket',
|
||||
startedAt: Date.now(),
|
||||
logPath: null,
|
||||
};
|
||||
const logContext = { enabled: false, logAllServers: false, servers: new Set<string>() };
|
||||
|
||||
it('reuses pre-parsed requests without reparsing payloads', async () => {
|
||||
const parsedRequest: DaemonRequest = { id: '1', method: 'status', params: {} };
|
||||
const result = await __testProcessRequest(
|
||||
'!!!invalid-json!!!',
|
||||
@ -30,4 +30,69 @@ describe('daemon host request handling', () => {
|
||||
expect(result.response.ok).toBe(true);
|
||||
expect(result.shouldShutdown).toBe(false);
|
||||
});
|
||||
|
||||
it('defaults daemon callTool and listTools requests to cached auth', async () => {
|
||||
const runtime = createRuntimeDouble();
|
||||
const managedServers = createManagedServers();
|
||||
|
||||
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
|
||||
id: 'call',
|
||||
method: 'callTool',
|
||||
params: { server: 'oauth', tool: 'ping' },
|
||||
});
|
||||
|
||||
expect(runtime.callTool).toHaveBeenCalledWith('oauth', 'ping', {
|
||||
args: {},
|
||||
timeoutMs: undefined,
|
||||
});
|
||||
|
||||
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
|
||||
id: 'list',
|
||||
method: 'listTools',
|
||||
params: { server: 'oauth', includeSchema: true },
|
||||
});
|
||||
|
||||
expect(runtime.listTools).toHaveBeenCalledWith('oauth', {
|
||||
includeSchema: true,
|
||||
autoAuthorize: undefined,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('preserves explicit listTools cached-auth opt out on daemon requests', async () => {
|
||||
const runtime = createRuntimeDouble();
|
||||
const managedServers = createManagedServers();
|
||||
|
||||
await __testProcessRequest('', runtime as unknown as Runtime, managedServers, new Map(), metadata, logContext, {
|
||||
id: 'list',
|
||||
method: 'listTools',
|
||||
params: { server: 'oauth', allowCachedAuth: false },
|
||||
});
|
||||
|
||||
expect(runtime.listTools).toHaveBeenCalledWith('oauth', {
|
||||
includeSchema: undefined,
|
||||
autoAuthorize: undefined,
|
||||
allowCachedAuth: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function createRuntimeDouble(): Pick<Runtime, 'callTool' | 'listTools'> {
|
||||
return {
|
||||
callTool: vi.fn().mockResolvedValue({ ok: true }),
|
||||
listTools: vi.fn().mockResolvedValue([]),
|
||||
};
|
||||
}
|
||||
|
||||
function createManagedServers(): Map<string, ServerDefinition> {
|
||||
return new Map([
|
||||
[
|
||||
'oauth',
|
||||
{
|
||||
name: 'oauth',
|
||||
command: { kind: 'http', url: new URL('https://oauth.example.com/mcp') },
|
||||
lifecycle: { mode: 'keep-alive' },
|
||||
},
|
||||
],
|
||||
]);
|
||||
}
|
||||
|
||||
@ -105,7 +105,20 @@ describe('createKeepAliveRuntime', () => {
|
||||
});
|
||||
|
||||
await keepAliveRuntime.listTools('alpha', { includeSchema: true });
|
||||
expect(daemon.listTools).toHaveBeenCalledWith({ server: 'alpha', includeSchema: true, autoAuthorize: undefined });
|
||||
expect(daemon.listTools).toHaveBeenCalledWith({
|
||||
server: 'alpha',
|
||||
includeSchema: true,
|
||||
autoAuthorize: undefined,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
|
||||
await keepAliveRuntime.listTools('alpha', { allowCachedAuth: false });
|
||||
expect(daemon.listTools).toHaveBeenLastCalledWith({
|
||||
server: 'alpha',
|
||||
includeSchema: undefined,
|
||||
autoAuthorize: undefined,
|
||||
allowCachedAuth: false,
|
||||
});
|
||||
|
||||
await keepAliveRuntime.listResources('alpha', { cursor: '1' });
|
||||
expect(daemon.listResources).toHaveBeenCalledWith({ server: 'alpha', params: { cursor: '1' } });
|
||||
|
||||
@ -5,6 +5,7 @@ const mocks = vi.hoisted(() => {
|
||||
const listToolsMock = vi.fn();
|
||||
const callToolMock = vi.fn();
|
||||
const listResourcesMock = vi.fn();
|
||||
const readCachedAccessTokenMock = vi.fn();
|
||||
const clientInstances: unknown[] = [];
|
||||
const streamableInstances: unknown[] = [];
|
||||
const stdioInstances: unknown[] = [];
|
||||
@ -71,6 +72,7 @@ const mocks = vi.hoisted(() => {
|
||||
listToolsMock,
|
||||
callToolMock,
|
||||
listResourcesMock,
|
||||
readCachedAccessTokenMock,
|
||||
clientInstances,
|
||||
streamableInstances,
|
||||
stdioInstances,
|
||||
@ -102,6 +104,10 @@ vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
|
||||
UnauthorizedError: mocks.MockUnauthorizedError,
|
||||
}));
|
||||
|
||||
vi.mock('../src/oauth-persistence.js', () => ({
|
||||
readCachedAccessToken: mocks.readCachedAccessTokenMock,
|
||||
}));
|
||||
|
||||
import { createRuntime } from '../src/runtime.js';
|
||||
|
||||
describe('mcporter composability', () => {
|
||||
@ -110,6 +116,7 @@ describe('mcporter composability', () => {
|
||||
mocks.listToolsMock.mockReset();
|
||||
mocks.callToolMock.mockReset();
|
||||
mocks.listResourcesMock.mockReset();
|
||||
mocks.readCachedAccessTokenMock.mockReset();
|
||||
mocks.clientInstances.length = 0;
|
||||
mocks.streamableInstances.length = 0;
|
||||
mocks.stdioInstances.length = 0;
|
||||
@ -117,6 +124,7 @@ describe('mcporter composability', () => {
|
||||
mocks.listToolsMock.mockResolvedValue({ tools: [] });
|
||||
mocks.callToolMock.mockResolvedValue({ ok: true });
|
||||
mocks.listResourcesMock.mockResolvedValue({ resources: [] });
|
||||
mocks.readCachedAccessTokenMock.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@ -237,6 +245,60 @@ describe('mcporter composability', () => {
|
||||
expect(instance?.options?.env?.MCPORTER_STDIO_TEST).toBe('from-config');
|
||||
expect(instance?.options?.env?.EXTRA).toBe('42');
|
||||
});
|
||||
|
||||
it('applies cached auth for callTool connections', async () => {
|
||||
mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token');
|
||||
const runtime = await createRuntime({
|
||||
servers: [
|
||||
{
|
||||
name: 'oauth',
|
||||
command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
try {
|
||||
await runtime.callTool('oauth', 'ping');
|
||||
expect(mocks.readCachedAccessTokenMock).toHaveBeenCalledOnce();
|
||||
const streamableTransport = mocks.streamableInstances[0] as {
|
||||
options?: { requestInit?: { headers?: Record<string, string> } };
|
||||
};
|
||||
expect(streamableTransport.options?.requestInit?.headers).toEqual({
|
||||
Authorization: 'Bearer cached-token',
|
||||
});
|
||||
} finally {
|
||||
await runtime.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('reconnects when callTool needs cached auth after an uncached connection', async () => {
|
||||
const runtime = await createRuntime({
|
||||
servers: [
|
||||
{
|
||||
name: 'oauth',
|
||||
command: { kind: 'http' as const, url: new URL('https://oauth.example.com/mcp') },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
try {
|
||||
await runtime.listTools('oauth', { allowCachedAuth: false });
|
||||
expect(mocks.streamableInstances).toHaveLength(1);
|
||||
|
||||
mocks.readCachedAccessTokenMock.mockResolvedValue('cached-token');
|
||||
await runtime.callTool('oauth', 'ping');
|
||||
|
||||
expect(mocks.streamableInstances).toHaveLength(2);
|
||||
const streamableTransport = mocks.streamableInstances[1] as {
|
||||
options?: { requestInit?: { headers?: Record<string, string> } };
|
||||
};
|
||||
expect(streamableTransport.options?.requestInit?.headers).toEqual({
|
||||
Authorization: 'Bearer cached-token',
|
||||
});
|
||||
} finally {
|
||||
await runtime.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('stdio transport environment', () => {
|
||||
|
||||
@ -200,6 +200,7 @@ describe('mcporter serve bridge', () => {
|
||||
server: 'alpha',
|
||||
includeSchema: true,
|
||||
autoAuthorize: true,
|
||||
allowCachedAuth: true,
|
||||
});
|
||||
expect(baseRuntime.listTools).not.toHaveBeenCalled();
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user