feat: add oauth runtime support and integration tests
This commit is contained in:
parent
a751cadff0
commit
86003eb431
24
.github/workflows/ci.yml
vendored
Normal file
24
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 8
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: 'pnpm'
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm check
|
||||
- run: pnpm build
|
||||
- run: pnpm test
|
||||
24
README.md
24
README.md
@ -51,17 +51,31 @@ const result = await callOnce({
|
||||
### CLI
|
||||
|
||||
```
|
||||
npx mcp-runtime list # show all configured servers
|
||||
npx mcp-runtime list vercel --schema # list tools + schemas for vercel server
|
||||
npx mcp-runtime list # show all configured servers
|
||||
npx mcp-runtime list vercel --schema # list tools + schemas for vercel server
|
||||
npx mcp-runtime call linear.searchIssues --args '{"query":"status:InProgress"}'
|
||||
npx mcp-runtime call signoz.query --tail-log # tail log output when provided
|
||||
```
|
||||
|
||||
Pass `--config <path>` or set `MCP_RUNTIME_CONFIG` to override the config location (defaults to `./config/mcp_servers.json`).
|
||||
|
||||
### OAuth Flow
|
||||
|
||||
For servers that advertise `"auth": "oauth"` in `mcp_servers.json`, `mcp-runtime` launches a local callback server and stores refreshed tokens under `~/.mcp-runtime/<serverName>/`. The first call opens your browser; once the provider redirects back to `http://127.0.0.1:<port>/callback`, the runtime exchanges the code automatically and reuses the session for future requests.
|
||||
|
||||
### Migrating from `pnpm mcp:*`
|
||||
|
||||
- Replace `pnpm mcp:list` with `npx mcp-runtime list`.
|
||||
- Replace `pnpm mcp:call server.tool key=value` with `npx mcp-runtime call server.tool key=value`.
|
||||
- Optionally add `--tail-log` to follow log files returned by tools.
|
||||
|
||||
See [`docs/migration.md`](docs/migration.md) for a full comparison, environment setup checklist, and CLI usage examples.
|
||||
|
||||
## Roadmap
|
||||
|
||||
- OAuth helper parity with the Python wrapper (token caching + browser dance)
|
||||
- Streaming log helpers (tailing tool output)
|
||||
- Type-safe code generation for frequently used tool schemas
|
||||
- Improve the OAuth flow UX (auto-open fallback hints, timeouts, explicit `mcp-runtime auth <server>` helper).
|
||||
- Tail support for structured streaming content (not only file paths).
|
||||
- Type-safe code generation for frequently used tool schemas.
|
||||
- Release & versioning automation (CI, npm publish guardrails).
|
||||
|
||||
See [`docs/spec.md`](docs/spec.md) for the high-level implementation plan and open questions.
|
||||
|
||||
76
docs/migration.md
Normal file
76
docs/migration.md
Normal file
@ -0,0 +1,76 @@
|
||||
---
|
||||
summary: 'How to migrate from pnpm mcp:* wrappers to the mcp-runtime package.'
|
||||
---
|
||||
|
||||
# Migration Guide
|
||||
|
||||
This guide walks through replacing the Python-based `pnpm mcp:*` helpers with the new TypeScript runtime and CLI.
|
||||
|
||||
## 1. Install
|
||||
|
||||
```bash
|
||||
pnpm add mcp-runtime
|
||||
# or
|
||||
yarn add mcp-runtime
|
||||
# or
|
||||
npm install mcp-runtime
|
||||
```
|
||||
|
||||
## 2. Update Scripts
|
||||
|
||||
- Replace `pnpm mcp:list` with `npx mcp-runtime list`.
|
||||
- Replace `pnpm mcp:call <server>.<tool> key=value` with `npx mcp-runtime call <server>.<tool> key=value`.
|
||||
- Add `--config <path>` if your configuration is not under `./config/mcp_servers.json`.
|
||||
- Append `--tail-log` to stream the last 20 lines of any log file returned by the tool.
|
||||
|
||||
## 3. OAuth Tokens
|
||||
|
||||
- Tokens are saved under `~/.mcp-runtime/<server>/` by default.
|
||||
- To force a fresh login, delete that directory and rerun the command; the CLI will relaunch the browser.
|
||||
- Custom `token_cache_dir` entries in `mcp_servers.json` continue to work as explicit overrides.
|
||||
|
||||
## 4. Programmatic Usage
|
||||
|
||||
```ts
|
||||
import { createRuntime } from "mcp-runtime";
|
||||
|
||||
const runtime = await createRuntime({ configPath: "./config/mcp_servers.json" });
|
||||
const tools = await runtime.listTools("chrome-devtools");
|
||||
await runtime.callTool("chrome-devtools", "take_screenshot", { args: { url: "https://x.com" } });
|
||||
await runtime.close();
|
||||
```
|
||||
|
||||
Prefer `createRuntime` for long-lived agents so connections and OAuth tokens can be reused.
|
||||
|
||||
## 5. Single Call Helper
|
||||
|
||||
```ts
|
||||
import { callOnce } from "mcp-runtime";
|
||||
|
||||
await callOnce({
|
||||
server: "firecrawl",
|
||||
toolName: "crawl",
|
||||
args: { url: "https://anthropic.com" },
|
||||
});
|
||||
```
|
||||
|
||||
Use `callOnce` for fire-and-forget invocations.
|
||||
|
||||
## 6. Environment Variables
|
||||
|
||||
- `LINEAR_API_KEY`, `FIRECRAWL_API_KEY`, and similar tokens are read exactly as before via `${VAR}` syntax.
|
||||
- `${VAR:-default}` continues to work; empty values are ignored.
|
||||
- `$env:VAR` placeholders resolve to raw OS environment variables.
|
||||
|
||||
## 7. Troubleshooting
|
||||
|
||||
| Symptom | Fix |
|
||||
| --- | --- |
|
||||
| Browser did not open | Copy the printed OAuth URL manually into a browser. |
|
||||
| Authorization hangs | Ensure the callback URL can bind to `127.0.0.1`; firewalls may block it. |
|
||||
| Tokens are stale | Delete `~/.mcp-runtime/<server>/tokens.json` and retry. |
|
||||
| Stdio command fails | Pass `--root` to point at the repo root so relative paths resolve. |
|
||||
|
||||
---
|
||||
|
||||
For deeper architectural notes and future work, see [`docs/spec.md`](./spec.md).
|
||||
@ -25,7 +25,8 @@
|
||||
"lint": "pnpm check",
|
||||
"test": "vitest run",
|
||||
"clean": "rimraf dist",
|
||||
"dev": "tsc -w -p tsconfig.build.json"
|
||||
"dev": "tsc -w -p tsconfig.build.json",
|
||||
"prepublishOnly": "pnpm check && pnpm test && pnpm build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.10.1",
|
||||
@ -33,7 +34,9 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "^2.3.3",
|
||||
"@types/express": "^4.17.21",
|
||||
"@types/node": "^22.7.4",
|
||||
"express": "^4.21.1",
|
||||
"rimraf": "^6.0.1",
|
||||
"typescript": "^5.6.3",
|
||||
"vitest": "^1.6.0"
|
||||
|
||||
66
src/cli.ts
66
src/cli.ts
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env node
|
||||
import fs from "node:fs";
|
||||
import { createRuntime } from "./runtime.js";
|
||||
|
||||
type FlagMap = Partial<Record<string, string>>;
|
||||
@ -125,13 +126,16 @@ async function handleCall(
|
||||
try {
|
||||
const decoded = JSON.parse(result);
|
||||
console.log(JSON.stringify(decoded, null, 2));
|
||||
tailLogIfRequested(decoded, parsed.tailLog ?? false);
|
||||
} catch {
|
||||
console.log(result);
|
||||
tailLogIfRequested(result, parsed.tailLog ?? false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(JSON.stringify(result, null, 2));
|
||||
tailLogIfRequested(result, parsed.tailLog ?? false);
|
||||
}
|
||||
|
||||
function extractListFlags(args: string[]): { schema: boolean } {
|
||||
@ -154,10 +158,11 @@ interface CallArgsParseResult {
|
||||
server?: string;
|
||||
tool?: string;
|
||||
args: Record<string, unknown>;
|
||||
tailLog?: boolean;
|
||||
}
|
||||
|
||||
function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
const result: CallArgsParseResult = { args: {} };
|
||||
const result: CallArgsParseResult = { args: {}, tailLog: false };
|
||||
let index = 0;
|
||||
while (index < args.length) {
|
||||
const token = args[index];
|
||||
@ -200,6 +205,11 @@ function parseCallArguments(args: string[]): CallArgsParseResult {
|
||||
args.splice(index, 2);
|
||||
continue;
|
||||
}
|
||||
if (token === "--tail-log") {
|
||||
result.tailLog = true;
|
||||
args.splice(index, 1);
|
||||
continue;
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
|
||||
@ -250,6 +260,51 @@ function indent(text: string, pad: string): string {
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
function tailLogIfRequested(result: unknown, enabled: boolean): void {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
const candidates: string[] = [];
|
||||
if (typeof result === "string") {
|
||||
const idx = result.indexOf(":");
|
||||
if (idx !== -1) {
|
||||
const candidate = result.slice(idx + 1).trim();
|
||||
if (candidate) {
|
||||
candidates.push(candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (result && typeof result === "object") {
|
||||
const possibleKeys = ["logPath", "logFile", "logfile", "path"];
|
||||
for (const key of possibleKeys) {
|
||||
const value = (result as Record<string, unknown>)[key];
|
||||
if (typeof value === "string") {
|
||||
candidates.push(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const candidate of candidates) {
|
||||
if (!fs.existsSync(candidate)) {
|
||||
console.warn(`[warn] Log path not found: ${candidate}`);
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const content = fs.readFileSync(candidate, "utf8");
|
||||
const lines = content.trimEnd().split(/\r?\n/);
|
||||
const tail = lines.slice(-20);
|
||||
console.log(`--- tail ${candidate} ---`);
|
||||
for (const line of tail) {
|
||||
console.log(line);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`[warn] Failed to read log file ${candidate}: ${(error as Error).message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function printHelp(message?: string): void {
|
||||
if (message) {
|
||||
console.error(message);
|
||||
@ -258,12 +313,13 @@ function printHelp(message?: string): void {
|
||||
console.error(`Usage: mcp-runtime <command> [options]
|
||||
|
||||
Commands:
|
||||
list [name] [--schema] List configured MCP servers (and tools for a server)
|
||||
call [selector] [flags] Call a tool (selector like server.tool)
|
||||
list [name] [--schema] List configured MCP servers (and tools for a server)
|
||||
call [selector] [flags] Call a tool (selector like server.tool)
|
||||
--tail-log Tail log output when the tool returns a log file path
|
||||
|
||||
Global flags:
|
||||
--config <path> Path to mcp_servers.json (defaults to ./config/mcp_servers.json)
|
||||
--root <path> Root directory for stdio command cwd
|
||||
--config <path> Path to mcp_servers.json (defaults to ./config/mcp_servers.json)
|
||||
--root <path> Root directory for stdio command cwd
|
||||
`);
|
||||
}
|
||||
|
||||
|
||||
328
src/oauth.ts
Normal file
328
src/oauth.ts
Normal file
@ -0,0 +1,328 @@
|
||||
import { spawn } from "node:child_process";
|
||||
import { randomUUID } from "node:crypto";
|
||||
import fs from "node:fs/promises";
|
||||
import http from "node:http";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { URL } from "node:url";
|
||||
import type { OAuthClientProvider } from "@modelcontextprotocol/sdk/client/auth.js";
|
||||
import type {
|
||||
OAuthClientInformationMixed,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from "@modelcontextprotocol/sdk/shared/auth.js";
|
||||
import type { ServerDefinition } from "./config.js";
|
||||
|
||||
const CALLBACK_HOST = "127.0.0.1";
|
||||
|
||||
interface Deferred<T> {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
}
|
||||
|
||||
function createDeferred<T>(): Deferred<T> {
|
||||
let resolve!: (value: T) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
return { promise, resolve, reject };
|
||||
}
|
||||
|
||||
function openExternal(url: string) {
|
||||
const platform = process.platform;
|
||||
const stdio = "ignore";
|
||||
try {
|
||||
if (platform === "darwin") {
|
||||
const child = spawn("open", [url], { stdio, detached: true });
|
||||
child.unref();
|
||||
} else if (platform === "win32") {
|
||||
const child = spawn("cmd", ["/c", "start", '""', url], {
|
||||
stdio,
|
||||
detached: true,
|
||||
});
|
||||
child.unref();
|
||||
} else {
|
||||
const child = spawn("xdg-open", [url], { stdio, detached: true });
|
||||
child.unref();
|
||||
}
|
||||
} catch {
|
||||
// best-effort: fall back to printing URL
|
||||
}
|
||||
}
|
||||
|
||||
async function ensureDirectory(dir: string) {
|
||||
await fs.mkdir(dir, { recursive: true });
|
||||
}
|
||||
|
||||
async function readJsonFile<T>(filePath: string): Promise<T | undefined> {
|
||||
try {
|
||||
const raw = await fs.readFile(filePath, "utf8");
|
||||
return JSON.parse(raw) as T;
|
||||
} catch (error) {
|
||||
if ((error as NodeJS.ErrnoException).code === "ENOENT") {
|
||||
return undefined;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function writeJsonFile(filePath: string, data: unknown) {
|
||||
await ensureDirectory(path.dirname(filePath));
|
||||
await fs.writeFile(filePath, JSON.stringify(data, null, 2), "utf8");
|
||||
}
|
||||
|
||||
class FileOAuthClientProvider implements OAuthClientProvider {
|
||||
private readonly tokenPath: string;
|
||||
private readonly clientInfoPath: string;
|
||||
private readonly codeVerifierPath: string;
|
||||
private readonly statePath: string;
|
||||
private readonly metadata: OAuthClientMetadata;
|
||||
private readonly logger: OAuthLogger;
|
||||
private redirectUrlValue: URL;
|
||||
private authorizationDeferred: Deferred<string> | null = null;
|
||||
private server?: http.Server;
|
||||
|
||||
private constructor(
|
||||
private readonly definition: ServerDefinition,
|
||||
tokenCacheDir: string,
|
||||
redirectUrl: URL,
|
||||
logger: OAuthLogger,
|
||||
) {
|
||||
this.tokenPath = path.join(tokenCacheDir, "tokens.json");
|
||||
this.clientInfoPath = path.join(tokenCacheDir, "client.json");
|
||||
this.codeVerifierPath = path.join(tokenCacheDir, "code_verifier.txt");
|
||||
this.statePath = path.join(tokenCacheDir, "state.txt");
|
||||
this.redirectUrlValue = redirectUrl;
|
||||
this.logger = logger;
|
||||
this.metadata = {
|
||||
client_name: definition.clientName ?? `mcp-runtime (${definition.name})`,
|
||||
redirect_uris: [this.redirectUrlValue.toString()],
|
||||
grant_types: ["authorization_code", "refresh_token"],
|
||||
response_types: ["code"],
|
||||
token_endpoint_auth_method: "none",
|
||||
scope: "mcp:tools",
|
||||
};
|
||||
}
|
||||
|
||||
static async create(
|
||||
definition: ServerDefinition,
|
||||
logger: OAuthLogger,
|
||||
): Promise<{
|
||||
provider: FileOAuthClientProvider;
|
||||
close: () => Promise<void>;
|
||||
}> {
|
||||
const tokenDir =
|
||||
definition.tokenCacheDir ??
|
||||
path.join(os.homedir(), ".mcp-runtime", definition.name);
|
||||
await ensureDirectory(tokenDir);
|
||||
|
||||
const server = http.createServer();
|
||||
const port = await new Promise<number>((resolve, reject) => {
|
||||
server.listen(0, CALLBACK_HOST, () => {
|
||||
const address = server.address();
|
||||
if (typeof address === "object" && address && "port" in address) {
|
||||
resolve(address.port);
|
||||
} else {
|
||||
reject(new Error("Failed to determine callback port"));
|
||||
}
|
||||
});
|
||||
server.once("error", (error) => reject(error));
|
||||
});
|
||||
|
||||
const redirectUrl = new URL(`http://${CALLBACK_HOST}:${port}/callback`);
|
||||
|
||||
const provider = new FileOAuthClientProvider(
|
||||
definition,
|
||||
tokenDir,
|
||||
redirectUrl,
|
||||
logger,
|
||||
);
|
||||
provider.attachServer(server);
|
||||
return {
|
||||
provider,
|
||||
close: async () => {
|
||||
await provider.close();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private attachServer(server: http.Server) {
|
||||
this.server = server;
|
||||
server.on("request", async (req, res) => {
|
||||
try {
|
||||
const url = req.url ?? "";
|
||||
if (!url.startsWith("/callback")) {
|
||||
res.statusCode = 404;
|
||||
res.end("Not found");
|
||||
return;
|
||||
}
|
||||
const parsed = new URL(url, this.redirectUrlValue);
|
||||
const code = parsed.searchParams.get("code");
|
||||
const error = parsed.searchParams.get("error");
|
||||
if (code) {
|
||||
this.logger.info(
|
||||
`Received OAuth authorization code for ${this.definition.name}`,
|
||||
);
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/html");
|
||||
res.end(
|
||||
"<html><body><h1>Authorization successful</h1><p>You can return to the CLI.</p></body></html>",
|
||||
);
|
||||
this.authorizationDeferred?.resolve(code);
|
||||
this.authorizationDeferred = null;
|
||||
} else if (error) {
|
||||
res.statusCode = 400;
|
||||
res.setHeader("Content-Type", "text/html");
|
||||
res.end(
|
||||
`<html><body><h1>Authorization failed</h1><p>${error}</p></body></html>`,
|
||||
);
|
||||
this.authorizationDeferred?.reject(
|
||||
new Error(`OAuth error: ${error}`),
|
||||
);
|
||||
this.authorizationDeferred = null;
|
||||
} else {
|
||||
res.statusCode = 400;
|
||||
res.end("Missing authorization code");
|
||||
this.authorizationDeferred?.reject(
|
||||
new Error("Missing authorization code"),
|
||||
);
|
||||
this.authorizationDeferred = null;
|
||||
}
|
||||
} catch (error) {
|
||||
this.authorizationDeferred?.reject(error);
|
||||
this.authorizationDeferred = null;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
get redirectUrl(): string | URL {
|
||||
return this.redirectUrlValue;
|
||||
}
|
||||
|
||||
get clientMetadata(): OAuthClientMetadata {
|
||||
return this.metadata;
|
||||
}
|
||||
|
||||
async state(): Promise<string> {
|
||||
const existing = await readJsonFile<string>(this.statePath);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
const state = randomUUID();
|
||||
await writeJsonFile(this.statePath, state);
|
||||
return state;
|
||||
}
|
||||
|
||||
async clientInformation(): Promise<OAuthClientInformationMixed | undefined> {
|
||||
return readJsonFile<OAuthClientInformationMixed>(this.clientInfoPath);
|
||||
}
|
||||
|
||||
async saveClientInformation(
|
||||
clientInformation: OAuthClientInformationMixed,
|
||||
): Promise<void> {
|
||||
await writeJsonFile(this.clientInfoPath, clientInformation);
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
return readJsonFile<OAuthTokens>(this.tokenPath);
|
||||
}
|
||||
|
||||
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||
await writeJsonFile(this.tokenPath, tokens);
|
||||
this.logger.info(
|
||||
`Saved OAuth tokens for ${this.definition.name} to ${this.tokenPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
|
||||
this.logger.info(
|
||||
`Authorization required for ${this.definition.name}. Opening browser...`,
|
||||
);
|
||||
this.authorizationDeferred = createDeferred<string>();
|
||||
openExternal(authorizationUrl.toString());
|
||||
this.logger.info(
|
||||
`If the browser did not open, visit ${authorizationUrl.toString()} manually.`,
|
||||
);
|
||||
}
|
||||
|
||||
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||
await fs.writeFile(this.codeVerifierPath, codeVerifier, "utf8");
|
||||
}
|
||||
|
||||
async codeVerifier(): Promise<string> {
|
||||
const value = await fs.readFile(this.codeVerifierPath, "utf8");
|
||||
return value.trim();
|
||||
}
|
||||
|
||||
async invalidateCredentials(
|
||||
scope: "all" | "client" | "tokens" | "verifier",
|
||||
): Promise<void> {
|
||||
const removals: string[] = [];
|
||||
if (scope === "all" || scope === "tokens") removals.push(this.tokenPath);
|
||||
if (scope === "all" || scope === "client")
|
||||
removals.push(this.clientInfoPath);
|
||||
if (scope === "all" || scope === "verifier")
|
||||
removals.push(this.codeVerifierPath);
|
||||
await Promise.all(
|
||||
removals.map(async (file) => {
|
||||
try {
|
||||
await fs.unlink(file);
|
||||
} catch (error) {
|
||||
if ((error as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async waitForAuthorizationCode(): Promise<string> {
|
||||
if (!this.authorizationDeferred) {
|
||||
this.authorizationDeferred = createDeferred<string>();
|
||||
}
|
||||
return this.authorizationDeferred.promise;
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
if (!this.server) {
|
||||
return;
|
||||
}
|
||||
await new Promise<void>((resolve) => {
|
||||
this.server?.close(() => resolve());
|
||||
});
|
||||
this.server = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export interface OAuthSession {
|
||||
provider: OAuthClientProvider & {
|
||||
waitForAuthorizationCode: () => Promise<string>;
|
||||
};
|
||||
waitForAuthorizationCode: () => Promise<string>;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
export async function createOAuthSession(
|
||||
definition: ServerDefinition,
|
||||
logger: OAuthLogger,
|
||||
): Promise<OAuthSession> {
|
||||
const { provider, close } = await FileOAuthClientProvider.create(
|
||||
definition,
|
||||
logger,
|
||||
);
|
||||
return {
|
||||
provider: Object.assign(provider, {
|
||||
waitForAuthorizationCode: () => provider.waitForAuthorizationCode(),
|
||||
}),
|
||||
waitForAuthorizationCode: () => provider.waitForAuthorizationCode(),
|
||||
close,
|
||||
};
|
||||
}
|
||||
export interface OAuthLogger {
|
||||
info(message: string): void;
|
||||
warn(message: string): void;
|
||||
error(message: string, error?: unknown): void;
|
||||
}
|
||||
116
src/runtime.ts
116
src/runtime.ts
@ -12,6 +12,7 @@ import type {
|
||||
} from "@modelcontextprotocol/sdk/types.js";
|
||||
import { loadServerDefinitions, type ServerDefinition } from "./config.js";
|
||||
import { withEnvOverrides } from "./env.js";
|
||||
import { createOAuthSession, type OAuthSession } from "./oauth.js";
|
||||
|
||||
const PACKAGE_NAME = "mcp-runtime";
|
||||
const CLIENT_VERSION = "0.0.1";
|
||||
@ -73,6 +74,7 @@ interface ClientContext {
|
||||
readonly client: Client;
|
||||
readonly transport: Transport & { close(): Promise<void> };
|
||||
readonly definition: ServerDefinition;
|
||||
readonly oauthSession?: OAuthSession;
|
||||
}
|
||||
|
||||
export async function createRuntime(
|
||||
@ -201,6 +203,7 @@ class McpRuntime implements Runtime {
|
||||
return;
|
||||
}
|
||||
await context.transport.close().catch(() => {});
|
||||
await context.oauthSession?.close().catch(() => {});
|
||||
this.clients.delete(normalized);
|
||||
return;
|
||||
}
|
||||
@ -209,6 +212,7 @@ class McpRuntime implements Runtime {
|
||||
try {
|
||||
const context = await promise;
|
||||
await context.transport.close().catch(() => {});
|
||||
await context.oauthSession?.close().catch(() => {});
|
||||
} finally {
|
||||
this.clients.delete(name);
|
||||
}
|
||||
@ -221,6 +225,11 @@ class McpRuntime implements Runtime {
|
||||
const client = new Client(this.clientInfo);
|
||||
|
||||
return withEnvOverrides(definition.env, async () => {
|
||||
let oauthSession: OAuthSession | undefined;
|
||||
if (definition.auth === "oauth") {
|
||||
oauthSession = await createOAuthSession(definition, this.logger);
|
||||
}
|
||||
|
||||
if (definition.command.kind === "stdio") {
|
||||
const transport = new StdioClientTransport({
|
||||
command: definition.command.command,
|
||||
@ -228,42 +237,109 @@ class McpRuntime implements Runtime {
|
||||
cwd: definition.command.cwd,
|
||||
});
|
||||
await client.connect(transport);
|
||||
return { client, transport, definition };
|
||||
return { client, transport, definition, oauthSession };
|
||||
}
|
||||
|
||||
const requestInit: RequestInit = definition.command.headers
|
||||
? { headers: definition.command.headers as HeadersInit }
|
||||
: {};
|
||||
|
||||
const baseOptions = {
|
||||
requestInit,
|
||||
authProvider: oauthSession?.provider,
|
||||
};
|
||||
|
||||
const streamableTransport = new StreamableHTTPClientTransport(
|
||||
definition.command.url,
|
||||
{
|
||||
requestInit,
|
||||
},
|
||||
baseOptions,
|
||||
);
|
||||
|
||||
try {
|
||||
await client.connect(streamableTransport);
|
||||
return { client, transport: streamableTransport, definition };
|
||||
} catch (error) {
|
||||
await streamableTransport.close().catch(() => {});
|
||||
if (error instanceof UnauthorizedError) {
|
||||
this.logger.warn(
|
||||
`Authentication required for '${definition.name}'. OAuth flows are not yet implemented in mcp-runtime.`,
|
||||
try {
|
||||
await this.connectWithAuth(
|
||||
client,
|
||||
streamableTransport,
|
||||
oauthSession,
|
||||
definition.name,
|
||||
);
|
||||
throw error;
|
||||
return {
|
||||
client,
|
||||
transport: streamableTransport,
|
||||
definition,
|
||||
oauthSession,
|
||||
};
|
||||
} catch (error) {
|
||||
await streamableTransport.close().catch(() => {});
|
||||
this.logger.info(
|
||||
`Falling back to SSE transport for '${definition.name}': ${(error as Error).message}`,
|
||||
);
|
||||
const sseTransport = new SSEClientTransport(definition.command.url, {
|
||||
...baseOptions,
|
||||
});
|
||||
await this.connectWithAuth(
|
||||
client,
|
||||
sseTransport,
|
||||
oauthSession,
|
||||
definition.name,
|
||||
);
|
||||
return { client, transport: sseTransport, definition, oauthSession };
|
||||
}
|
||||
this.logger.info(
|
||||
`Falling back to SSE transport for '${definition.name}': ${(error as Error).message}`,
|
||||
);
|
||||
const sseTransport = new SSEClientTransport(definition.command.url, {
|
||||
requestInit,
|
||||
});
|
||||
await client.connect(sseTransport);
|
||||
return { client, transport: sseTransport, definition };
|
||||
} catch (error) {
|
||||
await oauthSession?.close().catch(() => {});
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async connectWithAuth(
|
||||
client: Client,
|
||||
transport: Transport & {
|
||||
close(): Promise<void>;
|
||||
finishAuth?: (authorizationCode: string) => Promise<void>;
|
||||
},
|
||||
session?: OAuthSession,
|
||||
serverName?: string,
|
||||
maxAttempts = 3,
|
||||
): Promise<void> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
try {
|
||||
await client.connect(transport);
|
||||
return;
|
||||
} catch (error) {
|
||||
if (!(error instanceof UnauthorizedError) || !session) {
|
||||
throw error;
|
||||
}
|
||||
attempt += 1;
|
||||
if (attempt > maxAttempts) {
|
||||
throw error;
|
||||
}
|
||||
this.logger.warn(
|
||||
`OAuth authorization required for '${serverName ?? "unknown"}'. Waiting for browser approval...`,
|
||||
);
|
||||
try {
|
||||
const code = await session.waitForAuthorizationCode();
|
||||
if (typeof transport.finishAuth === "function") {
|
||||
await transport.finishAuth(code);
|
||||
this.logger.info(
|
||||
"Authorization code accepted. Retrying connection...",
|
||||
);
|
||||
} else {
|
||||
this.logger.warn(
|
||||
"Transport does not support finishAuth; cannot complete OAuth flow automatically.",
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
} catch (authError) {
|
||||
this.logger.error(
|
||||
"OAuth authorization failed while waiting for callback.",
|
||||
authError,
|
||||
);
|
||||
throw authError;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function createConsoleLogger(): RuntimeLogger {
|
||||
|
||||
108
tests/runtime-integration.test.ts
Normal file
108
tests/runtime-integration.test.ts
Normal file
@ -0,0 +1,108 @@
|
||||
import type { Server as HttpServer } from "node:http";
|
||||
import type { AddressInfo } from "node:net";
|
||||
import {
|
||||
McpServer,
|
||||
ResourceTemplate,
|
||||
} from "@modelcontextprotocol/sdk/server/mcp.js";
|
||||
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
|
||||
import express from "express";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { z } from "zod";
|
||||
import { createRuntime } from "../src/runtime.js";
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
const server = new McpServer({
|
||||
name: "integration-demo",
|
||||
version: "1.0.0",
|
||||
});
|
||||
|
||||
server.registerTool(
|
||||
"add",
|
||||
{
|
||||
title: "Addition Tool",
|
||||
description: "Add two numbers",
|
||||
inputSchema: { a: z.number(), b: z.number() },
|
||||
outputSchema: { result: z.number() },
|
||||
},
|
||||
async ({ a, b }) => {
|
||||
const result = { result: a + b };
|
||||
return {
|
||||
content: [{ type: "text", text: JSON.stringify(result) }],
|
||||
structuredContent: result,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
server.registerResource(
|
||||
"greeting",
|
||||
new ResourceTemplate("greeting://{name}", { list: undefined }),
|
||||
{
|
||||
title: "Greeting",
|
||||
description: "Dynamic greeting resource",
|
||||
},
|
||||
async (uri, { name }) => ({
|
||||
contents: [
|
||||
{
|
||||
uri: uri.href,
|
||||
text: `Hello, ${name}!`,
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
app.post("/mcp", async (req, res) => {
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
enableJsonResponse: true,
|
||||
});
|
||||
|
||||
res.on("close", () => {
|
||||
transport.close().catch(() => {});
|
||||
});
|
||||
|
||||
await server.connect(transport);
|
||||
await transport.handleRequest(req, res, req.body);
|
||||
});
|
||||
|
||||
let httpServer: HttpServer;
|
||||
let baseUrl: URL;
|
||||
|
||||
describe("runtime integration", () => {
|
||||
beforeAll(async () => {
|
||||
httpServer = app.listen(0, "127.0.0.1");
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
httpServer.once("listening", resolve);
|
||||
httpServer.once("error", reject);
|
||||
});
|
||||
const address = httpServer.address() as AddressInfo;
|
||||
baseUrl = new URL(`http://127.0.0.1:${address.port}/mcp`);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await new Promise<void>((resolve) => httpServer.close(() => resolve()));
|
||||
});
|
||||
|
||||
it("lists tools and calls a tool over HTTP", async () => {
|
||||
const runtime = await createRuntime({
|
||||
servers: [
|
||||
{
|
||||
name: "integration",
|
||||
description: "Integration test server",
|
||||
command: { kind: "http", url: baseUrl },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const tools = await runtime.listTools("integration");
|
||||
expect(tools.some((tool) => tool.name === "add")).toBe(true);
|
||||
|
||||
const result = (await runtime.callTool("integration", "add", {
|
||||
args: { a: 3, b: 4 },
|
||||
})) as { structuredContent?: { result: number } };
|
||||
|
||||
expect(result.structuredContent?.result).toBe(7);
|
||||
|
||||
await runtime.close("integration");
|
||||
});
|
||||
});
|
||||
Loading…
Reference in New Issue
Block a user