diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ff30e0c Binary files /dev/null and b/.DS_Store differ diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..2bbd4ec --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,476 @@ +# Tachikoma Architecture + +This document provides a detailed technical overview of the Tachikoma AI integration library architecture. + +## Overview + +Tachikoma is designed as a modular, type-safe Swift package that abstracts AI provider differences behind a unified interface. The architecture emphasizes Swift 6 concurrency safety, performance, and extensibility. + +## Core Architecture Principles + +### 1. Protocol-Oriented Design +All AI providers implement the `ModelInterface` protocol, ensuring consistent behavior across different services while allowing provider-specific optimizations. + +### 2. Swift 6 Strict Concurrency +- All public APIs are actor-safe +- Sendable conformance throughout the type system +- `@MainActor` isolation where appropriate +- No data races or concurrency issues + +### 3. Type Safety +- Strongly-typed message system with enum-based content types +- Compile-time verification of tool parameters +- Generic tool system with context type safety + +### 4. Performance First +- Intelligent caching with configurable policies +- Streaming responses with minimal memory overhead +- Lazy provider initialization +- Efficient JSON handling without reflection + +## Module Structure + +``` +Tachikoma/ +├── Sources/Tachikoma/ +│ ├── Tachikoma.swift # Main API entry point +│ ├── Core/ # Core abstractions +│ │ ├── ModelInterface.swift # Provider protocol +│ │ ├── ModelProvider.swift # Provider registry & management +│ │ ├── MessageTypes.swift # Message type system +│ │ ├── StreamingTypes.swift # Streaming event system +│ │ ├── ModelParameters.swift # Request/response parameters +│ │ ├── ToolDefinitions.swift # Tool calling system +│ │ └── TachikomaError.swift # Error handling +│ └── Providers/ # Provider implementations +│ ├── OpenAI/ +│ │ ├── OpenAIModel.swift # OpenAI implementation +│ │ └── OpenAITypes.swift # OpenAI-specific types +│ ├── Anthropic/ +│ │ ├── AnthropicModel.swift # Anthropic implementation +│ │ └── AnthropicTypes.swift # Anthropic-specific types +│ ├── Grok/ +│ │ ├── GrokModel.swift # Grok implementation +│ │ └── GrokTypes.swift # Grok-specific types +│ └── Ollama/ +│ ├── OllamaModel.swift # Ollama implementation +│ └── OllamaTypes.swift # Ollama-specific types +``` + +## Core Components + +### ModelInterface Protocol + +The central abstraction that all providers implement: + +```swift +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public protocol ModelInterface: Sendable { + /// Masked API key for debugging + var maskedApiKey: String { get } + + /// Get a single response + func getResponse(request: ModelRequest) async throws -> ModelResponse + + /// Get streaming response + func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream +} +``` + +**Key Design Decisions:** +- `Sendable` conformance ensures thread safety +- Async/await for all network operations +- Streaming uses `AsyncThrowingStream` for memory efficiency +- Availability annotations ensure compatibility + +### Message Type System + +Hierarchical message types that handle all AI interaction patterns: + +```swift +public enum Message: Codable, Sendable { + case system(id: String? = nil, content: String) + case user(id: String? = nil, content: MessageContent) + case assistant(id: String? = nil, content: [AssistantContent], status: MessageStatus = .completed) + case tool(id: String? = nil, toolCallId: String, content: String) + case reasoning(id: String? = nil, content: String) +} +``` + +**Content Type Hierarchy:** +- `MessageContent`: Text, images, multimodal, files, audio +- `AssistantContent`: Text output, refusals, tool calls +- `ImageContent`: URLs, base64 data, detail levels +- `AudioContent`: Transcripts, durations, metadata + +**Benefits:** +- Type safety prevents invalid message construction +- Codable conformance for persistence +- Sendable for concurrency safety +- Extensible for new content types + +### Streaming System + +Real-time event processing with comprehensive event types: + +```swift +public enum StreamEvent { + case responseStarted(StreamResponseStarted) + case textDelta(StreamTextDelta) + case toolCallDelta(StreamToolCallDelta) + case toolCallCompleted(StreamToolCallCompleted) + case responseCompleted(StreamResponseCompleted) + case error(StreamError) +} +``` + +**Stream Processing Flow:** +1. `responseStarted` - Metadata and initialization +2. `textDelta` - Incremental text content +3. `toolCallDelta` - Incremental tool call construction +4. `toolCallCompleted` - Complete tool call available +5. `responseCompleted` - Stream finished with final metadata +6. `error` - Error events for handling failures + +**Implementation Details:** +- Each provider converts its streaming format to unified events +- Back-pressure handling through `AsyncThrowingStream` +- Automatic event ordering and consistency validation + +### Tool Calling System + +Generic, type-safe tool execution with context support: + +```swift +public struct Tool { + public let execute: (ToolInput, Context) async throws -> ToolOutput + + public func toToolDefinition() -> ToolDefinition { + // Convert to provider-agnostic definition + } +} +``` + +**Type Safety Features:** +- Generic context ensures compile-time type checking +- Parameter validation through JSON Schema +- Async execution for I/O operations +- Error handling with structured failures + +**Tool Definition System:** +```swift +public struct ToolDefinition { + public let function: FunctionDefinition + public let type: ToolType = .function +} + +public struct FunctionDefinition { + public let name: String + public let description: String? + public let parameters: ToolParameters +} +``` + +### Error Handling + +Comprehensive error system with recovery guidance: + +```swift +public enum TachikomaError: Error, LocalizedError { + case modelNotFound(String) + case authenticationFailed + case invalidConfiguration(String) + case networkError(underlying: any Error) + case rateLimited + case insufficientQuota + case contextLengthExceeded + // ... more cases + + public var isRetryable: Bool { /* logic */ } + public var recoverySuggestion: String? { /* guidance */ } +} +``` + +**Error Categories:** +- **Client Errors**: Invalid requests, configuration issues +- **Authentication Errors**: API key problems, quota issues +- **Network Errors**: Connectivity, timeouts, server errors +- **Provider Errors**: Model-specific limitations + +## Provider Implementations + +### OpenAI Provider + +**Dual API Support:** +- Chat Completions API (`/v1/chat/completions`) for standard models +- Responses API (`/v1/responses`) for o3/o4 reasoning models + +**Key Features:** +- Automatic API selection based on model capabilities +- Parameter filtering (o3/o4 models don't support temperature) +- Reasoning summary handling for thinking models +- Complete streaming support for both APIs + +**Implementation Highlights:** +```swift +private func convertToOpenAIRequest(_ request: ModelRequest, stream: Bool) throws -> OpenAIRequest { + // Convert unified request to OpenAI format + // Handle parameter filtering + // Support both API formats +} +``` + +### Anthropic Provider + +**Native Claude Integration:** +- Direct Claude API with proper message formatting +- Content blocks for multimodal inputs +- System prompt separation +- Tool result handling as user messages + +**Streaming Implementation:** +- Server-Sent Events (SSE) processing +- Delta accumulation for tool calls +- Proper handling of Claude's content block structure + +**Claude 4 Features:** +- Extended thinking modes +- Long-running task support +- Advanced reasoning capabilities + +### Grok Provider + +**OpenAI Compatibility:** +- Uses OpenAI-compatible Chat Completions API +- Parameter filtering for Grok 3/4 models +- Standard streaming implementation + +**Optimizations:** +- Efficient parameter encoding +- Proper error response handling +- Rate limiting awareness + +### Ollama Provider + +**Local Inference:** +- HTTP API for local models +- Custom timeout handling (5 minutes for model loading) +- Tool calling detection for compatible models + +**Model Support:** +- Language models: llama3.3, mistral, etc. +- Vision models: llava, bakllava (no tool calling) +- Custom model endpoints + +## Provider Registry & Management + +### ModelProvider (Actor) + +Central registry for all model factories and instances: + +```swift +@MainActor +public final class ModelProvider { + public static let shared = ModelProvider() + + private var modelFactories: [String: @Sendable () throws -> any ModelInterface] = [:] + private var modelCache: [String: any ModelInterface] = [:] + + public func getModel(_ modelName: String) async throws -> any ModelInterface + public func register(modelName: String, factory: @escaping @Sendable () throws -> any ModelInterface) +} +``` + +**Registration System:** +- Default model registration at startup +- Custom factory registration +- Lenient name matching (e.g., "gpt" → "gpt-4.1") +- Provider/model path resolution ("openai/gpt-4") + +**Caching Strategy:** +- Model instances cached after first creation +- Cache invalidation on registration changes +- Memory-efficient with weak references where appropriate + +## Concurrency & Threading + +### Actor Usage + +**ModelProvider as MainActor:** +- Centralizes model management +- Ensures thread-safe registration +- Coordinates provider initialization + +**Sendable Conformance:** +- All message types are Sendable +- Model instances are Sendable +- Error types are Sendable +- Tool definitions are Sendable + +**Async/Await Integration:** +- All network operations are async +- Streaming uses AsyncThrowingStream +- No blocking operations on main thread + +### Memory Management + +**Streaming Efficiency:** +- Events processed incrementally +- No accumulation of entire responses +- Automatic memory cleanup + +**Cache Management:** +- Model instances cached intelligently +- Configurable cache policies +- Weak references for large objects + +## Security Considerations + +### API Key Handling + +**Environment Variables:** +- Support for multiple key formats +- Secure key storage recommendations +- Masked keys in debug output + +**Key Security:** +- Never log full API keys +- Secure transmission only +- No persistence of keys in plain text + +### Input Validation + +**Parameter Validation:** +- Type-safe parameter construction +- Range validation for numeric parameters +- Required field enforcement + +**Content Filtering:** +- Provider-specific content policies +- Error handling for filtered content +- Transparent policy communication + +## Performance Characteristics + +### Network Efficiency + +**Connection Management:** +- URLSession with appropriate timeouts +- HTTP/2 support where available +- Connection pooling + +**Request Optimization:** +- Minimal payload size +- Efficient JSON encoding +- Compression support + +### Memory Usage + +**Streaming Responses:** +- Constant memory usage regardless of response size +- Incremental processing +- Automatic garbage collection + +**Object Creation:** +- Minimal allocations in hot paths +- Reuse of formatter objects +- Efficient string handling + +## Testing Strategy + +### Unit Tests + +**Provider Tests:** +- Mock network responses +- Error condition testing +- Parameter validation tests + +**Integration Tests:** +- End-to-end flow testing +- Streaming behavior validation +- Tool calling integration + +**Performance Tests:** +- Memory usage profiling +- Response time benchmarks +- Concurrent request handling + +## Extension Points + +### Custom Providers + +Implement `ModelInterface` to add new providers: + +```swift +class CustomProvider: ModelInterface { + var maskedApiKey: String { "custom-***" } + + func getResponse(request: ModelRequest) async throws -> ModelResponse { + // Custom implementation + } + + func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + // Custom streaming + } +} +``` + +### Message Type Extensions + +Add new content types by extending `MessageContent`: + +```swift +extension MessageContent { + case customType(CustomData) +} +``` + +### Tool System Extensions + +Create specialized tool contexts: + +```swift +struct DatabaseContext { + let connection: DatabaseConnection + let schema: Schema +} + +let dbTool = Tool { input, context in + // Database operations with type-safe context +} +``` + +## Future Considerations + +### Planned Features + +**Enhanced Caching:** +- Persistent cache with TTL +- Smart cache invalidation +- Distributed caching support + +**Advanced Streaming:** +- Bidirectional streaming +- Stream multiplexing +- Custom event types + +**Provider Enhancements:** +- More granular configuration +- Provider-specific optimizations +- Enhanced error recovery + +### Scalability + +**High-Volume Usage:** +- Connection pooling improvements +- Request batching +- Rate limiting integration + +**Enterprise Features:** +- Audit logging +- Metrics collection +- Custom authentication + +--- + +This architecture provides a solid foundation for AI integration while maintaining flexibility for future enhancements and provider additions. \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d05c911 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,173 @@ +# Changelog + +All notable changes to the Tachikoma project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.0] - 2025-01-XX + +### Added + +#### Core Framework +- Initial release of Tachikoma AI integration library +- Unified `ModelInterface` protocol for all AI providers +- Comprehensive message type system with multimodal support +- Real-time streaming response handling with `AsyncThrowingStream` +- Type-safe tool calling system with generic context support +- Actor-based provider registry with intelligent caching +- Swift 6 strict concurrency compliance throughout + +#### Provider Support +- **OpenAI Provider**: Complete integration with dual API support + - Chat Completions API for standard models (GPT-4o, GPT-4.1) + - Responses API for reasoning models (o3, o4 series) + - Automatic API selection based on model capabilities + - Parameter filtering for reasoning models + - Full streaming support for both APIs + - Reasoning summary handling for thinking models + +- **Anthropic Provider**: Native Claude API integration + - Support for Claude 4 (Opus, Sonnet) with thinking modes + - Claude 3.5/3.7 series compatibility + - Content block handling for multimodal inputs + - System prompt separation + - Server-Sent Events streaming + - Extended reasoning capabilities + +- **Grok Provider**: xAI integration with OpenAI compatibility + - Grok 4, Grok 3, Grok 2 series support + - Vision model capabilities + - Parameter filtering for Grok 3/4 models + - Standard streaming implementation + - OpenAI-compatible Chat Completions API + +- **Ollama Provider**: Local model inference support + - Support for Llama 3.3 (recommended), Mistral, CodeLlama + - Vision models (llava, bakllava) without tool calling + - Configurable endpoints for self-hosted deployments + - Extended timeouts for local model loading + - Tool calling detection for compatible models + +#### Message System +- **Unified Message Types**: Support for system, user, assistant, tool, and reasoning messages +- **Content Types**: Text, images (URL/base64), multimodal, files, audio with transcripts +- **Assistant Content**: Text output, refusals, tool calls with proper typing +- **Image Support**: High/low detail levels, multiple formats, base64 encoding +- **Audio Support**: Transcript extraction, duration metadata + +#### Streaming System +- **Event-Based Architecture**: Comprehensive streaming event types +- **Real-Time Processing**: Incremental text deltas, tool call construction +- **Memory Efficiency**: Constant memory usage regardless of response size +- **Error Handling**: Structured error events with recovery information +- **Provider Abstraction**: Unified events across different provider formats + +#### Tool Calling +- **Generic Tool System**: Type-safe tool execution with context support +- **Parameter Validation**: JSON Schema-based parameter validation +- **Async Execution**: Non-blocking tool execution with proper error handling +- **Tool Definitions**: Provider-agnostic tool definition format +- **Context Management**: Type-safe context passing to tool functions + +#### Error Handling +- **Comprehensive Error Types**: Structured error hierarchy with recovery guidance +- **Provider-Specific Errors**: Tailored error handling for each provider +- **Retry Logic**: Built-in retry detection with exponential backoff support +- **Error Categories**: Client, authentication, network, and provider errors +- **Localized Descriptions**: User-friendly error messages with recovery suggestions + +#### Configuration System +- **Environment Variables**: Support for standard API key environment variables +- **Provider Configuration**: Flexible configuration for custom endpoints +- **Model Registration**: Runtime model factory registration +- **Lenient Matching**: Intelligent model name resolution +- **Cache Management**: Configurable caching policies + +### Technical Features + +#### Swift 6 Compliance +- **Strict Concurrency**: Full Swift 6 strict concurrency mode compliance +- **Sendable Conformance**: All public types conform to Sendable protocol +- **Actor Safety**: Thread-safe operations with proper isolation +- **Memory Safety**: No data races or concurrency issues +- **Performance**: Optimized for concurrent execution + +#### Performance Optimizations +- **Intelligent Caching**: Model instance caching with smart invalidation +- **Connection Pooling**: Efficient network connection management +- **Memory Management**: Minimal allocations and efficient garbage collection +- **Streaming Efficiency**: Incremental processing without accumulation +- **JSON Optimization**: Fast encoding/decoding without reflection + +#### Type Safety +- **Compile-Time Verification**: Strong typing throughout the API +- **Generic Constraints**: Type-safe tool contexts and parameters +- **Enum-Based Design**: Exhaustive pattern matching for robustness +- **Protocol-Oriented**: Clean abstractions with concrete implementations + +### Documentation +- Comprehensive README with quick start guide +- Detailed architecture documentation +- API reference documentation +- Code examples for common usage patterns +- Migration guide from PeekabooCore +- Performance optimization guidelines + +### Testing +- Unit tests for all core components +- Integration tests for provider functionality +- Mock providers for testing scenarios +- Performance benchmarks +- Concurrency safety tests + +### Platform Support +- macOS 14.0+ +- iOS 17.0+ +- watchOS 10.0+ +- tvOS 17.0+ +- Swift 6.0+ +- Xcode 16.0+ + +## [Unreleased] + +### Planned Features +- Enhanced caching with persistence and TTL +- Bidirectional streaming support +- Request batching for high-volume usage +- Advanced error recovery mechanisms +- Metrics collection and monitoring +- Distributed caching support + +--- + +## Version History + +- **v1.0.0**: Initial release extracted from PeekabooCore with Swift 6 compliance +- **v0.x.x**: Development versions (internal) + +## Migration Notes + +### From PeekabooCore +When migrating from PeekabooCore's AI system: + +1. **Error Types**: Replace `PeekabooError` with `TachikomaError` +2. **Import Statements**: Update to `import Tachikoma` +3. **Model Creation**: Use `Tachikoma.shared.getModel()` instead of direct instantiation +4. **Streaming Events**: Update event handling for new event type hierarchy +5. **Message Types**: Adopt new unified message type system +6. **Tool Calling**: Update to generic tool system with context support + +### Breaking Changes +This is the initial release, so no breaking changes from previous versions. + +## Contributors + +- **Extraction Lead**: AI Assistant +- **Original Code**: Peekaboo project contributors +- **Architecture Design**: Based on proven patterns from PeekabooCore +- **Swift 6 Migration**: Complete rewrite for strict concurrency compliance + +## License + +This project is licensed under the MIT License. See LICENSE file for details. \ No newline at end of file diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..f2401e3 --- /dev/null +++ b/Package.resolved @@ -0,0 +1,15 @@ +{ + "originHash" : "3929c9b9fd81518c26df5464b3e6a459042d74460f8bfd7fb79e21dc5c507bdf", + "pins" : [ + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "ce592ae52f982c847a4efc0dd881cc9eb32d29f2", + "version" : "1.6.4" + } + } + ], + "version" : 3 +} diff --git a/Package.swift b/Package.swift new file mode 100644 index 0000000..e820da2 --- /dev/null +++ b/Package.swift @@ -0,0 +1,65 @@ +// swift-tools-version: 6.0 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "Tachikoma", + platforms: [ + .macOS(.v14), + .iOS(.v17), + .watchOS(.v10), + .tvOS(.v17) + ], + products: [ + .library( + name: "Tachikoma", + targets: ["Tachikoma"] + ), + ], + dependencies: [ + .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), + ], + targets: [ + .target( + name: "Tachikoma", + dependencies: [ + .product(name: "Logging", package: "swift-log"), + ], + path: "Sources", + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency"), + .enableUpcomingFeature("BareSlashRegexLiterals"), + .enableUpcomingFeature("ConciseMagicFile"), + .enableUpcomingFeature("ForwardTrailingClosures"), + .enableUpcomingFeature("ImportObjcForwardDeclarations"), + .enableUpcomingFeature("DisableOutwardActorInference"), + .enableUpcomingFeature("ExistentialAny"), + .enableUpcomingFeature("DeprecateApplicationMain"), + .enableUpcomingFeature("GlobalConcurrency"), + .enableUpcomingFeature("IsolatedDefaultValues"), + .enableUpcomingFeature("InternalImportsByDefault"), + ] + ), + .testTarget( + name: "TachikomaTests", + dependencies: [ + "Tachikoma", + .product(name: "Logging", package: "swift-log"), + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency"), + .enableUpcomingFeature("BareSlashRegexLiterals"), + .enableUpcomingFeature("ConciseMagicFile"), + .enableUpcomingFeature("ForwardTrailingClosures"), + .enableUpcomingFeature("ImportObjcForwardDeclarations"), + .enableUpcomingFeature("DisableOutwardActorInference"), + .enableUpcomingFeature("ExistentialAny"), + .enableUpcomingFeature("DeprecateApplicationMain"), + .enableUpcomingFeature("GlobalConcurrency"), + .enableUpcomingFeature("IsolatedDefaultValues"), + .enableUpcomingFeature("InternalImportsByDefault"), + ] + ), + ] +) \ No newline at end of file diff --git a/README.md b/README.md index 14aa939..76b6d7a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,429 @@ -# Tachikoma -One interface, every AI model. A Swift SDK to interface with AI providers. +
+ Tachikoma Logo +

Tachikoma

+
+ +A comprehensive Swift package for AI model integration, providing a unified interface for multiple AI providers including OpenAI, Anthropic, Grok (xAI), and Ollama. + +Named after the spider-tank AI from Ghost in the Shell, Tachikoma provides an intelligent, adaptable interface for AI services. + +## Features + +- **Unified API**: Single interface for multiple AI providers +- **Swift 6 Compliant**: Built with Swift 6 strict concurrency mode for maximum safety +- **Streaming Support**: Real-time streaming responses for all supported providers +- **Tool Calling**: Complete function calling support for AI agent workflows +- **Multimodal**: Support for text, images, audio, and file inputs +- **Type Safety**: Strongly-typed message handling and error management +- **Performance**: Optimized for efficiency with intelligent caching and resource management + +## Supported Providers + +### OpenAI +- **Models**: GPT-4o, GPT-4.1, o3, o4 series with full parameter support +- **Features**: Chat Completions API, Responses API, streaming, tool calling, multimodal +- **API Types**: Automatic selection between Chat Completions and Responses APIs + +### Anthropic (Claude) +- **Models**: Claude 4 (Opus, Sonnet), Claude 3.5/3.7 series with thinking modes +- **Features**: Native streaming, tool calling, multimodal, extended reasoning +- **Capabilities**: Long-running tasks, system prompts, safety filtering + +### Grok (xAI) +- **Models**: Grok 4, Grok 3, Grok 2 series with vision capabilities +- **Features**: OpenAI-compatible API, streaming, tool calling, parameter filtering +- **Performance**: High-speed inference with competitive pricing + +### Ollama +- **Models**: Llama 3.3 (recommended), Mistral, CodeLlama, and vision models +- **Features**: Local inference, tool calling (select models), streaming +- **Deployment**: Self-hosted with configurable endpoints + +## Installation + +### Swift Package Manager + +Add Tachikoma as a dependency in your `Package.swift`: + +```swift +dependencies: [ + .package(url: "https://github.com/steipete/Tachikoma", from: "1.0.0") +] +``` + +### Requirements + +- macOS 14.0+, iOS 17.0+, watchOS 10.0+, tvOS 17.0+ +- Swift 6.0+ +- Xcode 16.0+ + +## Quick Start + +### Basic Usage + +```swift +import Tachikoma + +// Get a model instance +let tachikoma = Tachikoma.shared +let model = try await tachikoma.getModel("claude-opus-4") + +// Create a simple request +let request = ModelRequest( + messages: [ + .user(content: .text("What is the capital of France?")) + ] +) + +// Get response +let response = try await model.getResponse(request: request) +print(response.content.first?.text ?? "No response") +``` + +### Streaming Responses + +```swift +let request = ModelRequest( + messages: [.user(content: .text("Write a story about AI"))], + settings: ModelSettings(temperature: 0.7) +) + +for try await event in try await model.getStreamedResponse(request: request) { + switch event { + case .textDelta(let delta): + print(delta.delta, terminator: "") + case .responseCompleted: + print("\n[Stream completed]") + default: + break + } +} +``` + +### Tool Calling + +```swift +// Define a tool +let weatherTool = ToolDefinition( + function: FunctionDefinition( + name: "get_weather", + description: "Get current weather for a location", + parameters: ToolParameters( + type: "object", + properties: [ + "location": ParameterSchema( + type: .string, + description: "City name" + ) + ], + required: ["location"] + ) + ) +) + +let request = ModelRequest( + messages: [.user(content: .text("What's the weather in Tokyo?"))], + tools: [weatherTool], + settings: ModelSettings(toolChoice: .auto) +) + +let response = try await model.getResponse(request: request) +// Handle tool calls in response.content +``` + +### Multimodal Inputs + +```swift +let imageData = Data(contentsOf: imageURL) +let base64Image = imageData.base64EncodedString() + +let request = ModelRequest( + messages: [ + .user(content: .multimodal([ + .text("What do you see in this image?"), + .imageUrl(ImageUrl( + base64: base64Image, + detail: .high + )) + ])) + ] +) + +let response = try await model.getResponse(request: request) +``` + +## Configuration + +### Environment Variables + +```bash +# OpenAI +export OPENAI_API_KEY="sk-..." + +# Anthropic +export ANTHROPIC_API_KEY="sk-ant-..." + +# Grok (xAI) +export X_AI_API_KEY="xai-..." +# or +export XAI_API_KEY="xai-..." + +# Ollama (optional, defaults to localhost:11434) +export PEEKABOO_OLLAMA_BASE_URL="http://localhost:11434" +``` + +### Provider Configuration + +```swift +// Configure custom provider settings +let openAIConfig = ProviderConfiguration.openAI( + apiKey: "your-api-key", + baseURL: URL(string: "https://api.openai.com/v1"), + organizationId: "org-id" +) + +try await Tachikoma.shared.configureProvider(openAIConfig) + +// Register custom model +await Tachikoma.shared.registerModel(name: "custom-gpt") { + OpenAIModel( + apiKey: "your-key", + modelName: "gpt-4-custom" + ) +} +``` + +## Architecture + +### Core Components + +#### ModelInterface +The unified protocol that all AI providers implement: + +```swift +protocol ModelInterface: Sendable { + var maskedApiKey: String { get } + func getResponse(request: ModelRequest) async throws -> ModelResponse + func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream +} +``` + +#### Message System +Type-safe message handling with support for all content types: + +```swift +public enum Message: Codable, Sendable { + case system(id: String? = nil, content: String) + case user(id: String? = nil, content: MessageContent) + case assistant(id: String? = nil, content: [AssistantContent], status: MessageStatus = .completed) + case tool(id: String? = nil, toolCallId: String, content: String) + case reasoning(id: String? = nil, content: String) +} +``` + +#### Streaming System +Real-time event handling for streaming responses: + +```swift +public enum StreamEvent { + case responseStarted(StreamResponseStarted) + case textDelta(StreamTextDelta) + case toolCallDelta(StreamToolCallDelta) + case toolCallCompleted(StreamToolCallCompleted) + case responseCompleted(StreamResponseCompleted) + case error(StreamError) +} +``` + +#### Tool System +Comprehensive function calling with generic context support: + +```swift +public struct Tool { + public let execute: (ToolInput, Context) async throws -> ToolOutput + public func toToolDefinition() -> ToolDefinition +} +``` + +### Provider Architecture + +Each provider implements the `ModelInterface` with provider-specific optimizations: + +- **OpenAI**: Dual API support (Chat Completions + Responses API) +- **Anthropic**: Native SSE streaming with content blocks +- **Grok**: OpenAI-compatible with parameter filtering +- **Ollama**: Local inference with tool calling detection + +### Error Handling + +Comprehensive error types with recovery suggestions: + +```swift +public enum TachikomaError: Error, LocalizedError { + case modelNotFound(String) + case authenticationFailed + case rateLimited + case insufficientQuota + case contextLengthExceeded + // ... more cases with detailed descriptions + + public var isRetryable: Bool { /* ... */ } + public var recoverySuggestion: String? { /* ... */ } +} +``` + +## Advanced Usage + +### Custom Providers + +```swift +class CustomAIModel: ModelInterface { + var maskedApiKey: String { "custom-***" } + + func getResponse(request: ModelRequest) async throws -> ModelResponse { + // Custom implementation + } + + func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + // Custom streaming implementation + } +} + +// Register the custom provider +await Tachikoma.shared.registerModel(name: "custom-ai") { + CustomAIModel() +} +``` + +### Batch Processing + +```swift +let requests = [ + ModelRequest(messages: [.user(content: .text("Question 1"))]), + ModelRequest(messages: [.user(content: .text("Question 2"))]), + ModelRequest(messages: [.user(content: .text("Question 3"))]) +] + +let responses = try await withThrowingTaskGroup(of: ModelResponse.self) { group in + for request in requests { + group.addTask { + try await model.getResponse(request: request) + } + } + + var results: [ModelResponse] = [] + for try await response in group { + results.append(response) + } + return results +} +``` + +### Tool Context Management + +```swift +struct WeatherContext { + let apiKey: String + let units: String +} + +let weatherTool = Tool { input, context in + let location = input.parameters["location"] as? String ?? "" + // Use context.apiKey and context.units for API call + return ToolOutput(content: "Weather data for \(location)") +} + +// Use with context +let context = WeatherContext(apiKey: "weather-key", units: "metric") +let toolDefinition = weatherTool.toToolDefinition() +``` + +## Testing + +Tachikoma includes comprehensive test coverage: + +```bash +# Run tests +swift test + +# Run with verbose output +swift test --verbose + +# Run specific test suites +swift test --filter "ModelProviderTests" +``` + +### Mock Providers + +```swift +class MockModel: ModelInterface { + var maskedApiKey: String = "mock-***" + var responses: [ModelResponse] = [] + + func getResponse(request: ModelRequest) async throws -> ModelResponse { + return responses.removeFirst() + } +} + +// Use in tests +let mockModel = MockModel() +mockModel.responses = [/* test responses */] +await Tachikoma.shared.registerModel(name: "mock") { mockModel } +``` + +## Performance Considerations + +### Caching +- Model instances are cached by default +- Clear cache with `ModelProvider.shared.clearCache()` +- Disable caching for specific models if needed + +### Memory Management +- Streaming responses use `AsyncThrowingStream` for memory efficiency +- Large responses are processed incrementally +- Tool contexts should be lightweight for optimal performance + +### Concurrency +- All APIs are actor-safe and Swift 6 compliant +- Use `TaskGroup` for parallel processing +- Respect rate limits with proper error handling + +## Migration Guide + +### From PeekabooCore + +If migrating from PeekabooCore's AI system: + +1. Replace `PeekabooError` with `TachikomaError` +2. Update import statements +3. Use `Tachikoma.shared.getModel()` instead of direct model creation +4. Update streaming event handling for new event types + +### Version History + +- **v1.0.0**: Initial release with Swift 6 support +- Core providers: OpenAI, Anthropic, Grok, Ollama +- Complete tool calling and streaming support + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Ensure Swift 6 strict mode compliance +4. Add comprehensive tests +5. Update documentation +6. Submit a pull request + +## License + +Tachikoma is available under the MIT License. See LICENSE for details. + +## Support + +- **Issues**: [GitHub Issues](https://github.com/steipete/Tachikoma/issues) +- **Discussions**: [GitHub Discussions](https://github.com/steipete/Tachikoma/discussions) +- **Documentation**: [API Reference](https://steipete.github.io/Tachikoma/) + +--- + +Built with ❤️ for the Swift AI community. diff --git a/Sources/Core/MessageTypes.swift b/Sources/Core/MessageTypes.swift new file mode 100644 index 0000000..f364eb6 --- /dev/null +++ b/Sources/Core/MessageTypes.swift @@ -0,0 +1,366 @@ +import Foundation + +// MARK: - Unified Message Type + +/// Unified message enum that provides type-safe message handling +public enum Message: Codable, Sendable { + case system(id: String? = nil, content: String) + case user(id: String? = nil, content: MessageContent) + case assistant(id: String? = nil, content: [AssistantContent], status: MessageStatus = .completed) + case tool(id: String? = nil, toolCallId: String, content: String) + case reasoning(id: String? = nil, content: String) + + // MARK: - Properties + + /// Get the message type + public var type: MessageType { + switch self { + case .system: .system + case .user: .user + case .assistant: .assistant + case .tool: .tool + case .reasoning: .reasoning + } + } + + /// Get the message ID + public var id: String? { + switch self { + case let .system(id, _), let .user(id, _), let .assistant(id, _, _), + let .tool(id, _, _), let .reasoning(id, _): + id + } + } + + // MARK: - Codable Implementation + + private enum CodingKeys: String, CodingKey { + case type, id, content, status, toolCallId + } + + public enum MessageType: String, Codable { + case system, user, assistant, tool, reasoning + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(MessageType.self, forKey: .type) + let id = try container.decodeIfPresent(String.self, forKey: .id) + + switch type { + case .system: + let content = try container.decode(String.self, forKey: .content) + self = .system(id: id, content: content) + + case .user: + let content = try container.decode(MessageContent.self, forKey: .content) + self = .user(id: id, content: content) + + case .assistant: + let content = try container.decode([AssistantContent].self, forKey: .content) + let status = try container.decodeIfPresent(MessageStatus.self, forKey: .status) ?? .completed + self = .assistant(id: id, content: content, status: status) + + case .tool: + let toolCallId = try container.decode(String.self, forKey: .toolCallId) + let content = try container.decode(String.self, forKey: .content) + self = .tool(id: id, toolCallId: toolCallId, content: content) + + case .reasoning: + let content = try container.decode(String.self, forKey: .content) + self = .reasoning(id: id, content: content) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(self.type, forKey: .type) + + switch self { + case let .system(id, content): + try container.encodeIfPresent(id, forKey: .id) + try container.encode(content, forKey: .content) + + case let .user(id, content): + try container.encodeIfPresent(id, forKey: .id) + try container.encode(content, forKey: .content) + + case let .assistant(id, content, status): + try container.encodeIfPresent(id, forKey: .id) + try container.encode(content, forKey: .content) + try container.encode(status, forKey: .status) + + case let .tool(id, toolCallId, content): + try container.encodeIfPresent(id, forKey: .id) + try container.encode(toolCallId, forKey: .toolCallId) + try container.encode(content, forKey: .content) + + case let .reasoning(id, content): + try container.encodeIfPresent(id, forKey: .id) + try container.encode(content, forKey: .content) + } + } +} + +// MARK: - Content Types + +/// User message content variants +public enum MessageContent: Codable, Sendable { + case text(String) + case image(ImageContent) + case file(FileContent) + case audio(AudioContent) + case multimodal([MessageContentPart]) + + // Custom coding for enum + enum CodingKeys: String, CodingKey { + case type, value + } + + enum ContentType: String, Codable { + case text, image, file, audio, multimodal + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(ContentType.self, forKey: .type) + + switch type { + case .text: + let value = try container.decode(String.self, forKey: .value) + self = .text(value) + case .image: + let value = try container.decode(ImageContent.self, forKey: .value) + self = .image(value) + case .file: + let value = try container.decode(FileContent.self, forKey: .value) + self = .file(value) + case .audio: + let value = try container.decode(AudioContent.self, forKey: .value) + self = .audio(value) + case .multimodal: + let value = try container.decode([MessageContentPart].self, forKey: .value) + self = .multimodal(value) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case let .text(value): + try container.encode(ContentType.text, forKey: .type) + try container.encode(value, forKey: .value) + case let .image(value): + try container.encode(ContentType.image, forKey: .type) + try container.encode(value, forKey: .value) + case let .file(value): + try container.encode(ContentType.file, forKey: .type) + try container.encode(value, forKey: .value) + case let .audio(value): + try container.encode(ContentType.audio, forKey: .type) + try container.encode(value, forKey: .value) + case let .multimodal(value): + try container.encode(ContentType.multimodal, forKey: .type) + try container.encode(value, forKey: .value) + } + } +} + +/// Image content for messages +public struct ImageContent: Codable, Sendable { + public let url: String? + public let base64: String? + public let detail: ImageDetail? + + public enum ImageDetail: String, Codable, Sendable { + case auto, low, high + } + + public init(url: String? = nil, base64: String? = nil, detail: ImageDetail? = nil) { + self.url = url + self.base64 = base64 + self.detail = detail + } +} + +/// File content for messages +public struct FileContent: Codable, Sendable { + public let id: String? + public let url: String? + public let name: String? + public let filename: String? + public let content: String? + public let mimeType: String? + + public init( + id: String? = nil, + url: String? = nil, + name: String? = nil, + filename: String? = nil, + content: String? = nil, + mimeType: String? = nil) + { + self.id = id + self.url = url + self.name = name + self.filename = filename + self.content = content + self.mimeType = mimeType + } + + // Convenience constructor for test compatibility + public init(filename: String, content: String, mimeType: String) { + self.init(name: filename, filename: filename, content: content, mimeType: mimeType) + } +} + +/// Audio content for messages +public struct AudioContent: Codable, Sendable { + public let url: String? + public let base64: String? + public let transcript: String? + public let duration: TimeInterval? + public let mimeType: String? + + public init( + url: String? = nil, + base64: String? = nil, + transcript: String? = nil, + duration: TimeInterval? = nil, + mimeType: String? = nil) + { + self.url = url + self.base64 = base64 + self.transcript = transcript + self.duration = duration + self.mimeType = mimeType + } +} + +/// Multimodal content part +public struct MessageContentPart: Codable, Sendable { + public let type: String + public let text: String? + public let imageUrl: ImageContent? + + public init(type: String, text: String? = nil, imageUrl: ImageContent? = nil) { + self.type = type + self.text = text + self.imageUrl = imageUrl + } +} + +/// Assistant response content variants +public enum AssistantContent: Codable, Sendable { + case outputText(String) + case refusal(String) + case toolCall(ToolCallItem) + + // Custom coding + enum CodingKeys: String, CodingKey { + case type, value + } + + enum ContentType: String, Codable { + case text, refusal, toolCall + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(ContentType.self, forKey: .type) + + switch type { + case .text: + let value = try container.decode(String.self, forKey: .value) + self = .outputText(value) + case .refusal: + let value = try container.decode(String.self, forKey: .value) + self = .refusal(value) + case .toolCall: + let value = try container.decode(ToolCallItem.self, forKey: .value) + self = .toolCall(value) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case let .outputText(value): + try container.encode(ContentType.text, forKey: .type) + try container.encode(value, forKey: .value) + case let .refusal(value): + try container.encode(ContentType.refusal, forKey: .type) + try container.encode(value, forKey: .value) + case let .toolCall(value): + try container.encode(ContentType.toolCall, forKey: .type) + try container.encode(value, forKey: .value) + } + } +} + +// MARK: - Tool Call Types + +/// Tool call item representing a function invocation +public struct ToolCallItem: Codable, Sendable { + public let id: String + public let type: ToolCallType + public let function: FunctionCall + public let status: ToolCallStatus? + + public init(id: String, type: ToolCallType = .function, function: FunctionCall, status: ToolCallStatus? = nil) { + self.id = id + self.type = type + self.function = function + self.status = status + } +} + +/// Types of tool calls +public enum ToolCallType: String, Codable, Sendable { + case function + case hosted = "hosted_tool" + case computer +} + +/// Function call details +public struct FunctionCall: Codable, Sendable { + public let name: String + public let arguments: String + + public init(name: String, arguments: String) { + self.name = name + self.arguments = arguments + } +} + +/// Tool call execution status +public enum ToolCallStatus: String, Codable, Sendable { + case inProgress = "in_progress" + case completed + case failed +} + +/// Message processing status +public enum MessageStatus: String, Codable, Sendable { + case inProgress = "in_progress" + case completed + case incomplete +} + +// MARK: - Helper Extensions + +extension AssistantContent { + /// Extract text content if available + public var textContent: String? { + switch self { + case let .outputText(text): + text + case let .refusal(text): + text + case .toolCall: + nil + } + } +} \ No newline at end of file diff --git a/Sources/Core/ModelInterface.swift b/Sources/Core/ModelInterface.swift new file mode 100644 index 0000000..c1eb5ca --- /dev/null +++ b/Sources/Core/ModelInterface.swift @@ -0,0 +1,424 @@ +import Foundation + +// MARK: - Model Interface Protocol + +/// Protocol defining the interface for AI model providers +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public protocol ModelInterface: Sendable { + /// Get a non-streaming response from the model + /// - Parameter request: The model request containing messages, tools, and settings + /// - Returns: The model response + func getResponse(request: ModelRequest) async throws -> ModelResponse + + /// Get a streaming response from the model + /// - Parameter request: The model request containing messages, tools, and settings + /// - Returns: An async stream of events + func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream + + /// Get a masked version of the API key for debugging + /// Returns the first 6 and last 2 characters of the API key + /// - Returns: Masked API key string (e.g., "sk-ant...AA") + var maskedApiKey: String { get } +} + +// MARK: - Model Request & Response Types + +/// Request to send to a model +public struct ModelRequest: Codable, Sendable { + /// The messages to send to the model + public let messages: [Message] + + /// Available tools for the model to use + public let tools: [ToolDefinition]? + + /// Model-specific settings + public let settings: ModelSettings + + /// System instructions (some models support this separately from messages) + public let systemInstructions: String? + + public init( + messages: [Message], + tools: [ToolDefinition]? = nil, + settings: ModelSettings, + systemInstructions: String? = nil) + { + self.messages = messages + self.tools = tools + self.settings = settings + self.systemInstructions = systemInstructions + } +} + +/// Response from a model +public struct ModelResponse: Codable, Sendable { + /// Unique identifier for the response + public let id: String + + /// The model that generated the response + public let model: String? + + /// Content returned by the model + public let content: [AssistantContent] + + /// Token usage statistics + public let usage: Usage? + + /// Whether the response was flagged for safety + public let flagged: Bool + + /// Reason for flagging if applicable + public let flaggedCategories: [String]? + + /// Finish reason + public let finishReason: FinishReason? + + public init( + id: String, + model: String? = nil, + content: [AssistantContent], + usage: Usage? = nil, + flagged: Bool = false, + flaggedCategories: [String]? = nil, + finishReason: FinishReason? = nil) + { + self.id = id + self.model = model + self.content = content + self.usage = usage + self.flagged = flagged + self.flaggedCategories = flaggedCategories + self.finishReason = finishReason + } +} + +// MARK: - Model Settings + +/// Settings for model behavior +public struct ModelSettings: Codable, Sendable { + /// The model name/identifier + public let modelName: String + + /// Temperature for randomness (0.0 to 2.0) + public let temperature: Double? + + /// Top-p sampling parameter + public let topP: Double? + + /// Maximum tokens to generate + public let maxTokens: Int? + + /// Frequency penalty (-2.0 to 2.0) + public let frequencyPenalty: Double? + + /// Presence penalty (-2.0 to 2.0) + public let presencePenalty: Double? + + /// Stop sequences + public let stopSequences: [String]? + + /// Tool choice setting + public let toolChoice: ToolChoice? + + /// Whether to use parallel tool calls + public let parallelToolCalls: Bool? + + /// Response format + public let responseFormat: ResponseFormat? + + /// Seed for deterministic generation + public let seed: Int? + + /// User identifier for tracking + public let user: String? + + /// Additional provider-specific parameters + public let additionalParameters: ModelParameters? + + public init( + modelName: String, + temperature: Double? = nil, + topP: Double? = nil, + maxTokens: Int? = nil, + frequencyPenalty: Double? = nil, + presencePenalty: Double? = nil, + stopSequences: [String]? = nil, + toolChoice: ToolChoice? = nil, + parallelToolCalls: Bool? = nil, + responseFormat: ResponseFormat? = nil, + seed: Int? = nil, + user: String? = nil, + additionalParameters: ModelParameters? = nil) + { + self.modelName = modelName + self.temperature = temperature + self.topP = topP + self.maxTokens = maxTokens + self.frequencyPenalty = frequencyPenalty + self.presencePenalty = presencePenalty + self.stopSequences = stopSequences + self.toolChoice = toolChoice + self.parallelToolCalls = parallelToolCalls + self.responseFormat = responseFormat + self.seed = seed + self.user = user + self.additionalParameters = additionalParameters + } + + /// Default settings for Claude Opus 4 + public static var `default`: ModelSettings { + ModelSettings(modelName: "claude-opus-4-20250514") + } + + // MARK: - Convenience Constructors + + /// Create settings with just specified parameters, using Claude Opus 4 as default model + public init( + temperature: Double? = nil, + topP: Double? = nil, + maxTokens: Int? = nil, + frequencyPenalty: Double? = nil, + presencePenalty: Double? = nil, + stopSequences: [String]? = nil, + toolChoice: ToolChoice? = nil, + parallelToolCalls: Bool? = nil, + responseFormat: ResponseFormat? = nil, + seed: Int? = nil, + user: String? = nil, + additionalParameters: ModelParameters? = nil) + { + self.init( + modelName: "claude-opus-4-20250514", + temperature: temperature, + topP: topP, + maxTokens: maxTokens, + frequencyPenalty: frequencyPenalty, + presencePenalty: presencePenalty, + stopSequences: stopSequences, + toolChoice: toolChoice, + parallelToolCalls: parallelToolCalls, + responseFormat: responseFormat, + seed: seed, + user: user, + additionalParameters: additionalParameters + ) + } + + /// Convenience constructors for specific API types and reasoning parameters + public init( + apiType: String, + modelName: String = "claude-opus-4-20250514") + { + let params = ModelParameters([ + "apiType": ModelParameters.Value.string(apiType) + ]) + self.init(modelName: modelName, additionalParameters: params) + } + + public init( + reasoningEffort: String, + reasoning: [String: String]? = nil, + temperature: Double? = nil, + modelName: String = "o3") + { + var params: [String: ModelParameters.Value] = [ + "reasoningEffort": ModelParameters.Value.string(reasoningEffort) + ] + if let reasoning = reasoning { + let reasoningValue = reasoning.mapValues { ModelParameters.Value.string($0) } + params["reasoning"] = ModelParameters.Value.dictionary(reasoningValue) + } + let modelParams = ModelParameters(params) + self.init( + modelName: modelName, + temperature: temperature, + additionalParameters: modelParams + ) + } + + // Custom coding for additionalParameters + enum CodingKeys: String, CodingKey { + case modelName, temperature, topP, maxTokens + case frequencyPenalty, presencePenalty, stopSequences + case toolChoice, parallelToolCalls, responseFormat + case seed, user, additionalParameters + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + self.modelName = try container.decode(String.self, forKey: .modelName) + self.temperature = try container.decodeIfPresent(Double.self, forKey: .temperature) + self.topP = try container.decodeIfPresent(Double.self, forKey: .topP) + self.maxTokens = try container.decodeIfPresent(Int.self, forKey: .maxTokens) + self.frequencyPenalty = try container.decodeIfPresent(Double.self, forKey: .frequencyPenalty) + self.presencePenalty = try container.decodeIfPresent(Double.self, forKey: .presencePenalty) + self.stopSequences = try container.decodeIfPresent([String].self, forKey: .stopSequences) + self.toolChoice = try container.decodeIfPresent(ToolChoice.self, forKey: .toolChoice) + self.parallelToolCalls = try container.decodeIfPresent(Bool.self, forKey: .parallelToolCalls) + self.responseFormat = try container.decodeIfPresent(ResponseFormat.self, forKey: .responseFormat) + self.seed = try container.decodeIfPresent(Int.self, forKey: .seed) + self.user = try container.decodeIfPresent(String.self, forKey: .user) + + // Decode additional parameters + self.additionalParameters = try container.decodeIfPresent(ModelParameters.self, forKey: .additionalParameters) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + try container.encode(self.modelName, forKey: .modelName) + try container.encodeIfPresent(self.temperature, forKey: .temperature) + try container.encodeIfPresent(self.topP, forKey: .topP) + try container.encodeIfPresent(self.maxTokens, forKey: .maxTokens) + try container.encodeIfPresent(self.frequencyPenalty, forKey: .frequencyPenalty) + try container.encodeIfPresent(self.presencePenalty, forKey: .presencePenalty) + try container.encodeIfPresent(self.stopSequences, forKey: .stopSequences) + try container.encodeIfPresent(self.toolChoice, forKey: .toolChoice) + try container.encodeIfPresent(self.parallelToolCalls, forKey: .parallelToolCalls) + try container.encodeIfPresent(self.responseFormat, forKey: .responseFormat) + try container.encodeIfPresent(self.seed, forKey: .seed) + try container.encodeIfPresent(self.user, forKey: .user) + + // Encode additional parameters + try container.encodeIfPresent(self.additionalParameters, forKey: .additionalParameters) + } +} + +// MARK: - Tool Choice + +/// Tool choice setting for models +public enum ToolChoice: Codable, Sendable, Equatable { + case auto + case none + case required + case specific(toolName: String) + + // Custom coding + enum CodingKeys: String, CodingKey { + case type, toolName + } + + enum ChoiceType: String, Codable { + case auto, none, required, specific + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(ChoiceType.self, forKey: .type) + + switch type { + case .auto: + self = .auto + case .none: + self = .none + case .required: + self = .required + case .specific: + let toolName = try container.decode(String.self, forKey: .toolName) + self = .specific(toolName: toolName) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case .auto: + try container.encode(ChoiceType.auto, forKey: .type) + case .none: + try container.encode(ChoiceType.none, forKey: .type) + case .required: + try container.encode(ChoiceType.required, forKey: .type) + case let .specific(toolName): + try container.encode(ChoiceType.specific, forKey: .type) + try container.encode(toolName, forKey: .toolName) + } + } +} + +// MARK: - Response Format + +/// Response format specification +public struct ResponseFormat: Codable, Sendable { + public let type: ResponseFormatType + public let jsonSchema: JSONSchema? + + public init(type: ResponseFormatType, jsonSchema: JSONSchema? = nil) { + self.type = type + self.jsonSchema = jsonSchema + } + + /// Plain text response + public static var text: ResponseFormat { + ResponseFormat(type: .text) + } + + /// JSON object response + public static var jsonObject: ResponseFormat { + ResponseFormat(type: .jsonObject) + } +} + +/// Response format types +public enum ResponseFormatType: String, Codable, Sendable { + case text + case jsonObject = "json_object" + case jsonSchema = "json_schema" +} + +/// JSON schema specification +public struct JSONSchema: Codable, Sendable { + public let name: String + public let strict: Bool + private let schemaData: Data // Store as raw JSON data + + // Custom coding for schema + enum CodingKeys: String, CodingKey { + case name, strict, schema + } + + public init(name: String, strict: Bool = true, schema: [String: Any]) throws { + self.name = name + self.strict = strict + self.schemaData = try JSONSerialization.data(withJSONObject: schema) + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.name = try container.decode(String.self, forKey: .name) + self.strict = try container.decode(Bool.self, forKey: .strict) + self.schemaData = try container.decode(Data.self, forKey: .schema) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(self.name, forKey: .name) + try container.encode(self.strict, forKey: .strict) + try container.encode(self.schemaData, forKey: .schema) + } + + /// Get the schema as a dictionary + public func getSchema() throws -> [String: Any] { + guard let dict = try JSONSerialization.jsonObject(with: schemaData) as? [String: Any] else { + throw TachikomaError.invalidConfiguration("Invalid JSON schema data") + } + return dict + } +} + +// MARK: - Model Provider Protocol + +/// Protocol for model provider factories +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public protocol ModelProviderProtocol { + /// Get a model by name + /// - Parameter modelName: The name of the model to retrieve + /// - Returns: A model instance conforming to ModelInterface + func getModel(modelName: String) throws -> any ModelInterface +} + +// MARK: - Model Errors + +// MARK: - Note +// ModelError is defined in TachikomaError.swift to avoid duplication \ No newline at end of file diff --git a/Sources/Core/ModelParameters.swift b/Sources/Core/ModelParameters.swift new file mode 100644 index 0000000..d0f9175 --- /dev/null +++ b/Sources/Core/ModelParameters.swift @@ -0,0 +1,240 @@ +import Foundation + +// MARK: - Model Parameters + +/// Type-safe representation of additional model parameters +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct ModelParameters: Codable, Sendable { + private let storage: [String: Value] + + /// Supported parameter value types + public enum Value: Codable, Sendable { + case string(String) + case int(Int) + case double(Double) + case bool(Bool) + case dictionary([String: Value]) + case array([Value]) + + // MARK: - Codable + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if let intValue = try? container.decode(Int.self) { + self = .int(intValue) + } else if let doubleValue = try? container.decode(Double.self) { + self = .double(doubleValue) + } else if let boolValue = try? container.decode(Bool.self) { + self = .bool(boolValue) + } else if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else if let dictValue = try? container.decode([String: Value].self) { + self = .dictionary(dictValue) + } else if let arrayValue = try? container.decode([Value].self) { + self = .array(arrayValue) + } else { + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "Unable to decode ModelParameters.Value") + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + + switch self { + case let .string(value): + try container.encode(value) + case let .int(value): + try container.encode(value) + case let .double(value): + try container.encode(value) + case let .bool(value): + try container.encode(value) + case let .dictionary(value): + try container.encode(value) + case let .array(value): + try container.encode(value) + } + } + + /// Convert to raw value for JSON serialization + public var rawValue: Any { + switch self { + case let .string(value): + value + case let .int(value): + value + case let .double(value): + value + case let .bool(value): + value + case let .dictionary(dict): + dict.mapValues { $0.rawValue } + case let .array(array): + array.map(\.rawValue) + } + } + } + + // MARK: - Initialization + + public init(_ storage: [String: Value] = [:]) { + self.storage = storage + } + + /// Initialize from a dictionary of raw values + public init(from rawValues: [String: Any]) { + var convertedStorage: [String: Value] = [:] + for (key, value) in rawValues { + if let converted = Self.convertToValue(value) { + convertedStorage[key] = converted + } + } + self.storage = convertedStorage + } + + /// Convert any value to our Value enum + private static func convertToValue(_ value: Any) -> Value? { + switch value { + case let string as String: + return .string(string) + case let int as Int: + return .int(int) + case let double as Double: + return .double(double) + case let bool as Bool: + return .bool(bool) + case let dict as [String: Any]: + var converted: [String: Value] = [:] + for (k, v) in dict { + if let convertedValue = convertToValue(v) { + converted[k] = convertedValue + } + } + return .dictionary(converted) + case let array as [Any]: + let converted = array.compactMap { self.convertToValue($0) } + return .array(converted) + default: + return nil + } + } + + // MARK: - Access Methods + + public subscript(key: String) -> Value? { self.storage[key] } + + public func string(_ key: String) -> String? { + guard case let .string(value) = storage[key] else { return nil } + return value + } + + public func int(_ key: String) -> Int? { + guard case let .int(value) = storage[key] else { return nil } + return value + } + + public func double(_ key: String) -> Double? { + guard case let .double(value) = storage[key] else { return nil } + return value + } + + public func bool(_ key: String) -> Bool? { + guard case let .bool(value) = storage[key] else { return nil } + return value + } + + /// Get the raw dictionary for JSON serialization + public var rawDictionary: [String: Any] { + self.storage.mapValues { $0.rawValue } + } + + /// Check if empty + public var isEmpty: Bool { + self.storage.isEmpty + } + + // MARK: - Mutating Methods + + /// Set a parameter value + public mutating func set(_ key: String, value: Any) { + if let converted = Self.convertToValue(value) { + var newStorage = self.storage + newStorage[key] = converted + self = ModelParameters(newStorage) + } + } + + /// Get a parameter value + public func get(_ key: String) -> Any? { + return self.storage[key]?.rawValue + } + + // MARK: - Builder Methods + + public func with(_ key: String, value: String) -> ModelParameters { + var newStorage = self.storage + newStorage[key] = .string(value) + return ModelParameters(newStorage) + } + + public func with(_ key: String, value: Int) -> ModelParameters { + var newStorage = self.storage + newStorage[key] = .int(value) + return ModelParameters(newStorage) + } + + public func with(_ key: String, value: Double) -> ModelParameters { + var newStorage = self.storage + newStorage[key] = .double(value) + return ModelParameters(newStorage) + } + + public func with(_ key: String, value: Bool) -> ModelParameters { + var newStorage = self.storage + newStorage[key] = .bool(value) + return ModelParameters(newStorage) + } + + public func with(_ key: String, value: [String: Any]) -> ModelParameters { + guard let converted = Self.convertToValue(value) else { return self } + var newStorage = self.storage + newStorage[key] = converted + return ModelParameters(newStorage) + } + + // MARK: - Codable + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + self.storage = try container.decode([String: Value].self) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(self.storage) + } +} + +// MARK: - Convenience Builders + +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +extension ModelParameters { + /// Create parameters for OpenAI o3/o4 models + public static func o3Parameters( + reasoningEffort: String = "medium", + maxCompletionTokens: Int = 32768) -> ModelParameters + { + ModelParameters() + .with("reasoning_effort", value: reasoningEffort) + .with("max_completion_tokens", value: maxCompletionTokens) + .with("reasoning", value: ["summary": "detailed"]) + } + + /// Create parameters with API type + public static func withAPIType(_ apiType: String) -> ModelParameters { + ModelParameters().with("apiType", value: apiType) + } +} \ No newline at end of file diff --git a/Sources/Core/ModelProvider.swift b/Sources/Core/ModelProvider.swift new file mode 100644 index 0000000..487ee4e --- /dev/null +++ b/Sources/Core/ModelProvider.swift @@ -0,0 +1,650 @@ +import Foundation + +// MARK: - Model Provider + +/// Singleton provider for managing model instances in Tachikoma +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public actor ModelProvider { + /// Shared instance + public static let shared = ModelProvider() + + /// Registered model factories + private var modelFactories: [String: @Sendable () throws -> any ModelInterface] = [:] + + /// Model instance cache + private var modelCache: [String: any ModelInterface] = [:] + + private init() { + // Register default models when initialized + Task { + await self.registerDefaultModels() + } + } + + // MARK: - Public Methods + + /// Register a model factory + /// - Parameters: + /// - name: The model name + /// - factory: Factory closure that creates the model + public func register( + modelName: String, + factory: @escaping @Sendable () throws -> any ModelInterface) + { + self.modelFactories[modelName] = factory + // Clear cache for this model + self.modelCache.removeValue(forKey: modelName) + } + + /// Get a model by name + /// - Parameter modelName: The name of the model + /// - Returns: A model instance + /// - Throws: TachikomaError if model not found + public func getModel(modelName: String) throws -> any ModelInterface { + // Check cache first + if let cached = modelCache[modelName] { + return cached + } + + // Check for custom provider first (format: provider-id/model-path) + if let slashIndex = modelName.firstIndex(of: "/") { + let providerId = String(modelName[.. [String] { + Array(self.modelFactories.keys).sorted() + } + + /// Clear model cache + public func clearCache() { + self.modelCache.removeAll() + } + + /// Clear all model registrations and cache (useful for testing) + public func clearAll() async { + self.modelCache.removeAll() + self.modelFactories.removeAll() + // Re-register default models + await self.registerDefaultModels() + } + + /// Unregister a model + public func unregister(modelName: String) { + self.modelFactories.removeValue(forKey: modelName) + self.modelCache.removeValue(forKey: modelName) + } + + // MARK: - Private Methods + + private func registerDefaultModels() async { + // Register OpenAI models + self.registerOpenAIModels() + + // Register Anthropic models + self.registerAnthropicModels() + + // Register Grok models + self.registerGrokModels() + + // Register Ollama models + self.registerOllamaModels() + } + + /// Resolve lenient model names to their full versions + private func resolveLenientModelName(_ modelName: String) -> String? { + let lowercased = modelName.lowercased() + + // Claude model shortcuts + if lowercased == "claude-4-opus" || lowercased == "claude-opus-4" || lowercased == "claude-opus" { + return "claude-opus-4-20250514" + } + if lowercased == "claude-4-sonnet" || lowercased == "claude-sonnet-4" || lowercased == "claude-sonnet" { + return "claude-sonnet-4-20250514" + } + if lowercased == "claude-3.7-sonnet" || lowercased == "claude-3-7-sonnet" || lowercased == "claude-sonnet-3.7" { + return "claude-3-7-sonnet" + } + if lowercased == "claude-3.5-sonnet" || lowercased == "claude-3-5-sonnet" || lowercased == "claude-sonnet-3.5" { + return "claude-3-5-sonnet" + } + if lowercased == "claude-3.5-haiku" || lowercased == "claude-3-5-haiku" || lowercased == "claude-haiku-3.5" { + return "claude-3-5-haiku" + } + if lowercased == "claude-3.5-opus" || lowercased == "claude-3-5-opus" || lowercased == "claude-opus-3.5" { + return "claude-3-5-opus" + } + if lowercased == "claude" { + return "claude-opus-4-20250514" // Default to Claude Opus 4 + } + + // OpenAI model shortcuts + if lowercased == "gpt4" || lowercased == "gpt-4" { + return "gpt-4.1" + } + if lowercased == "gpt4-mini" || lowercased == "gpt-4-mini" { + return "gpt-4.1-mini" + } + if lowercased == "gpt" { + return "gpt-4.1" // Default to latest GPT + } + + // Grok model shortcuts + if lowercased == "grok" || lowercased == "grok4" || lowercased == "grok-4" { + return "grok-4-0709" + } + if lowercased == "grok3" || lowercased == "grok-3" { + return "grok-3" + } + if lowercased == "grok2" || lowercased == "grok-2" { + return "grok-2-vision-1212" + } + + // Ollama model shortcuts + if lowercased == "ollama" || lowercased == "llama" { + return "llama3.3" // Default to llama3.3 - best for agent tasks with tool support + } + if lowercased == "llama3" || lowercased == "llama-3" { + return "llama3.3" // Default to latest llama 3.x + } + + // Check if it's a partial match for any registered model + let registeredModels = Array(modelFactories.keys) + for model in registeredModels { + if model.lowercased().contains(lowercased) || lowercased.contains(model.lowercased()) { + return model + } + } + + return nil + } + + private func registerOpenAIModels() { + let models = [ + // GPT-4o series + "gpt-4o", + "gpt-4o-mini", + + // GPT-4.1 series + "gpt-4.1", + "gpt-4.1-mini", + + // o3 series (Responses API only) + "o3", + "o3-mini", + "o3-pro", + + // o4 series (Responses API only) + "o4-mini", + ] + + for modelName in models { + self.register(modelName: modelName) { + guard let apiKey = self.getOpenAIAPIKey() else { + throw TachikomaError.authenticationFailed + } + + return OpenAIModel(apiKey: apiKey, modelName: modelName) + } + } + } + + nonisolated private func getOpenAIAPIKey() -> String? { + // Check environment variable + if let apiKey = ProcessInfo.processInfo.environment["OPENAI_API_KEY"] { + return apiKey + } + + // Check standard configuration directory + let configPath = FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".tachikoma") + .appendingPathComponent("credentials") + + if let credentials = try? String(contentsOf: configPath) { + for line in credentials.components(separatedBy: .newlines) { + let trimmed = line.trimmingCharacters(in: .whitespaces) + if trimmed.hasPrefix("OPENAI_API_KEY=") { + return String(trimmed.dropFirst("OPENAI_API_KEY=".count)) + } + } + } + + return nil + } + + private func registerAnthropicModels() { + // Map of model names to their actual IDs + let modelMappings: [String: String] = [ + // Claude 4 series (Latest - May 2025) + "claude-opus-4-20250514": "claude-opus-4-20250514", + "claude-opus-4-20250514-thinking": "claude-opus-4-20250514-thinking", + "claude-sonnet-4-20250514": "claude-sonnet-4-20250514", + "claude-sonnet-4-20250514-thinking": "claude-sonnet-4-20250514-thinking", + + // Claude 3.7 series (February 2025) + "claude-3-7-sonnet": "claude-3-7-sonnet", + + // Claude 3.5 series (Still available) + "claude-3-5-haiku": "claude-3-5-haiku", + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-opus": "claude-3-5-opus", + ] + + for (alias, actualModelId) in modelMappings { + self.register(modelName: alias) { + guard let apiKey = self.getAnthropicAPIKey() else { + throw TachikomaError.authenticationFailed + } + + return AnthropicModel(apiKey: apiKey, modelName: actualModelId) + } + } + } + + nonisolated private func getAnthropicAPIKey() -> String? { + // Check environment variable + if let apiKey = ProcessInfo.processInfo.environment["ANTHROPIC_API_KEY"] { + return apiKey + } + + // Check standard configuration directory + let configPath = FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".tachikoma") + .appendingPathComponent("credentials") + + if let credentials = try? String(contentsOf: configPath) { + for line in credentials.components(separatedBy: .newlines) { + let trimmed = line.trimmingCharacters(in: .whitespaces) + if trimmed.hasPrefix("ANTHROPIC_API_KEY=") { + return String(trimmed.dropFirst("ANTHROPIC_API_KEY=".count)) + } + } + } + + return nil + } + + private func registerGrokModels() { + let models = [ + // Grok 4 series + "grok-4", + "grok-4-0709", + "grok-4-latest", + + // Grok 3 series + "grok-3", + "grok-3-mini", + "grok-3-fast", + "grok-3-mini-fast", + + // Grok 2 series + "grok-2-1212", + "grok-2-vision-1212", + "grok-2-image-1212", + + // Beta models + "grok-beta", + "grok-vision-beta", + ] + + for modelName in models { + self.register(modelName: modelName) { + guard let apiKey = self.getGrokAPIKey() else { + throw TachikomaError.authenticationFailed + } + + return GrokModel(apiKey: apiKey, modelName: modelName) + } + } + } + + nonisolated private func getGrokAPIKey() -> String? { + // Check environment variables (both variants) + if let apiKey = ProcessInfo.processInfo.environment["X_AI_API_KEY"] { + return apiKey + } + if let apiKey = ProcessInfo.processInfo.environment["XAI_API_KEY"] { + return apiKey + } + + // Check standard configuration directory + let configPath = FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".tachikoma") + .appendingPathComponent("credentials") + + if let credentials = try? String(contentsOf: configPath) { + for line in credentials.components(separatedBy: .newlines) { + let trimmed = line.trimmingCharacters(in: .whitespaces) + if trimmed.hasPrefix("X_AI_API_KEY=") { + return String(trimmed.dropFirst("X_AI_API_KEY=".count)) + } + if trimmed.hasPrefix("XAI_API_KEY=") { + return String(trimmed.dropFirst("XAI_API_KEY=".count)) + } + } + } + + return nil + } + + private func registerOllamaModels() { + // Common Ollama models + let models = [ + // Language models with tool support (recommended for agent tasks) + "llama3.3", + "llama3.3:latest", + "llama3.2", + "llama3.2:latest", + + // Vision models (NOTE: These do NOT support tool calling) + "llava:latest", + "llava", + "bakllava:latest", + "bakllava", + "llama3.2-vision:11b", + "llama3.2-vision:90b", + "qwen2.5vl:7b", + "qwen2.5vl:32b", + + // Other language models (tool support varies) + "llama2", + "llama2:latest", + "llama4", + "llama4:latest", + "codellama", + "codellama:latest", + "mistral", + "mistral:latest", + "mixtral", + "mixtral:latest", + "neural-chat", + "neural-chat:latest", + "gemma", + "gemma:latest", + "devstral", + "devstral:latest", + "deepseek-r1:8b", + "deepseek-r1:671b", + ] + + // Get base URL from environment or default + let baseURLString = ProcessInfo.processInfo.environment["TACHIKOMA_OLLAMA_BASE_URL"] ?? "http://localhost:11434" + guard let baseURL = URL(string: baseURLString) else { return } + + for modelName in models { + self.register(modelName: modelName) { + OllamaModel(modelName: modelName, baseURL: baseURL) + } + } + } + + // MARK: - Custom Provider Support + + /// Create a model instance for a custom provider + /// - Parameters: + /// - providerId: The custom provider ID + /// - modelPath: The model path within the provider + /// - Returns: A model instance + /// - Throws: TachikomaError if provider not found or configuration invalid + private func createCustomProviderModel(providerId: String, modelPath: String) throws -> any ModelInterface { + // For now, return a basic implementation + // This can be extended with a configuration system later + throw TachikomaError.modelNotFound("Custom providers not yet implemented") + } +} + +// MARK: - Model Provider Configuration + +/// Configuration for model providers +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum ProviderConfiguration { + /// OpenAI configuration + public struct OpenAI: Sendable { + public let apiKey: String + public let organizationId: String? + public let baseURL: URL? + + public init( + apiKey: String, + organizationId: String? = nil, + baseURL: URL? = nil) + { + self.apiKey = apiKey + self.organizationId = organizationId + self.baseURL = baseURL + } + } + + /// Anthropic configuration + public struct Anthropic: Sendable { + public let apiKey: String + public let baseURL: URL? + + public init( + apiKey: String, + baseURL: URL? = nil) + { + self.apiKey = apiKey + self.baseURL = baseURL + } + } + + /// Ollama configuration + public struct Ollama: Sendable { + public let baseURL: URL + + public init(baseURL: URL = URL(string: "http://localhost:11434")!) { + self.baseURL = baseURL + } + } + + /// Grok/xAI configuration + public struct Grok: Sendable { + public let apiKey: String + public let baseURL: URL? + + public init( + apiKey: String, + baseURL: URL? = nil) + { + self.apiKey = apiKey + self.baseURL = baseURL + } + } +} + +// MARK: - Model Provider Extensions + +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +extension ModelProvider { + /// Configure OpenAI models with specific settings + public func configureOpenAI(_ config: ProviderConfiguration.OpenAI) { + let models = [ + // GPT-4o series + "gpt-4o", + "gpt-4o-mini", + + // GPT-4.1 series + "gpt-4.1", + "gpt-4.1-mini", + + // o3 series (Responses API only) + "o3", + "o3-mini", + "o3-pro", + + // o4 series (Responses API only) + "o4-mini", + ] + + for modelName in models { + self.register(modelName: modelName) { + OpenAIModel( + apiKey: config.apiKey, + baseURL: config.baseURL ?? URL(string: "https://api.openai.com/v1")!, + organizationId: config.organizationId, + modelName: modelName) + } + } + } + + /// Configure Anthropic models with specific settings + public func configureAnthropic(_ config: ProviderConfiguration.Anthropic) { + // Map of model names to their actual IDs + let modelMappings: [String: String] = [ + // Claude 4 series (Latest - May 2025) + "claude-opus-4-20250514": "claude-opus-4-20250514", + "claude-opus-4-20250514-thinking": "claude-opus-4-20250514-thinking", + "claude-sonnet-4-20250514": "claude-sonnet-4-20250514", + "claude-sonnet-4-20250514-thinking": "claude-sonnet-4-20250514-thinking", + + // Claude 3.7 series (February 2025) + "claude-3-7-sonnet": "claude-3-7-sonnet", + + // Claude 3.5 series (Still available) + "claude-3-5-haiku": "claude-3-5-haiku", + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-opus": "claude-3-5-opus", + ] + + for (alias, actualModelId) in modelMappings { + self.register(modelName: alias) { + AnthropicModel( + apiKey: config.apiKey, + baseURL: config.baseURL ?? URL(string: "https://api.anthropic.com/v1")!, + modelName: actualModelId) + } + } + } + + /// Configure Ollama models with specific settings + public func configureOllama(_ config: ProviderConfiguration.Ollama) { + let models = [ + // Vision models + "llava:latest", + "llava", + "bakllava:latest", + "bakllava", + "llama3.2-vision:11b", + "llama3.2-vision:90b", + "qwen2.5vl:7b", + "qwen2.5vl:32b", + + // Language models + "llama2", + "llama2:latest", + "llama3.2", + "llama3.2:latest", + "llama3.3", + "llama3.3:latest", + "llama4", + "llama4:latest", + "codellama", + "codellama:latest", + "mistral", + "mistral:latest", + "mixtral", + "mixtral:latest", + "neural-chat", + "neural-chat:latest", + "gemma", + "gemma:latest", + "devstral", + "devstral:latest", + "deepseek-r1:8b", + "deepseek-r1:671b", + ] + + for modelName in models { + self.register(modelName: modelName) { + OllamaModel(modelName: modelName, baseURL: config.baseURL) + } + } + } + + /// Configure Grok models with specific settings + public func configureGrok(_ config: ProviderConfiguration.Grok) { + let models = [ + // Grok 4 series + "grok-4", + "grok-4-0709", + "grok-4-latest", + + // Grok 3 series + "grok-3", + "grok-3-mini", + "grok-3-fast", + "grok-3-mini-fast", + + // Grok 2 series + "grok-2-1212", + "grok-2-vision-1212", + "grok-2-image-1212", + + // Beta models + "grok-beta", + "grok-vision-beta", + ] + + for modelName in models { + self.register(modelName: modelName) { + GrokModel( + apiKey: config.apiKey, + modelName: modelName, + baseURL: config.baseURL ?? URL(string: "https://api.x.ai/v1")!) + } + } + } + + /// Quick setup with API key from environment + public func setupFromEnvironment() async throws { + if let apiKey = ProcessInfo.processInfo.environment["OPENAI_API_KEY"] { + self.configureOpenAI(ProviderConfiguration.OpenAI(apiKey: apiKey)) + } + + if let apiKey = ProcessInfo.processInfo.environment["ANTHROPIC_API_KEY"] { + self.configureAnthropic(ProviderConfiguration.Anthropic(apiKey: apiKey)) + } + + // Configure Ollama (no API key needed) + let ollamaBaseURL = ProcessInfo.processInfo.environment["TACHIKOMA_OLLAMA_BASE_URL"] ?? "http://localhost:11434" + if let baseURL = URL(string: ollamaBaseURL) { + self.configureOllama(ProviderConfiguration.Ollama(baseURL: baseURL)) + } + + // Configure Grok with various API key options + if let apiKey = ProcessInfo.processInfo.environment["X_AI_API_KEY"] ?? + ProcessInfo.processInfo.environment["XAI_API_KEY"] + { + self.configureGrok(ProviderConfiguration.Grok(apiKey: apiKey)) + } + } +} \ No newline at end of file diff --git a/Sources/Core/StreamingTypes.swift b/Sources/Core/StreamingTypes.swift new file mode 100644 index 0000000..66202f4 --- /dev/null +++ b/Sources/Core/StreamingTypes.swift @@ -0,0 +1,381 @@ +import Foundation + +// MARK: - Streaming Event Types + +/// Base protocol for all streaming events +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public protocol StreamingEvent: Codable, Sendable { + var type: StreamEventType { get } +} + +/// Types of streaming events +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum StreamEventType: String, Codable, Sendable { + case textDelta = "text_delta" + case responseStarted = "response_started" + case responseCompleted = "response_completed" + case toolCallDelta = "tool_call_delta" + case toolCallCompleted = "tool_call_completed" + case functionCallArgumentsDelta = "function_call_arguments_delta" + case error + case unknown + case reasoningSummaryDelta = "reasoning_summary_delta" + case reasoningSummaryCompleted = "reasoning_summary_completed" +} + +/// Main streaming event enum that encompasses all event types +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum StreamEvent: Codable, Sendable { + case textDelta(StreamTextDelta) + case responseStarted(StreamResponseStarted) + case responseCompleted(StreamResponseCompleted) + case toolCallDelta(StreamToolCallDelta) + case toolCallCompleted(StreamToolCallCompleted) + case functionCallArgumentsDelta(StreamFunctionCallArgumentsDelta) + case error(StreamError) + case unknown(StreamUnknown) + case reasoningSummaryDelta(StreamReasoningSummaryDelta) + case reasoningSummaryCompleted(StreamReasoningSummaryCompleted) + + // Custom coding for the enum + enum CodingKeys: String, CodingKey { + case type, data + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(StreamEventType.self, forKey: .type) + + switch type { + case .textDelta: + let data = try container.decode(StreamTextDelta.self, forKey: .data) + self = .textDelta(data) + case .responseStarted: + let data = try container.decode(StreamResponseStarted.self, forKey: .data) + self = .responseStarted(data) + case .responseCompleted: + let data = try container.decode(StreamResponseCompleted.self, forKey: .data) + self = .responseCompleted(data) + case .toolCallDelta: + let data = try container.decode(StreamToolCallDelta.self, forKey: .data) + self = .toolCallDelta(data) + case .toolCallCompleted: + let data = try container.decode(StreamToolCallCompleted.self, forKey: .data) + self = .toolCallCompleted(data) + case .functionCallArgumentsDelta: + let data = try container.decode(StreamFunctionCallArgumentsDelta.self, forKey: .data) + self = .functionCallArgumentsDelta(data) + case .error: + let data = try container.decode(StreamError.self, forKey: .data) + self = .error(data) + case .unknown: + let data = try container.decode(StreamUnknown.self, forKey: .data) + self = .unknown(data) + case .reasoningSummaryDelta: + let data = try container.decode(StreamReasoningSummaryDelta.self, forKey: .data) + self = .reasoningSummaryDelta(data) + case .reasoningSummaryCompleted: + let data = try container.decode(StreamReasoningSummaryCompleted.self, forKey: .data) + self = .reasoningSummaryCompleted(data) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case let .textDelta(data): + try container.encode(StreamEventType.textDelta, forKey: .type) + try container.encode(data, forKey: .data) + case let .responseStarted(data): + try container.encode(StreamEventType.responseStarted, forKey: .type) + try container.encode(data, forKey: .data) + case let .responseCompleted(data): + try container.encode(StreamEventType.responseCompleted, forKey: .type) + try container.encode(data, forKey: .data) + case let .toolCallDelta(data): + try container.encode(StreamEventType.toolCallDelta, forKey: .type) + try container.encode(data, forKey: .data) + case let .toolCallCompleted(data): + try container.encode(StreamEventType.toolCallCompleted, forKey: .type) + try container.encode(data, forKey: .data) + case let .functionCallArgumentsDelta(data): + try container.encode(StreamEventType.functionCallArgumentsDelta, forKey: .type) + try container.encode(data, forKey: .data) + case let .error(data): + try container.encode(StreamEventType.error, forKey: .type) + try container.encode(data, forKey: .data) + case let .unknown(data): + try container.encode(StreamEventType.unknown, forKey: .type) + try container.encode(data, forKey: .data) + case let .reasoningSummaryDelta(data): + try container.encode(StreamEventType.reasoningSummaryDelta, forKey: .type) + try container.encode(data, forKey: .data) + case let .reasoningSummaryCompleted(data): + try container.encode(StreamEventType.reasoningSummaryCompleted, forKey: .type) + try container.encode(data, forKey: .data) + } + } +} + +// MARK: - Concrete Streaming Event Types + +/// Text delta event containing incremental text output +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamTextDelta: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.textDelta + public let delta: String + public let index: Int? + + public init(delta: String, index: Int? = nil) { + self.delta = delta + self.index = index + } +} + +/// Response started event +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamResponseStarted: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.responseStarted + public let id: String + public let model: String? + public let systemFingerprint: String? + + public init(id: String, model: String? = nil, systemFingerprint: String? = nil) { + self.id = id + self.model = model + self.systemFingerprint = systemFingerprint + } +} + +/// Response completed event with final metadata +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamResponseCompleted: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.responseCompleted + public let id: String + public let usage: Usage? + public let finishReason: FinishReason? + + public init(id: String, usage: Usage? = nil, finishReason: FinishReason? = nil) { + self.id = id + self.usage = usage + self.finishReason = finishReason + } +} + +/// Tool call delta event for incremental tool call information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamToolCallDelta: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.toolCallDelta + public let id: String + public let index: Int + public let function: FunctionCallDelta + + public init(id: String, index: Int, function: FunctionCallDelta) { + self.id = id + self.index = index + self.function = function + } +} + +/// Tool call completed event +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamToolCallCompleted: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.toolCallCompleted + public let id: String + public let function: FunctionCall + + public init(id: String, function: FunctionCall) { + self.id = id + self.function = function + } +} + +/// Function call arguments delta event for incremental function call argument information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamFunctionCallArgumentsDelta: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.functionCallArgumentsDelta + public let id: String + public let arguments: String + + public init(id: String, arguments: String) { + self.id = id + self.arguments = arguments + } +} + +/// Error event for stream errors +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamError: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.error + public let error: ErrorDetail + + public init(error: ErrorDetail) { + self.error = error + } +} + +/// Unknown event for forward compatibility +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamUnknown: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.unknown + public let eventType: String + public let rawJSON: [UInt8] + + public init(eventType: String, rawJSON: [UInt8]) { + self.eventType = eventType + self.rawJSON = rawJSON + } + + /// Get the raw data as a dictionary if possible + public func getRawData() throws -> [String: Any]? { + let data = Data(self.rawJSON) + return try JSONSerialization.jsonObject(with: data) as? [String: Any] + } + + /// Get the raw data as a pretty-printed JSON string + public func getRawJSONString() -> String? { + let data = Data(self.rawJSON) + if let json = try? JSONSerialization.jsonObject(with: data), + let prettyData = try? JSONSerialization.data(withJSONObject: json, options: .prettyPrinted) + { + return String(data: prettyData, encoding: .utf8) + } + return String(data: data, encoding: .utf8) + } +} + +/// Reasoning summary delta event for o3 models +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamReasoningSummaryDelta: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.reasoningSummaryDelta + public let delta: String + public let index: Int? + + public init(delta: String, index: Int? = nil) { + self.delta = delta + self.index = index + } +} + +/// Reasoning summary completed event for o3 models +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct StreamReasoningSummaryCompleted: StreamingEvent, Codable, Sendable { + public var type = StreamEventType.reasoningSummaryCompleted + public let summary: String + public let reasoningTokens: Int? + + public init(summary: String, reasoningTokens: Int? = nil) { + self.summary = summary + self.reasoningTokens = reasoningTokens + } +} + +// MARK: - Supporting Types + +/// Function call delta for incremental function information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct FunctionCallDelta: Codable, Sendable { + public let name: String? + public let arguments: String? + + public init(name: String? = nil, arguments: String? = nil) { + self.name = name + self.arguments = arguments + } +} + +/// Token usage information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct Usage: Codable, Sendable { + public let promptTokens: Int + public let completionTokens: Int + public let totalTokens: Int + public let promptTokensDetails: TokenDetails? + public let completionTokensDetails: TokenDetails? + + public init( + promptTokens: Int, + completionTokens: Int, + totalTokens: Int, + promptTokensDetails: TokenDetails? = nil, + completionTokensDetails: TokenDetails? = nil) + { + self.promptTokens = promptTokens + self.completionTokens = completionTokens + self.totalTokens = totalTokens + self.promptTokensDetails = promptTokensDetails + self.completionTokensDetails = completionTokensDetails + } +} + +/// Detailed token usage breakdown +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct TokenDetails: Codable, Sendable { + public let cachedTokens: Int? + public let audioTokens: Int? + public let reasoningTokens: Int? + + public init(cachedTokens: Int? = nil, audioTokens: Int? = nil, reasoningTokens: Int? = nil) { + self.cachedTokens = cachedTokens + self.audioTokens = audioTokens + self.reasoningTokens = reasoningTokens + } +} + +/// Reason why the response finished +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum FinishReason: String, Codable, Sendable { + case stop + case length + case toolCalls = "tool_calls" + case contentFilter = "content_filter" + case functionCall = "function_call" +} + +/// Error detail information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct ErrorDetail: Codable, Sendable { + public let message: String + public let type: String? + public let code: String? + public let param: String? + + public init(message: String, type: String? = nil, code: String? = nil, param: String? = nil) { + self.message = message + self.type = type + self.code = code + self.param = param + } +} + +// MARK: - Stream Event Extensions + +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +extension StreamEvent { + /// Check if this is a final event + public var isFinal: Bool { + switch self { + case .responseCompleted, .error, .reasoningSummaryCompleted: + true + default: + false + } + } + + /// Extract any text content from the event + public var textContent: String? { + switch self { + case let .textDelta(delta): + delta.delta + case let .reasoningSummaryDelta(delta): + delta.delta + case let .reasoningSummaryCompleted(completed): + completed.summary + case let .error(error): + error.error.message + default: + nil + } + } +} \ No newline at end of file diff --git a/Sources/Core/TachikomaError.swift b/Sources/Core/TachikomaError.swift new file mode 100644 index 0000000..871d9d2 --- /dev/null +++ b/Sources/Core/TachikomaError.swift @@ -0,0 +1,202 @@ +@_exported import Foundation + +// MARK: - Tachikoma Error Types + +/// Main error type for the Tachikoma library +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum TachikomaError: Error, LocalizedError, Sendable { + case modelNotFound(String) + case authenticationFailed + case invalidConfiguration(String) + case networkError(underlying: any Error) + case decodingError(underlying: any Error) + case invalidRequest(String) + case apiError(message: String, code: String? = nil) + case timeout + case rateLimited + case insufficientQuota + case modelOverloaded + case contextLengthExceeded + case contentFiltered + case invalidToolCall(String) + case streamingError(String) + case configurationError(String) + + public var errorDescription: String? { + switch self { + case let .modelNotFound(model): + "Model not found: \(model)" + case .authenticationFailed: + "Authentication failed - check your API key" + case let .invalidConfiguration(message): + "Invalid configuration: \(message)" + case let .networkError(underlying): + "Network error: \(underlying.localizedDescription)" + case let .decodingError(underlying): + "Failed to decode response: \(underlying.localizedDescription)" + case let .invalidRequest(message): + "Invalid request: \(message)" + case let .apiError(message, code): + if let code { + "API error (\(code)): \(message)" + } else { + "API error: \(message)" + } + case .timeout: + "Request timed out" + case .rateLimited: + "Rate limited - please slow down requests" + case .insufficientQuota: + "Insufficient quota - check your billing" + case .modelOverloaded: + "Model is currently overloaded - try again later" + case .contextLengthExceeded: + "Context length exceeded - reduce input size" + case .contentFiltered: + "Content was filtered by safety systems" + case let .invalidToolCall(message): + "Invalid tool call: \(message)" + case let .streamingError(message): + "Streaming error: \(message)" + case let .configurationError(message): + "Configuration error: \(message)" + } + } + + public var recoverySuggestion: String? { + switch self { + case .modelNotFound: + "Check available models with listModels() or verify the model name is correct" + case .authenticationFailed: + "Verify your API key is set correctly in environment variables or credentials file" + case .rateLimited: + "Wait a moment before making another request" + case .insufficientQuota: + "Check your billing settings and account quota" + case .modelOverloaded: + "Try using a different model or retry after a delay" + case .contextLengthExceeded: + "Reduce the length of your input messages or use a model with larger context" + case .contentFiltered: + "Modify your input to comply with content policies" + case .timeout: + "Check your network connection and try again" + default: + nil + } + } + + /// Check if this error indicates a temporary condition that might resolve with retry + public var isRetryable: Bool { + switch self { + case .rateLimited, .modelOverloaded, .timeout, .networkError: + true + default: + false + } + } + + /// Check if this error indicates an authentication issue + public var isAuthenticationError: Bool { + switch self { + case .authenticationFailed, .insufficientQuota: + true + default: + false + } + } + + /// Check if this error indicates a client-side issue + public var isClientError: Bool { + switch self { + case .invalidRequest, .invalidConfiguration, .contextLengthExceeded, .contentFiltered, .invalidToolCall: + true + default: + false + } + } +} + +// MARK: - Model Request/Response Errors + +/// Specific errors for model requests and responses +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum ModelError: Error, LocalizedError, Sendable { + case invalidInput(String) + case missingRequiredParameter(String) + case unsupportedParameter(String) + case invalidParameterValue(String, value: String) + case responseTooLarge + case emptyResponse + case malformedResponse(String) + + public var errorDescription: String? { + switch self { + case let .invalidInput(message): + "Invalid input: \(message)" + case let .missingRequiredParameter(param): + "Missing required parameter: \(param)" + case let .unsupportedParameter(param): + "Unsupported parameter: \(param)" + case let .invalidParameterValue(param, value): + "Invalid value for parameter '\(param)': \(value)" + case .responseTooLarge: + "Response exceeds maximum size limit" + case .emptyResponse: + "Received empty response from API" + case let .malformedResponse(message): + "Malformed response: \(message)" + } + } +} + +// MARK: - Streaming Errors + +/// Errors specific to streaming operations +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum StreamingError: Error, LocalizedError, Sendable { + case streamClosed + case invalidEventFormat(String) + case bufferOverflow + case connectionLost + + public var errorDescription: String? { + switch self { + case .streamClosed: + "Stream was closed unexpectedly" + case let .invalidEventFormat(format): + "Invalid event format: \(format)" + case .bufferOverflow: + "Stream buffer overflow" + case .connectionLost: + "Connection to stream was lost" + } + } +} + +// MARK: - Tool Execution Errors + +/// Errors for tool execution +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum ToolExecutionError: Error, LocalizedError, Sendable { + case toolNotFound(String) + case invalidArguments(String) + case executionFailed(String) + case timeout + case missingContext + + public var errorDescription: String? { + switch self { + case let .toolNotFound(name): + "Tool not found: \(name)" + case let .invalidArguments(message): + "Invalid tool arguments: \(message)" + case let .executionFailed(message): + "Tool execution failed: \(message)" + case .timeout: + "Tool execution timed out" + case .missingContext: + "Required context is missing for tool execution" + } + } +} \ No newline at end of file diff --git a/Sources/Core/ToolDefinitions.swift b/Sources/Core/ToolDefinitions.swift new file mode 100644 index 0000000..3fd2969 --- /dev/null +++ b/Sources/Core/ToolDefinitions.swift @@ -0,0 +1,597 @@ +import Foundation + +// MARK: - Tool Definition + +/// A tool that can be used by an agent to perform actions +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct Tool { + /// Unique name of the tool + public let name: String + + /// Description of what the tool does + public let description: String + + /// Parameters the tool accepts + public let parameters: ToolParameters + + /// Whether to use strict parameter validation + public let strict: Bool + + /// The function to execute when the tool is called + public let execute: (ToolInput, Context) async throws -> ToolOutput + + public init( + name: String, + description: String, + parameters: ToolParameters, + strict: Bool = true, + execute: @escaping (ToolInput, Context) async throws -> ToolOutput) + { + self.name = name + self.description = description + self.parameters = parameters + self.strict = strict + self.execute = execute + } + + /// Convert to a tool definition for the model + public func toToolDefinition() -> ToolDefinition { + ToolDefinition( + type: .function, + function: FunctionDefinition( + name: self.name, + description: self.description, + parameters: self.parameters, + strict: self.strict)) + } +} + +// MARK: - Tool Definition Types + +/// Definition of a tool that can be sent to a model +public struct ToolDefinition: Codable, Sendable { + public let type: ToolType + public let function: FunctionDefinition + + public init(type: ToolType = .function, function: FunctionDefinition) { + self.type = type + self.function = function + } +} + +/// Type of tool +public enum ToolType: String, Codable, Sendable { + case function +} + +/// Function definition for a tool +public struct FunctionDefinition: Codable, Sendable { + public let name: String + public let description: String + public let parameters: ToolParameters + public let strict: Bool? + + public init( + name: String, + description: String, + parameters: ToolParameters, + strict: Bool? = nil) + { + self.name = name + self.description = description + self.parameters = parameters + self.strict = strict + } +} + +// MARK: - Tool Parameters + +/// Parameters schema for a tool +public struct ToolParameters: Codable, Sendable { + public let type: String + public let properties: [String: ParameterSchema] + public let required: [String] + public let additionalProperties: Bool + + public init( + type: String = "object", + properties: [String: ParameterSchema] = [:], + required: [String] = [], + additionalProperties: Bool = false) + { + self.type = type + self.properties = properties + self.required = required + self.additionalProperties = additionalProperties + } + + /// Create parameters from a dictionary of property definitions + public static func object( + properties: [String: ParameterSchema], + required: [String] = []) -> ToolParameters + { + ToolParameters( + type: "object", + properties: properties, + required: required, + additionalProperties: false) + } +} + +/// Schema for a single parameter +public struct ParameterSchema: Codable, Sendable { + public let type: ParameterType + public let description: String? + public let enumValues: [String]? + public let items: Box? + public let properties: [String: ParameterSchema]? + public let minimum: Double? + public let maximum: Double? + public let pattern: String? + + public init( + type: ParameterType, + description: String? = nil, + enumValues: [String]? = nil, + items: ParameterSchema? = nil, + properties: [String: ParameterSchema]? = nil, + minimum: Double? = nil, + maximum: Double? = nil, + pattern: String? = nil) + { + self.type = type + self.description = description + self.enumValues = enumValues + self.items = items.map(Box.init) + self.properties = properties + self.minimum = minimum + self.maximum = maximum + self.pattern = pattern + } + + // Convenience initializers + public static func string(description: String? = nil, pattern: String? = nil) -> ParameterSchema { + ParameterSchema(type: .string, description: description, pattern: pattern) + } + + public static func number( + description: String? = nil, + minimum: Double? = nil, + maximum: Double? = nil) -> ParameterSchema + { + ParameterSchema(type: .number, description: description, minimum: minimum, maximum: maximum) + } + + public static func integer( + description: String? = nil, + minimum: Double? = nil, + maximum: Double? = nil) -> ParameterSchema + { + ParameterSchema(type: .integer, description: description, minimum: minimum, maximum: maximum) + } + + public static func boolean(description: String? = nil) -> ParameterSchema { + ParameterSchema(type: .boolean, description: description) + } + + public static func array(of items: ParameterSchema, description: String? = nil) -> ParameterSchema { + ParameterSchema(type: .array, description: description, items: items) + } + + public static func object(properties: [String: ParameterSchema], description: String? = nil) -> ParameterSchema { + ParameterSchema(type: .object, description: description, properties: properties) + } + + public static func enumeration(_ values: [String], description: String? = nil) -> ParameterSchema { + ParameterSchema(type: .string, description: description, enumValues: values) + } + + // Custom coding keys + enum CodingKeys: String, CodingKey { + case type, description + case enumValues = "enum" + case items, properties + case minimum, maximum, pattern + } +} + +/// Parameter types +public enum ParameterType: String, Codable, Sendable { + case string + case number + case integer + case boolean + case array + case object + case null +} + +// MARK: - Tool Input/Output + +/// Input provided to a tool +public enum ToolInput { + case string(String) + case dictionary([String: Any]) + case array([Any]) + case null + + /// Parse from a JSON string + public init(jsonString: String) throws { + // Handle empty string as empty dictionary + if jsonString.isEmpty { + self = .dictionary([:]) + return + } + + guard let data = jsonString.data(using: .utf8) else { + throw ToolError.invalidInput("Invalid JSON string") + } + + let parsed = try JSONSerialization.jsonObject(with: data) + + if let dict = parsed as? [String: Any] { + self = .dictionary(dict) + } else if let array = parsed as? [Any] { + self = .array(array) + } else if let string = parsed as? String { + self = .string(string) + } else { + self = .null + } + } + + /// Get value for a specific key (for dictionary inputs) + public func value(for key: String) -> T? { + guard case let .dictionary(dict) = self else { return nil } + return dict[key] as? T + } + + /// Get the raw string value + public var stringValue: String? { + switch self { + case let .string(str): + return str + case .dictionary, .array: + if let data = try? JSONSerialization.data(withJSONObject: rawValue), + let str = String(data: data, encoding: .utf8) + { + return str + } + return nil + case .null: + return nil + } + } + + /// Get the raw value + public var rawValue: Any { + switch self { + case let .string(str): + str + case let .dictionary(dict): + dict + case let .array(array): + array + case .null: + NSNull() + } + } +} + +/// Strongly-typed output from a tool +public enum ToolOutput: Codable, Sendable { + case string(String) + case number(Double) + case boolean(Bool) + case object([String: ToolOutput]) + case array([ToolOutput]) + case null + case error(message: String, code: String? = nil) + + // MARK: - Codable Implementation + + private enum CodingKeys: String, CodingKey { + case type, value, message, code + } + + private enum OutputType: String, Codable { + case string, number, boolean, object, array, null, error + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(OutputType.self, forKey: .type) + + switch type { + case .string: + let value = try container.decode(String.self, forKey: .value) + self = .string(value) + case .number: + let value = try container.decode(Double.self, forKey: .value) + self = .number(value) + case .boolean: + let value = try container.decode(Bool.self, forKey: .value) + self = .boolean(value) + case .object: + let value = try container.decode([String: ToolOutput].self, forKey: .value) + self = .object(value) + case .array: + let value = try container.decode([ToolOutput].self, forKey: .value) + self = .array(value) + case .null: + self = .null + case .error: + let message = try container.decode(String.self, forKey: .message) + let code = try container.decodeIfPresent(String.self, forKey: .code) + self = .error(message: message, code: code) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case let .string(value): + try container.encode(OutputType.string, forKey: .type) + try container.encode(value, forKey: .value) + case let .number(value): + try container.encode(OutputType.number, forKey: .type) + try container.encode(value, forKey: .value) + case let .boolean(value): + try container.encode(OutputType.boolean, forKey: .type) + try container.encode(value, forKey: .value) + case let .object(value): + try container.encode(OutputType.object, forKey: .type) + try container.encode(value, forKey: .value) + case let .array(value): + try container.encode(OutputType.array, forKey: .type) + try container.encode(value, forKey: .value) + case .null: + try container.encode(OutputType.null, forKey: .type) + case let .error(message, code): + try container.encode(OutputType.error, forKey: .type) + try container.encode(message, forKey: .message) + try container.encodeIfPresent(code, forKey: .code) + } + } + + // MARK: - Conversion Methods + + /// Convert to JSON string for the model + public func toJSONString() throws -> String { + switch self { + case let .string(str): + return str // Return string directly for text output + case let .error(message, code): + // Special handling for errors to match expected format + var errorDict: [String: ToolOutput] = ["error": .string(message)] + if let code { + errorDict["error_code"] = .string(code) + } + let data = try JSONEncoder().encode(ToolOutput.object(errorDict)) + guard let string = String(data: data, encoding: .utf8) else { + throw ToolError.serializationFailed + } + return string + default: + // For all other types, encode normally + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + let data = try encoder.encode(self) + guard let string = String(data: data, encoding: .utf8) else { + throw ToolError.serializationFailed + } + return string + } + } + + /// Convert to a dictionary representation (for compatibility) + public func toDictionary() -> [String: Any]? { + switch self { + case let .object(dict): + var result: [String: Any] = [:] + for (key, value) in dict { + if let converted = value.toAny() { + result[key] = converted + } + } + return result + default: + return nil + } + } + + /// Convert to Any (for legacy compatibility) + private func toAny() -> Any? { + switch self { + case let .string(value): + return value + case let .number(value): + return value + case let .boolean(value): + return value + case let .object(dict): + var result: [String: Any] = [:] + for (key, value) in dict { + if let converted = value.toAny() { + result[key] = converted + } + } + return result + case let .array(array): + return array.compactMap { $0.toAny() } + case .null: + return NSNull() + case let .error(message, _): + return ["error": message] + } + } +} + +// MARK: - Builder Methods + +extension ToolOutput { + /// Create a dictionary/object output using a builder pattern + public static func dictionary(_ builder: () -> [String: ToolOutput]) -> ToolOutput { + .object(builder()) + } + + /// Create a dictionary/object output from key-value pairs + public static func dictionary(_ pairs: (String, ToolOutput)...) -> ToolOutput { + var dict: [String: ToolOutput] = [:] + for (key, value) in pairs { + dict[key] = value + } + return .object(dict) + } + + /// Create from a Swift dictionary with automatic type conversion + public static func from(_ dict: [String: Any]) -> ToolOutput { + var result: [String: ToolOutput] = [:] + for (key, value) in dict { + result[key] = self.from(value) + } + return .object(result) + } + + /// Create from any Swift value with automatic type conversion + public static func from(_ value: Any) -> ToolOutput { + switch value { + case let str as String: + .string(str) + case let num as Int: + .number(Double(num)) + case let num as Double: + .number(num) + case let bool as Bool: + .boolean(bool) + case let dict as [String: Any]: + self.from(dict) + case let array as [Any]: + .array(array.map { self.from($0) }) + case is NSNull: + .null + default: + // Fallback to string representation + .string(String(describing: value)) + } + } + + /// Convenience method for success results + public static func success(_ message: String, metadata: (String, ToolOutput)...) -> ToolOutput { + var dict: [String: ToolOutput] = ["result": .string(message)] + for (key, value) in metadata { + dict[key] = value + } + return .object(dict) + } +} + +// MARK: - Tool Errors + +/// Errors that can occur during tool execution +public enum ToolError: Error, LocalizedError, Sendable { + case invalidInput(String) + case executionFailed(String) + case serializationFailed + case contextMissing + case toolNotFound(String) + + public var errorDescription: String? { + switch self { + case let .invalidInput(message): + "Invalid tool input: \(message)" + case let .executionFailed(message): + "Tool execution failed: \(message)" + case .serializationFailed: + "Failed to serialize tool output" + case .contextMissing: + "Required context is missing" + case let .toolNotFound(name): + "Tool not found: \(name)" + } + } +} + +// MARK: - Helper Types + +/// Box type for recursive data structures +public final class Box: Codable, Sendable { + public let value: T + + public init(_ value: T) { + self.value = value + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + self.value = try container.decode(T.self) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(self.value) + } +} + +// MARK: - Tool Builder + +/// Builder pattern for creating tools +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct ToolBuilder { + private var name: String = "" + private var description: String = "" + private var parameters: ToolParameters = .init() + private var strict: Bool = true + private var execute: ((ToolInput, Context) async throws -> ToolOutput)? + + public init() {} + + public func withName(_ name: String) -> ToolBuilder { + var builder = self + builder.name = name + return builder + } + + public func withDescription(_ description: String) -> ToolBuilder { + var builder = self + builder.description = description + return builder + } + + public func withParameters(_ parameters: ToolParameters) -> ToolBuilder { + var builder = self + builder.parameters = parameters + return builder + } + + public func withStrict(_ strict: Bool) -> ToolBuilder { + var builder = self + builder.strict = strict + return builder + } + + public func withExecution(_ execute: @escaping (ToolInput, Context) async throws -> ToolOutput) + -> ToolBuilder { + var builder = self + builder.execute = execute + return builder + } + + public func build() throws -> Tool { + guard !self.name.isEmpty else { + throw ToolError.invalidInput("Tool name is required") + } + + guard let execute else { + throw ToolError.invalidInput("Tool execution function is required") + } + + return Tool( + name: self.name, + description: self.description, + parameters: self.parameters, + strict: self.strict, + execute: execute) + } +} \ No newline at end of file diff --git a/Sources/Providers/Anthropic/AnthropicModel.swift b/Sources/Providers/Anthropic/AnthropicModel.swift new file mode 100644 index 0000000..151106a --- /dev/null +++ b/Sources/Providers/Anthropic/AnthropicModel.swift @@ -0,0 +1,642 @@ +import Foundation + +/// Anthropic model implementation conforming to ModelInterface +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public final class AnthropicModel: ModelInterface, Sendable { + private let apiKey: String + private let baseURL: URL + private let session: URLSession + private let anthropicVersion: String + private let modelName: String? + private let customHeaders: [String: String]? + + public init( + apiKey: String, + baseURL: URL = URL(string: "https://api.anthropic.com/v1")!, + anthropicVersion: String = "2023-06-01", + modelName: String? = nil, + session: URLSession? = nil) + { + self.apiKey = apiKey + self.baseURL = baseURL + self.anthropicVersion = anthropicVersion + self.modelName = modelName + self.customHeaders = nil + + if let session { + self.session = session + } else { + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 120 // 2 minutes + config.timeoutIntervalForResource = 120 + self.session = URLSession(configuration: config) + } + } + + /// Initialize with custom provider configuration + public init( + apiKey: String, + baseURL: String, + modelName: String? = nil, + headers: [String: String]? = nil, + anthropicVersion: String = "2023-06-01", + session: URLSession? = nil) + { + self.apiKey = apiKey + self.baseURL = URL(string: baseURL) ?? URL(string: "https://api.anthropic.com/v1")! + self.anthropicVersion = anthropicVersion + self.modelName = modelName + self.customHeaders = headers + + if let session { + self.session = session + } else { + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 120 // 2 minutes + config.timeoutIntervalForResource = 120 + self.session = URLSession(configuration: config) + } + } + + // MARK: - ModelInterface Implementation + + public var maskedApiKey: String { + guard self.apiKey.count > 8 else { return "***" } + let start = self.apiKey.prefix(6) + let end = self.apiKey.suffix(2) + return "\(start)...\(end)" + } + + public func getResponse(request: ModelRequest) async throws -> ModelResponse { + let anthropicRequest = try convertToAnthropicRequest(request, stream: false) + let urlRequest = try createURLRequest(endpoint: "messages", body: anthropicRequest) + + let (data, response) = try await session.data(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + throw TachikomaError.networkError(underlying: URLError(.badServerResponse)) + } + + if httpResponse.statusCode != 200 { + try handleErrorResponse(data: data, response: httpResponse) + } + + do { + let anthropicResponse = try JSONDecoder().decode(AnthropicResponse.self, from: data) + return try self.convertFromAnthropicResponse(anthropicResponse) + } catch { + throw TachikomaError.decodingError(underlying: error) + } + } + + public func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + let anthropicRequest = try convertToAnthropicRequest(request, stream: true) + let urlRequest = try createURLRequest(endpoint: "messages", body: anthropicRequest) + + return AsyncThrowingStream { continuation in + Task { + do { + let (bytes, response) = try await session.bytes(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + continuation.finish(throwing: TachikomaError.networkError(underlying: URLError(.badServerResponse))) + return + } + + if httpResponse.statusCode != 200 { + var errorData = Data() + for try await byte in bytes.prefix(1024) { + errorData.append(byte) + } + try self.handleErrorResponse(data: errorData, response: httpResponse) + } + + // Process SSE stream + var currentToolCalls: [String: PartialToolCall] = [:] + var responseId: String? + var accumulatedText = "" + var currentContentIndex = 0 + var pendingUsage: Usage? + + for try await line in bytes.lines { + // Skip empty lines + if line.isEmpty { + continue + } + + // Handle SSE format + if line.hasPrefix("data: ") { + let data = String(line.dropFirst(6)) + + // Parse the event + if let eventData = data.data(using: .utf8) { + do { + let event = try JSONDecoder().decode( + AnthropicStreamEvent.self, + from: eventData) + + switch event.type { + case "message_start": + if let message = event.message { + responseId = message.id + continuation.yield(.responseStarted(StreamResponseStarted( + id: message.id, + model: message.model, + systemFingerprint: nil))) + } + + case "content_block_start": + currentContentIndex = event.index ?? 0 + if let block = event.contentBlock { + if block.type == "tool_use", let id = block.id, let name = block.name { + // Start tracking this tool call + let partialCall = PartialToolCall( + id: id, + name: name, + index: currentContentIndex) + currentToolCalls[id] = partialCall + } + } + + case "content_block_delta": + if let delta = event.delta { + if let text = delta.text { + // Text delta + continuation.yield(.textDelta(StreamTextDelta( + delta: text, + index: currentContentIndex))) + accumulatedText += text + } else if let partialJson = delta.partialJson { + // Tool use arguments delta + // Find the tool call being updated + if let toolCall = currentToolCalls.values + .first(where: { $0.index == currentContentIndex }) + { + toolCall.appendArguments(partialJson) + continuation.yield(.toolCallDelta(StreamToolCallDelta( + id: toolCall.id, + index: toolCall.index, + function: FunctionCallDelta( + name: toolCall.name, + arguments: partialJson)))) + } + } + } + + case "content_block_stop": + // Complete any tool calls at this index + for (id, toolCall) in currentToolCalls { + if toolCall.index == currentContentIndex { + if let completed = toolCall.toCompleted() { + continuation.yield(.toolCallCompleted( + StreamToolCallCompleted(id: id, function: completed))) + } + } + } + + case "message_delta": + // Skip regular parsing - will handle in catch block for usage data + break + + case "message_stop": + // Final completion - emit responseCompleted with usage if available + if let id = responseId { + continuation.yield(.responseCompleted(StreamResponseCompleted( + id: id, + usage: pendingUsage, + finishReason: .stop))) + } + continuation.finish() + return + + case "error": + if let error = event.error { + continuation.finish(throwing: TachikomaError.apiError( + message: error.message)) + return + } + + default: + // Unknown event type, ignore + break + } + } catch { + // Special handling for message_delta events with usage + if let jsonData = data.data(using: .utf8) { + do { + if let json = try JSONSerialization + .jsonObject(with: jsonData) as? [String: Any] + { + if json["type"] as? String == "message_delta" { + if let usage = json["usage"] as? [String: Any] { + if let outputTokens = usage["output_tokens"] as? Int { + // Extract input tokens if available + let inputTokens = usage["input_tokens"] as? Int ?? 0 + + // Create Usage object + let tokenUsage = Usage( + promptTokens: inputTokens, + completionTokens: outputTokens, + totalTokens: inputTokens + outputTokens, + promptTokensDetails: nil, + completionTokensDetails: nil) + + // Store the usage for later + pendingUsage = tokenUsage + } + } + } + } + } catch { + // Failed to parse JSON, ignore + } + } + } + } + } + } + + continuation.finish() + } catch { + continuation.finish(throwing: TachikomaError.streamingError(error.localizedDescription)) + } + } + } + } + + // MARK: - Private Methods + + private func createURLRequest(endpoint: String, body: any Encodable) throws -> URLRequest { + let url = self.baseURL.appendingPathComponent(endpoint) + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue(self.apiKey, forHTTPHeaderField: "x-api-key") + request.setValue(self.anthropicVersion, forHTTPHeaderField: "anthropic-version") + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + // Add custom headers from provider configuration + customHeaders?.forEach { key, value in + request.setValue(value, forHTTPHeaderField: key) + } + + let encoder = JSONEncoder() + encoder.outputFormatting = .sortedKeys + + do { + request.httpBody = try encoder.encode(body) + } catch { + throw TachikomaError.configurationError("Failed to encode Anthropic request: \(error.localizedDescription)") + } + + request.timeoutInterval = 120 + + return request + } + + private func convertToAnthropicRequest(_ request: ModelRequest, stream: Bool) throws -> AnthropicRequest { + var anthropicMessages: [AnthropicMessage] = [] + var systemPrompt: String? = nil + + // Convert messages + for message in request.messages { + switch message { + case let .system(_, content): + // Anthropic uses a separate system parameter + if systemPrompt == nil { + systemPrompt = content + } else { + systemPrompt = (systemPrompt ?? "") + "\n\n" + content + } + + case let .user(_, content): + let anthropicMessage = try convertUserMessage(content) + anthropicMessages.append(anthropicMessage) + + case let .assistant(_, content, status): + let anthropicMessage = try convertAssistantMessage(content, status: status) + anthropicMessages.append(anthropicMessage) + + case let .tool(_, toolCallId, content): + // Convert tool result to user message with tool_result content block + let toolResultBlock = AnthropicContentBlock.toolResult( + toolUseId: toolCallId, + content: content) + anthropicMessages.append(AnthropicMessage( + role: .user, + content: .array([toolResultBlock]))) + + case let .reasoning(_, content): + // Treat reasoning as a system message for now + if systemPrompt == nil { + systemPrompt = "[Reasoning] " + content + } else { + systemPrompt = (systemPrompt ?? "") + "\n\n[Reasoning] " + content + } + } + } + + // Convert tools + let tools = request.tools?.map { toolDef -> AnthropicTool in + AnthropicTool( + name: toolDef.function.name, + description: toolDef.function.description, + inputSchema: self.convertToolParameters(toolDef.function.parameters)) + } + + // Convert tool choice + let toolChoice = self.convertToolChoice(request.settings.toolChoice) + + // Create system content with cache control + let systemContent: AnthropicSystemContent? = if let systemPrompt { + // Use array format with cache control for system prompt + .array([ + AnthropicSystemBlock( + type: "text", + text: systemPrompt, + cacheControl: AnthropicCacheControl(type: "ephemeral")), + ]) + } else { + nil + } + + return AnthropicRequest( + model: self.modelName ?? request.settings.modelName, + messages: anthropicMessages, + system: systemContent, + maxTokens: request.settings.maxTokens ?? 4096, + temperature: request.settings.temperature, + topP: request.settings.topP, + topK: request.settings.additionalParameters?.int("top_k"), + stream: stream, + stopSequences: request.settings.stopSequences, + tools: tools, + toolChoice: toolChoice, + metadata: request.settings.user.map { AnthropicMetadata(userId: $0) }) + } + + private func convertUserMessage(_ content: MessageContent) throws -> AnthropicMessage { + switch content { + case let .text(text): + return AnthropicMessage(role: .user, content: .string(text)) + + case let .image(imageContent): + var blocks: [AnthropicContentBlock] = [] + + if let base64 = imageContent.base64 { + blocks.append(.image(base64: base64, mediaType: "image/jpeg")) + } else if imageContent.url != nil { + // For URLs, we'd need to download and convert to base64 + // For now, throw an error + throw TachikomaError.invalidRequest("Image URLs not supported - please provide base64 data") + } + + return AnthropicMessage(role: .user, content: .array(blocks)) + + case let .multimodal(parts): + let blocks = try parts.compactMap { part -> AnthropicContentBlock? in + if let text = part.text { + return .text(text) + } else if let image = part.imageUrl { + if let base64 = image.base64 { + return .image(base64: base64, mediaType: "image/jpeg") + } else if image.url != nil { + throw TachikomaError.invalidRequest("Image URLs not supported - please provide base64 data") + } + } + return nil + } + return AnthropicMessage(role: .user, content: .array(blocks)) + + case .file: + throw TachikomaError.invalidRequest("File content not supported in Anthropic API") + + case let .audio(audioContent): + // Claude doesn't support native audio, so we need to use the transcript + if let transcript = audioContent.transcript { + // Include metadata about the audio source + var text = transcript + if let duration = audioContent.duration { + text = "[Audio transcript, duration: \(Int(duration))s] \(transcript)" + } else { + text = "[Audio transcript] \(transcript)" + } + return AnthropicMessage(role: .user, content: .string(text)) + } else { + throw TachikomaError.invalidRequest("Audio content must be transcribed before sending to Claude. Please ensure transcript is provided.") + } + } + } + + private func convertAssistantMessage( + _ content: [AssistantContent], + status: MessageStatus) throws -> AnthropicMessage + { + var blocks: [AnthropicContentBlock] = [] + + for content in content { + switch content { + case let .outputText(text): + blocks.append(.text(text)) + + case let .refusal(refusal): + blocks.append(.text(refusal)) + + case let .toolCall(toolCall): + // Parse arguments as JSON + let arguments: [String: Any] = if let data = toolCall.function.arguments.data(using: .utf8), + let json = try? JSONSerialization + .jsonObject(with: data) as? [String: Any] + { + json + } else { + [:] + } + + blocks.append(.toolUse( + id: toolCall.id, + name: toolCall.function.name, + input: arguments)) + } + } + + return AnthropicMessage(role: .assistant, content: .array(blocks)) + } + + private func convertToolParameters(_ params: ToolParameters) -> AnthropicJSONSchema { + var properties: [String: AnthropicPropertySchema] = [:] + + for (key, schema) in params.properties { + properties[key] = self.convertParameterSchema(schema) + } + + return AnthropicJSONSchema( + type: params.type, + properties: properties, + required: params.required) + } + + private func convertParameterSchema(_ schema: ParameterSchema) -> AnthropicPropertySchema { + // Handle nested items for arrays + let items: AnthropicPropertySchema? = if schema.type == .array, let schemaItems = schema.items { + self.convertParameterSchema(schemaItems.value) + } else { + nil + } + + // Handle nested properties for objects + let properties: [String: AnthropicPropertySchema]? + if schema.type == .object, let schemaProps = schema.properties { + var convertedProps: [String: AnthropicPropertySchema] = [:] + for (key, nestedSchema) in schemaProps { + convertedProps[key] = self.convertParameterSchema(nestedSchema) + } + properties = convertedProps + } else { + properties = nil + } + + return AnthropicPropertySchema( + type: schema.type.rawValue, + description: schema.description, + enum: schema.enumValues, + items: items, + properties: properties, + required: nil // ParameterSchema doesn't have required field at this level + ) + } + + private func convertToolChoice(_ toolChoice: ToolChoice?) -> AnthropicToolChoice? { + guard let toolChoice else { return nil } + + switch toolChoice { + case .auto: + return .auto + case .none: + return nil // Anthropic doesn't have a "none" option + case .required: + return .any + case let .specific(toolName): + return .tool(name: toolName) + } + } + + private func convertFromAnthropicResponse(_ response: AnthropicResponse) throws -> ModelResponse { + var content: [AssistantContent] = [] + + // Convert content blocks + for block in response.content { + switch block.type { + case "text": + if let text = block.text { + content.append(.outputText(text)) + } + + case "tool_use": + if let id = block.id, + let name = block.name, + let input = block.input + { + // Convert input dictionary to JSON string + let arguments: String = if let data = try? JSONSerialization.data(withJSONObject: input.mapValues { $0.toAny() }), + let json = String(data: data, encoding: .utf8) + { + json + } else { + "{}" + } + + content.append(.toolCall(ToolCallItem( + id: id, + type: .function, + function: FunctionCall( + name: name, + arguments: arguments)))) + } + + default: + // Unknown content block type, ignore + break + } + } + + let usage = Usage( + promptTokens: response.usage.inputTokens, + completionTokens: response.usage.outputTokens, + totalTokens: response.usage.inputTokens + response.usage.outputTokens, + promptTokensDetails: nil, + completionTokensDetails: nil) + + return ModelResponse( + id: response.id, + model: response.model, + content: content, + usage: usage, + flagged: false, + finishReason: self.convertStopReason(response.stopReason)) + } + + private func convertStopReason(_ reason: String?) -> FinishReason? { + guard let reason else { return nil } + + switch reason { + case "end_turn": + return .stop + case "max_tokens": + return .length + case "stop_sequence": + return .stop + case "tool_use": + return .toolCalls + default: + return .stop + } + } + + private func handleErrorResponse(data: Data, response: HTTPURLResponse) throws { + if let errorResponse = try? JSONDecoder().decode(AnthropicErrorResponse.self, from: data) { + let message = errorResponse.error.message + + switch response.statusCode { + case 401: + throw TachikomaError.authenticationFailed + case 429: + throw TachikomaError.rateLimited + case 400: + if message.contains("credit") || message.contains("usage") { + throw TachikomaError.insufficientQuota + } else { + throw TachikomaError.invalidRequest(message) + } + case 500...599: + throw TachikomaError.modelOverloaded + default: + throw TachikomaError.apiError(message: message) + } + } else { + throw TachikomaError.apiError(message: "HTTP \(response.statusCode)") + } + } +} + +// MARK: - Helper Types + +private class PartialToolCall { + let id: String + let name: String + let index: Int + var arguments: String = "" + + init(id: String, name: String, index: Int) { + self.id = id + self.name = name + self.index = index + } + + func appendArguments(_ args: String) { + self.arguments += args + } + + func toCompleted() -> FunctionCall? { + FunctionCall(name: self.name, arguments: self.arguments) + } +} \ No newline at end of file diff --git a/Sources/Providers/Anthropic/AnthropicTypes.swift b/Sources/Providers/Anthropic/AnthropicTypes.swift new file mode 100644 index 0000000..1e91af9 --- /dev/null +++ b/Sources/Providers/Anthropic/AnthropicTypes.swift @@ -0,0 +1,678 @@ +import Foundation + +// MARK: - Anthropic API Request Types + +/// Cache control configuration +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicCacheControl: Codable, Sendable { + public let type: String // "ephemeral" + + public init(type: String = "ephemeral") { + self.type = type + } +} + +/// System content that can be cached +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum AnthropicSystemContent: Codable, Sendable { + case string(String) + case array([AnthropicSystemBlock]) + + // Custom encoding/decoding + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else if let arrayValue = try? container.decode([AnthropicSystemBlock].self) { + self = .array(arrayValue) + } else { + throw DecodingError.typeMismatch( + AnthropicSystemContent.self, + DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "Expected String or Array of system blocks")) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .array(blocks): + try container.encode(blocks) + } + } +} + +/// System block that can contain cache control +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicSystemBlock: Codable, Sendable { + public let type: String // "text" + public let text: String + public let cacheControl: AnthropicCacheControl? + + enum CodingKeys: String, CodingKey { + case type, text + case cacheControl = "cache_control" + } + + public init(type: String = "text", text: String, cacheControl: AnthropicCacheControl? = nil) { + self.type = type + self.text = text + self.cacheControl = cacheControl + } +} + +/// Main request structure for Anthropic's Messages API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicRequest: Codable, Sendable { + /// ID of the model to use (e.g., "claude-3-opus-20240229") + public let model: String + + /// Input messages + public let messages: [AnthropicMessage] + + /// System prompt (separate from messages) + public let system: AnthropicSystemContent? + + /// Maximum number of tokens to generate + public let maxTokens: Int + + /// Temperature for randomness (0.0 to 1.0) + public let temperature: Double? + + /// Top-p sampling parameter + public let topP: Double? + + /// Top-k sampling parameter + public let topK: Int? + + /// Whether to stream the response + public let stream: Bool? + + /// Stop sequences + public let stopSequences: [String]? + + /// Available tools + public let tools: [AnthropicTool]? + + /// Tool choice configuration + public let toolChoice: AnthropicToolChoice? + + /// Metadata about the request + public let metadata: AnthropicMetadata? + + enum CodingKeys: String, CodingKey { + case model, messages, system + case maxTokens = "max_tokens" + case temperature + case topP = "top_p" + case topK = "top_k" + case stream + case stopSequences = "stop_sequences" + case tools + case toolChoice = "tool_choice" + case metadata + } +} + +/// Anthropic message structure +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicMessage: Codable, Sendable { + /// Role of the message sender + public let role: AnthropicRole + + /// Content of the message + public let content: AnthropicContent + + public init(role: AnthropicRole, content: AnthropicContent) { + self.role = role + self.content = content + } +} + +/// Message roles in Anthropic API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum AnthropicRole: String, Codable, Sendable { + case user + case assistant +} + +/// Content types for Anthropic messages +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum AnthropicContent: Codable, Sendable { + case string(String) + case array([AnthropicContentBlock]) + + // Custom encoding/decoding + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else if let arrayValue = try? container.decode([AnthropicContentBlock].self) { + self = .array(arrayValue) + } else { + throw DecodingError.typeMismatch( + AnthropicContent.self, + DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "Expected String or Array of content blocks")) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .array(blocks): + try container.encode(blocks) + } + } +} + +/// Content block for multimodal messages +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicContentBlock: Codable, Sendable { + public let type: String + + // Text content + public let text: String? + + // Image content + public let source: AnthropicImageSource? + + // Tool use content + public let id: String? + public let name: String? + public let input: [String: AnthropicInputValue]? + + // Tool result content + public let toolUseId: String? + public let content: AnthropicContent? + public let isError: Bool? + + // Cache control + public let cacheControl: AnthropicCacheControl? + + enum CodingKeys: String, CodingKey { + case type, text, source, id, name, input + case toolUseId = "tool_use_id" + case content + case isError = "is_error" + case cacheControl = "cache_control" + } +} + +/// Image source for content blocks +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicImageSource: Codable, Sendable { + public let type: String // "base64" + public let mediaType: String // "image/jpeg", "image/png", etc. + public let data: String // base64 encoded image data + + enum CodingKeys: String, CodingKey { + case type + case mediaType = "media_type" + case data + } +} + +// MARK: - Tool Definitions + +/// Tool definition for Anthropic +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicTool: Codable, Sendable { + public let name: String + public let description: String + public let inputSchema: AnthropicJSONSchema + public let cacheControl: AnthropicCacheControl? + + enum CodingKeys: String, CodingKey { + case name, description + case inputSchema = "input_schema" + case cacheControl = "cache_control" + } + + public init( + name: String, + description: String, + inputSchema: AnthropicJSONSchema, + cacheControl: AnthropicCacheControl? = nil) + { + self.name = name + self.description = description + self.inputSchema = inputSchema + self.cacheControl = cacheControl + } +} + +/// JSON Schema for tool parameters +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicJSONSchema: Codable, Sendable { + public let type: String + public let properties: [String: AnthropicPropertySchema]? + public let required: [String]? + public let description: String? + + public init( + type: String = "object", + properties: [String: AnthropicPropertySchema]? = nil, + required: [String]? = nil, + description: String? = nil) + { + self.type = type + self.properties = properties + self.required = required + self.description = description + } +} + +/// Tool choice configuration +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum AnthropicToolChoice: Codable, Sendable { + case auto + case any + case tool(name: String) + + // Custom encoding/decoding + enum CodingKeys: String, CodingKey { + case type, name + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + + switch type { + case "auto": + self = .auto + case "any": + self = .any + case "tool": + let name = try container.decode(String.self, forKey: .name) + self = .tool(name: name) + default: + throw DecodingError.dataCorruptedError( + forKey: .type, + in: container, + debugDescription: "Unknown tool choice type: \(type)") + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case .auto: + try container.encode("auto", forKey: .type) + case .any: + try container.encode("any", forKey: .type) + case let .tool(name): + try container.encode("tool", forKey: .type) + try container.encode(name, forKey: .name) + } + } +} + +/// Metadata for requests +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicMetadata: Codable, Sendable { + public let userId: String? + + enum CodingKeys: String, CodingKey { + case userId = "user_id" + } +} + +// MARK: - Response Types + +/// Response from Anthropic's Messages API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicResponse: Codable, Sendable { + public let id: String + public let type: String + public let role: AnthropicRole + public let content: [AnthropicContentBlock] + public let model: String + public let stopReason: String? + public let stopSequence: String? + public let usage: AnthropicUsage + + enum CodingKeys: String, CodingKey { + case id, type, role, content, model + case stopReason = "stop_reason" + case stopSequence = "stop_sequence" + case usage + } +} + +/// Usage statistics +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicUsage: Codable, Sendable { + public let inputTokens: Int + public let outputTokens: Int + + enum CodingKeys: String, CodingKey { + case inputTokens = "input_tokens" + case outputTokens = "output_tokens" + } +} + +// MARK: - Streaming Types + +/// Server-sent event for streaming +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicStreamEvent: Codable, Sendable { + public let type: String + + // Common fields + public let index: Int? + public let delta: AnthropicDelta? + + // Message start + public let message: AnthropicStreamMessage? + + // Content block start + public let contentBlock: AnthropicContentBlock? + + // Message complete + public let usage: AnthropicUsage? + public let stopReason: String? + public let stopSequence: String? + + // Error + public let error: AnthropicError? + + enum CodingKeys: String, CodingKey { + case type, index, delta, message + case contentBlock = "content_block" + case usage + case stopReason = "stop_reason" + case stopSequence = "stop_sequence" + case error + } +} + +/// Stream message metadata +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicStreamMessage: Codable, Sendable { + public let id: String + public let type: String + public let role: AnthropicRole + public let content: [AnthropicContentBlock] + public let model: String + public let usage: AnthropicUsage + + enum CodingKeys: String, CodingKey { + case id, type, role, content, model, usage + } +} + +/// Delta updates for streaming +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicDelta: Codable, Sendable { + // Text delta + public let text: String? + + // Message delta fields + public let stopReason: String? + public let stopSequence: String? + + // Tool use delta + public let type: String? + public let partialJson: String? + + enum CodingKeys: String, CodingKey { + case text + case stopReason = "stop_reason" + case stopSequence = "stop_sequence" + case type + case partialJson = "partial_json" + } +} + +// MARK: - Error Types + +/// Anthropic error response +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicErrorResponse: Codable, Sendable { + public let error: AnthropicError + + public var message: String { + self.error.message + } + + public var code: String? { + nil // Anthropic doesn't provide error codes + } + + public var type: String? { + self.error.type + } +} + +/// Error details +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicError: Codable, Sendable { + public let type: String + public let message: String +} + +// MARK: - Helper Extensions + +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +extension AnthropicContentBlock { + /// Create a text content block + public static func text(_ text: String, cacheControl: AnthropicCacheControl? = nil) -> AnthropicContentBlock { + AnthropicContentBlock( + type: "text", + text: text, + source: nil, + id: nil, + name: nil, + input: nil, + toolUseId: nil, + content: nil, + isError: nil, + cacheControl: cacheControl) + } + + /// Create an image content block + public static func image( + base64: String, + mediaType: String, + cacheControl: AnthropicCacheControl? = nil) -> AnthropicContentBlock + { + AnthropicContentBlock( + type: "image", + text: nil, + source: AnthropicImageSource( + type: "base64", + mediaType: mediaType, + data: base64), + id: nil, + name: nil, + input: nil, + toolUseId: nil, + content: nil, + isError: nil, + cacheControl: cacheControl) + } + + /// Create a tool use content block + public static func toolUse(id: String, name: String, input: [String: Any]) -> AnthropicContentBlock { + AnthropicContentBlock( + type: "tool_use", + text: nil, + source: nil, + id: id, + name: name, + input: input.compactMapValues { AnthropicInputValue(from: $0) }, + toolUseId: nil, + content: nil, + isError: nil, + cacheControl: nil) + } + + /// Create a tool result content block + public static func toolResult(toolUseId: String, content: String, isError: Bool = false) -> AnthropicContentBlock { + AnthropicContentBlock( + type: "tool_result", + text: nil, + source: nil, + id: nil, + name: nil, + input: nil, + toolUseId: toolUseId, + content: .string(content), + isError: isError, + cacheControl: nil) + } +} + +// MARK: - Input Value Types + +/// Type-safe input value for Anthropic tool use +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public enum AnthropicInputValue: Codable, Sendable { + case string(String) + case int(Int) + case double(Double) + case bool(Bool) + case array([AnthropicInputValue]) + case object([String: AnthropicInputValue]) + case null + + // MARK: - Codable + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if container.decodeNil() { + self = .null + } else if let string = try? container.decode(String.self) { + self = .string(string) + } else if let int = try? container.decode(Int.self) { + self = .int(int) + } else if let double = try? container.decode(Double.self) { + self = .double(double) + } else if let bool = try? container.decode(Bool.self) { + self = .bool(bool) + } else if let array = try? container.decode([AnthropicInputValue].self) { + self = .array(array) + } else if let dict = try? container.decode([String: AnthropicInputValue].self) { + self = .object(dict) + } else { + throw DecodingError.dataCorruptedError(in: container, debugDescription: "Cannot decode AnthropicInputValue") + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + + switch self { + case let .string(value): + try container.encode(value) + case let .int(value): + try container.encode(value) + case let .double(value): + try container.encode(value) + case let .bool(value): + try container.encode(value) + case let .array(values): + try container.encode(values) + case let .object(dict): + try container.encode(dict) + case .null: + try container.encodeNil() + } + } + + // MARK: - Conversion + + /// Convert from Any value (for migration) + public init?(from value: Any) { + switch value { + case let string as String: + self = .string(string) + case let int as Int: + self = .int(int) + case let double as Double: + self = .double(double) + case let bool as Bool: + self = .bool(bool) + case let array as [Any]: + let values = array.compactMap { AnthropicInputValue(from: $0) } + if values.count == array.count { + self = .array(values) + } else { + return nil + } + case let dict as [String: Any]: + var values: [String: AnthropicInputValue] = [:] + for (key, val) in dict { + if let inputValue = AnthropicInputValue(from: val) { + values[key] = inputValue + } else { + return nil + } + } + self = .object(values) + case is NSNull: + self = .null + default: + return nil + } + } + + /// Convert to Any (for JSON serialization) + public func toAny() -> Any { + switch self { + case let .string(value): + value + case let .int(value): + value + case let .double(value): + value + case let .bool(value): + value + case let .array(values): + values.map { $0.toAny() } + case let .object(dict): + dict.mapValues { $0.toAny() } + case .null: + NSNull() + } + } +} + +/// Type-safe property schema for Anthropic JSON Schema +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnthropicPropertySchema: Codable, Sendable { + public let type: String + public let description: String? + public let `enum`: [String]? + public let items: Box? + public let properties: [String: AnthropicPropertySchema]? + public let required: [String]? + + public init( + type: String, + description: String? = nil, + enum enumValues: [String]? = nil, + items: AnthropicPropertySchema? = nil, + properties: [String: AnthropicPropertySchema]? = nil, + required: [String]? = nil) + { + self.type = type + self.description = description + self.enum = enumValues + self.items = items.map(Box.init) + self.properties = properties + self.required = required + } +} \ No newline at end of file diff --git a/Sources/Providers/Grok/GrokModel.swift b/Sources/Providers/Grok/GrokModel.swift new file mode 100644 index 0000000..d397ece --- /dev/null +++ b/Sources/Providers/Grok/GrokModel.swift @@ -0,0 +1,479 @@ +import Foundation + +/// Grok model implementation using OpenAI-compatible Chat Completions API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public final class GrokModel: ModelInterface, Sendable { + private let apiKey: String + private let modelName: String + private let baseURL: URL + private let session: URLSession + + public init( + apiKey: String, + modelName: String = "grok-4-0709", + baseURL: URL = URL(string: "https://api.x.ai/v1")!, + session: URLSession? = nil) + { + self.apiKey = apiKey + self.modelName = modelName + self.baseURL = baseURL + + // Create custom session with appropriate timeout + if let session { + self.session = session + } else { + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 300 // 5 minutes + config.timeoutIntervalForResource = 300 + self.session = URLSession(configuration: config) + } + } + + // MARK: - ModelInterface Implementation + + public var maskedApiKey: String { + guard self.apiKey.count > 8 else { return "***" } + let start = self.apiKey.prefix(6) + let end = self.apiKey.suffix(2) + return "\(start)...\(end)" + } + + public func getResponse(request: ModelRequest) async throws -> ModelResponse { + let grokRequest = try convertToGrokRequest(request, stream: false) + let urlRequest = try createURLRequest(endpoint: "chat/completions", body: grokRequest) + + let (data, response) = try await session.data(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + throw TachikomaError.networkError(underlying: URLError(.badServerResponse)) + } + + if httpResponse.statusCode != 200 { + try handleErrorResponse(data: data, response: httpResponse) + } + + let chatResponse = try JSONDecoder().decode(GrokChatCompletionResponse.self, from: data) + return try self.convertFromGrokResponse(chatResponse) + } + + public func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + let grokRequest = try convertToGrokRequest(request, stream: true) + let urlRequest = try createURLRequest(endpoint: "chat/completions", body: grokRequest) + + return AsyncThrowingStream { continuation in + Task { + do { + let (bytes, response) = try await session.bytes(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + continuation.finish(throwing: TachikomaError.networkError(underlying: URLError(.badServerResponse))) + return + } + + if httpResponse.statusCode != 200 { + var errorData = Data() + for try await byte in bytes.prefix(1024) { + errorData.append(byte) + } + try self.handleErrorResponse(data: errorData, response: httpResponse) + } + + // Process SSE stream + var currentToolCalls: [String: GrokPartialToolCall] = [:] + + for try await line in bytes.lines { + // Handle SSE format + if line.hasPrefix("data: ") { + let data = String(line.dropFirst(6)) + + if data == "[DONE]" { + // Send any pending tool calls + for (id, toolCall) in currentToolCalls { + if let completed = toolCall.toCompleted() { + continuation.yield(.toolCallCompleted( + StreamToolCallCompleted(id: id, function: completed))) + } + } + continuation.finish() + return + } + + // Parse chunk + if let chunkData = data.data(using: .utf8), + let chunk = try? JSONDecoder().decode( + GrokChatCompletionChunk.self, + from: chunkData) + { + if let events = self.processGrokChunk(chunk, toolCalls: ¤tToolCalls) { + for event in events { + continuation.yield(event) + } + } + } + } + } + + continuation.finish() + } catch { + continuation.finish(throwing: TachikomaError.streamingError(error.localizedDescription)) + } + } + } + } + + // MARK: - Private Helper Methods + + private func createURLRequest(endpoint: String, body: any Encodable) throws -> URLRequest { + let url = self.baseURL.appendingPathComponent(endpoint) + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("Bearer \(self.apiKey)", forHTTPHeaderField: "Authorization") + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + let encoder = JSONEncoder() + encoder.outputFormatting = .sortedKeys + + do { + request.httpBody = try encoder.encode(body) + } catch { + throw TachikomaError.configurationError("Failed to encode Grok request: \(error.localizedDescription)") + } + + request.timeoutInterval = 300 // 5 minutes for Grok + + return request + } + + private func convertToGrokRequest(_ request: ModelRequest, stream: Bool) throws -> GrokChatCompletionRequest { + // Convert messages to OpenAI-compatible format + let messages = try request.messages.map { message -> GrokMessage in + switch message { + case let .system(_, content): + return GrokMessage(role: "system", content: .string(content), toolCalls: nil, toolCallId: nil) + + case let .user(_, content): + return try self.convertUserMessageContent(content) + + case let .assistant(_, content, _): + return try self.convertAssistantMessageContent(content) + + case let .tool(_, toolCallId, content): + return GrokMessage( + role: "tool", + content: .string(content), + toolCalls: nil, + toolCallId: toolCallId) + + case .reasoning: + throw TachikomaError.invalidRequest("Reasoning messages not supported in Grok") + } + } + + // Convert tools to OpenAI-compatible format if present + let tools = request.tools?.map { toolDef -> GrokTool in + GrokTool( + type: "function", + function: GrokTool.Function( + name: toolDef.function.name, + description: toolDef.function.description, + parameters: self.convertToolParameters(toolDef.function.parameters))) + } + + // Filter parameters for Grok 4 + let temperature = request.settings.temperature + var frequencyPenalty = request.settings.frequencyPenalty + var presencePenalty = request.settings.presencePenalty + var stop = request.settings.stopSequences + + if self.modelName.contains("grok-4") || self.modelName.contains("grok-3") { + // Grok 3 and 4 models don't support these parameters + frequencyPenalty = nil + presencePenalty = nil + stop = nil + } + + return GrokChatCompletionRequest( + model: self.modelName, + messages: messages, + tools: tools, + toolChoice: self.convertToolChoice(request.settings.toolChoice), + temperature: temperature, + maxTokens: request.settings.maxTokens, + stream: stream, + frequencyPenalty: frequencyPenalty, + presencePenalty: presencePenalty, + stop: stop) + } + + private func convertUserMessageContent(_ content: MessageContent) throws -> GrokMessage { + switch content { + case let .text(text): + return GrokMessage(role: "user", content: .string(text), toolCalls: nil, toolCallId: nil) + + case let .image(imageContent): + var content: [GrokMessageContentPart] = [] + + if let url = imageContent.url { + content.append(GrokMessageContentPart( + type: "image_url", + text: nil, + imageUrl: GrokImageUrl( + url: url, + detail: imageContent.detail?.rawValue))) + } else if let base64 = imageContent.base64 { + content.append(GrokMessageContentPart( + type: "image_url", + text: nil, + imageUrl: GrokImageUrl( + url: "data:image/jpeg;base64,\(base64)", + detail: imageContent.detail?.rawValue))) + } + + return GrokMessage(role: "user", content: .array(content), toolCalls: nil, toolCallId: nil) + + case let .multimodal(parts): + let content = parts.compactMap { part -> GrokMessageContentPart? in + if let text = part.text { + return GrokMessageContentPart( + type: "text", + text: text, + imageUrl: nil) + } else if let image = part.imageUrl { + if let url = image.url { + return GrokMessageContentPart( + type: "image_url", + text: nil, + imageUrl: GrokImageUrl(url: url, detail: image.detail?.rawValue)) + } else if let base64 = image.base64 { + return GrokMessageContentPart( + type: "image_url", + text: nil, + imageUrl: GrokImageUrl( + url: "data:image/jpeg;base64,\(base64)", + detail: image.detail?.rawValue)) + } + } + return nil + } + return GrokMessage(role: "user", content: .array(content), toolCalls: nil, toolCallId: nil) + + case .file: + throw TachikomaError.invalidRequest("File content not supported in Grok chat completions") + + case let .audio(audioContent): + // Grok doesn't support native audio, so we need to use the transcript + if let transcript = audioContent.transcript { + // Include metadata about the audio source + var text = transcript + if let duration = audioContent.duration { + text = "[Audio transcript, duration: \(Int(duration))s] \(transcript)" + } else { + text = "[Audio transcript] \(transcript)" + } + return GrokMessage(role: "user", content: .string(text), toolCalls: nil, toolCallId: nil) + } else { + throw TachikomaError.invalidRequest("Audio content must be transcribed before sending to Grok. Please ensure transcript is provided.") + } + } + } + + private func convertAssistantMessageContent(_ contentArray: [AssistantContent]) throws -> GrokMessage { + var textContent = "" + var toolCalls: [GrokToolCall] = [] + + for content in contentArray { + switch content { + case let .outputText(text): + textContent += text + + case let .refusal(refusal): + return GrokMessage(role: "assistant", content: .string(refusal), toolCalls: nil, toolCallId: nil) + + case let .toolCall(toolCall): + toolCalls.append(GrokToolCall( + id: toolCall.id, + type: toolCall.type.rawValue, + function: GrokFunctionCall( + name: toolCall.function.name, + arguments: toolCall.function.arguments))) + } + } + + // Include tool calls if present + if !toolCalls.isEmpty { + return GrokMessage( + role: "assistant", + content: textContent.isEmpty ? nil : .string(textContent), + toolCalls: toolCalls, + toolCallId: nil) + } + + return GrokMessage(role: "assistant", content: .string(textContent), toolCalls: nil, toolCallId: nil) + } + + private func convertToolParameters(_ params: ToolParameters) -> GrokTool.Parameters { + let (type, properties, required) = params.toGrokParameters() + return GrokTool.Parameters( + type: type, + properties: properties, + required: required) + } + + private func convertToolChoice(_ toolChoice: ToolChoice?) -> GrokToolChoice? { + guard let toolChoice else { return nil } + + switch toolChoice { + case .auto: + return .string("auto") + case .none: + return .string("none") + case .required: + return .string("required") + case let .specific(toolName): + return .object(GrokToolChoiceObject( + type: "function", + function: GrokToolChoiceFunction(name: toolName))) + } + } + + private func convertFromGrokResponse(_ response: GrokChatCompletionResponse) throws -> ModelResponse { + guard let choice = response.choices.first else { + throw TachikomaError.apiError(message: "No choices in response") + } + + var content: [AssistantContent] = [] + + // Add text content if present + if let textContent = choice.message.content { + content.append(.outputText(textContent)) + } + + // Add tool calls if present + if let toolCalls = choice.message.toolCalls { + for toolCall in toolCalls { + content.append(.toolCall(ToolCallItem( + id: toolCall.id, + type: .function, + function: FunctionCall( + name: toolCall.function.name, + arguments: toolCall.function.arguments)))) + } + } + + let usage = response.usage.map { usage in + Usage( + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + totalTokens: usage.totalTokens, + promptTokensDetails: nil, + completionTokensDetails: nil) + } + + return ModelResponse( + id: response.id, + model: response.model, + content: content, + usage: usage, + flagged: false, + finishReason: self.convertFinishReason(choice.finishReason)) + } + + private func convertFinishReason(_ reason: String?) -> FinishReason? { + guard let reason else { return nil } + return FinishReason(rawValue: reason) + } + + private func processGrokChunk( + _ chunk: GrokChatCompletionChunk, + toolCalls: inout [String: GrokPartialToolCall]) -> [StreamEvent]? + { + var events: [StreamEvent] = [] + + // First chunk often contains metadata + if !chunk.id.isEmpty, chunk.model.isEmpty == false { + events.append(.responseStarted(StreamResponseStarted( + id: chunk.id, + model: chunk.model, + systemFingerprint: chunk.systemFingerprint))) + } + + for choice in chunk.choices { + let delta = choice.delta + + // Handle text content + if let content = delta.content, !content.isEmpty { + events.append(.textDelta(StreamTextDelta(delta: content, index: choice.index))) + } + + // Handle tool calls + if let deltaToolCalls = delta.toolCalls { + for toolCallDelta in deltaToolCalls { + let toolCallId = toolCallDelta.id ?? "" + + if toolCalls[toolCallId] == nil { + let partialCall = GrokPartialToolCall(from: toolCallDelta) + toolCalls[toolCallId] = partialCall + } else { + toolCalls[toolCallId]?.update(with: toolCallDelta) + } + + // Emit delta event + if let functionDelta = toolCallDelta.function { + events.append(.toolCallDelta(StreamToolCallDelta( + id: toolCallId, + index: toolCallDelta.index, + function: FunctionCallDelta( + name: functionDelta.name, + arguments: functionDelta.arguments)))) + } + } + } + + // Handle finish reason + if let finishReason = choice.finishReason { + // If this is a tool call finish, emit completed events + if finishReason == "tool_calls" { + for (id, toolCall) in toolCalls { + if let completed = toolCall.toCompleted() { + events.append(.toolCallCompleted( + StreamToolCallCompleted(id: id, function: completed))) + } + } + } + + events.append(.responseCompleted(StreamResponseCompleted( + id: chunk.id, + usage: nil, + finishReason: FinishReason(rawValue: finishReason)))) + } + } + + return events.isEmpty ? nil : events + } + + private func handleErrorResponse(data: Data, response: HTTPURLResponse) throws { + if let errorResponse = try? JSONDecoder().decode(GrokErrorResponse.self, from: data) { + let message = errorResponse.error.message + + switch response.statusCode { + case 401: + throw TachikomaError.authenticationFailed + case 429: + throw TachikomaError.rateLimited + case 400: + if message.contains("credit") || message.contains("usage") { + throw TachikomaError.insufficientQuota + } else { + throw TachikomaError.invalidRequest(message) + } + case 500...599: + throw TachikomaError.modelOverloaded + default: + throw TachikomaError.apiError(message: message, code: errorResponse.error.code) + } + } else { + throw TachikomaError.apiError(message: "HTTP \(response.statusCode)") + } + } +} + diff --git a/Sources/Providers/Grok/GrokTypes.swift b/Sources/Providers/Grok/GrokTypes.swift new file mode 100644 index 0000000..cbee943 --- /dev/null +++ b/Sources/Providers/Grok/GrokTypes.swift @@ -0,0 +1,346 @@ +import Foundation + +// MARK: - Helper Types + +internal class GrokPartialToolCall { + var id: String = "" + var type: String = "function" + var index: Int = 0 + var name: String? + var arguments: String = "" + + init() { + // Default initializer + } + + init(from delta: GrokToolCallDelta) { + self.id = delta.id ?? "" + self.index = delta.index + self.name = delta.function?.name + self.arguments = delta.function?.arguments ?? "" + } + + func update(with delta: GrokToolCallDelta) { + if let funcName = delta.function?.name { + self.name = funcName + } + if let args = delta.function?.arguments { + self.arguments += args + } + } + + func toCompleted() -> FunctionCall? { + guard let name else { return nil } + return FunctionCall(name: name, arguments: self.arguments) + } +} + +// MARK: - Grok Request Types + +internal struct GrokChatCompletionRequest: Encodable { + let model: String + let messages: [GrokMessage] + let tools: [GrokTool]? + let toolChoice: GrokToolChoice? + let temperature: Double? + let maxTokens: Int? + let stream: Bool + let frequencyPenalty: Double? + let presencePenalty: Double? + let stop: [String]? + + enum CodingKeys: String, CodingKey { + case model, messages, tools, temperature, stream + case toolChoice = "tool_choice" + case maxTokens = "max_tokens" + case frequencyPenalty = "frequency_penalty" + case presencePenalty = "presence_penalty" + case stop + } +} + +internal struct GrokMessage: Encodable { + let role: String + let content: GrokMessageContent? + let toolCalls: [GrokToolCall]? + let toolCallId: String? + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + case toolCallId = "tool_call_id" + } +} + +internal enum GrokMessageContent: Encodable { + case string(String) + case array([GrokMessageContentPart]) + + func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(text): + try container.encode(text) + case let .array(parts): + try container.encode(parts) + } + } +} + +internal struct GrokMessageContentPart: Encodable { + let type: String + let text: String? + let imageUrl: GrokImageUrl? + + enum CodingKeys: String, CodingKey { + case type, text + case imageUrl = "image_url" + } +} + +internal struct GrokImageUrl: Encodable { + let url: String + let detail: String? +} + +internal struct GrokToolCall: Encodable { + let id: String + let type: String + let function: GrokFunctionCall +} + +internal struct GrokFunctionCall: Encodable { + let name: String + let arguments: String +} + +internal struct GrokTool: Encodable { + let type: String + let function: Function + + struct Function: Encodable { + let name: String + let description: String? + let parameters: Parameters + } + + struct Parameters: Encodable { + let type: String + let properties: [String: GrokPropertySchema] + let required: [String] + } +} + +internal enum GrokToolChoice: Encodable { + case string(String) + case object(GrokToolChoiceObject) + + func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .object(obj): + try container.encode(obj) + } + } +} + +internal struct GrokToolChoiceObject: Encodable { + let type: String + let function: GrokToolChoiceFunction +} + +internal struct GrokToolChoiceFunction: Encodable { + let name: String +} + +// MARK: - Response Types + +internal struct GrokChatCompletionResponse: Decodable { + let id: String + let model: String + let choices: [Choice] + let usage: Usage? + + struct Choice: Decodable { + let message: Message + let finishReason: String? + + enum CodingKeys: String, CodingKey { + case message + case finishReason = "finish_reason" + } + } + + struct Message: Decodable { + let role: String + let content: String? + let toolCalls: [ToolCall]? + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + } + + struct ToolCall: Decodable { + let id: String + let type: String + let function: Function + + struct Function: Decodable { + let name: String + let arguments: String + } + } + } + + struct Usage: Decodable { + let promptTokens: Int + let completionTokens: Int + let totalTokens: Int + + enum CodingKeys: String, CodingKey { + case promptTokens = "prompt_tokens" + case completionTokens = "completion_tokens" + case totalTokens = "total_tokens" + } + } +} + +// MARK: - Streaming Types + +internal struct GrokChatCompletionChunk: Decodable { + let id: String + let model: String + let choices: [StreamChoice] + let systemFingerprint: String? + + enum CodingKeys: String, CodingKey { + case id, model, choices + case systemFingerprint = "system_fingerprint" + } + + struct StreamChoice: Decodable { + let index: Int + let delta: Delta + let finishReason: String? + + enum CodingKeys: String, CodingKey { + case index, delta + case finishReason = "finish_reason" + } + + struct Delta: Decodable { + let role: String? + let content: String? + let toolCalls: [GrokToolCallDelta]? + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + } + } + } +} + +internal struct GrokToolCallDelta: Decodable { + let index: Int + let id: String? + let type: String? + let function: StreamFunction? + + struct StreamFunction: Decodable { + let name: String? + let arguments: String? + } +} + +// MARK: - Error Types + +internal struct GrokErrorResponse: Decodable { + let error: GrokError + + var message: String { + self.error.message + } + + var code: String? { + self.error.code + } + + var type: String? { + self.error.type + } +} + +internal struct GrokError: Decodable { + let message: String + let type: String + let code: String? +} + +// MARK: - Property Schema + +/// Type-safe property schema for Grok tool parameters +internal struct GrokPropertySchema: Codable, Sendable { + let type: String + let description: String? + let `enum`: [String]? + let items: Box? + let properties: [String: GrokPropertySchema]? + let minimum: Double? + let maximum: Double? + let pattern: String? + let required: [String]? + + init( + type: String, + description: String? = nil, + enum enumValues: [String]? = nil, + items: GrokPropertySchema? = nil, + properties: [String: GrokPropertySchema]? = nil, + minimum: Double? = nil, + maximum: Double? = nil, + pattern: String? = nil, + required: [String]? = nil) + { + self.type = type + self.description = description + self.enum = enumValues + self.items = items.map(Box.init) + self.properties = properties + self.minimum = minimum + self.maximum = maximum + self.pattern = pattern + self.required = required + } + + /// Create from a ParameterSchema + init(from schema: ParameterSchema) { + self.type = schema.type.rawValue + self.description = schema.description + self.enum = schema.enumValues + self.items = schema.items.map { Box(GrokPropertySchema(from: $0.value)) } + self.properties = schema.properties?.mapValues { GrokPropertySchema(from: $0) } + self.minimum = schema.minimum + self.maximum = schema.maximum + self.pattern = schema.pattern + self.required = nil + } +} + +// MARK: - Extensions + +/// Helper to convert ToolParameters to Grok-compatible structure +internal extension ToolParameters { + func toGrokParameters() -> (type: String, properties: [String: GrokPropertySchema], required: [String]) { + var grokProperties: [String: GrokPropertySchema] = [:] + + for (key, schema) in properties { + grokProperties[key] = GrokPropertySchema(from: schema) + } + + return (type: type, properties: grokProperties, required: required) + } +} \ No newline at end of file diff --git a/Sources/Providers/Ollama/OllamaModel.swift b/Sources/Providers/Ollama/OllamaModel.swift new file mode 100644 index 0000000..e58ffd2 --- /dev/null +++ b/Sources/Providers/Ollama/OllamaModel.swift @@ -0,0 +1,336 @@ +import Foundation + +/// Ollama model implementation conforming to ModelInterface +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public final class OllamaModel: ModelInterface, Sendable { + private let modelName: String + private let baseURL: URL + private let session: URLSession + + public init( + modelName: String, + baseURL: URL = URL(string: "http://localhost:11434")!, + session: URLSession? = nil) + { + self.modelName = modelName + self.baseURL = baseURL + + if let session { + self.session = session + } else { + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 300 // 5 minutes for local models + config.timeoutIntervalForResource = 300 + self.session = URLSession(configuration: config) + } + } + + // MARK: - ModelInterface Implementation + + public var maskedApiKey: String { + "local-ollama" + } + + public func getResponse(request: ModelRequest) async throws -> ModelResponse { + let ollamaRequest = try convertToOllamaRequest(request, stream: false) + let urlRequest = try createURLRequest(endpoint: "api/chat", body: ollamaRequest) + + let (data, response) = try await session.data(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + throw TachikomaError.networkError(underlying: URLError(.badServerResponse)) + } + + if httpResponse.statusCode != 200 { + try handleErrorResponse(data: data, response: httpResponse) + } + + let ollamaResponse = try JSONDecoder().decode(OllamaChatResponse.self, from: data) + return try self.convertFromOllamaResponse(ollamaResponse) + } + + public func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + let ollamaRequest = try convertToOllamaRequest(request, stream: true) + let urlRequest = try createURLRequest(endpoint: "api/chat", body: ollamaRequest) + + return AsyncThrowingStream { continuation in + Task { + do { + let (bytes, response) = try await session.bytes(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + continuation.finish(throwing: TachikomaError.networkError(underlying: URLError(.badServerResponse))) + return + } + + if httpResponse.statusCode != 200 { + var errorData = Data() + for try await byte in bytes.prefix(1024) { + errorData.append(byte) + } + try self.handleErrorResponse(data: errorData, response: httpResponse) + } + + // Process JSON stream (one JSON object per line) + for try await line in bytes.lines { + if line.isEmpty { continue } + + if let data = line.data(using: .utf8), + let chunk = try? JSONDecoder().decode(OllamaChatChunk.self, from: data) { + + if let events = self.processOllamaChunk(chunk) { + for event in events { + continuation.yield(event) + } + } + + // Check if done + if chunk.done { + continuation.finish() + return + } + } + } + + continuation.finish() + } catch { + continuation.finish(throwing: TachikomaError.streamingError(error.localizedDescription)) + } + } + } + } + + // MARK: - Private Methods + + private func createURLRequest(endpoint: String, body: any Encodable) throws -> URLRequest { + let url = self.baseURL.appendingPathComponent(endpoint) + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + let encoder = JSONEncoder() + encoder.outputFormatting = .sortedKeys + + do { + request.httpBody = try encoder.encode(body) + } catch { + throw TachikomaError.configurationError("Failed to encode Ollama request: \(error.localizedDescription)") + } + + request.timeoutInterval = 300 // 5 minutes for local models + + return request + } + + private func convertToOllamaRequest(_ request: ModelRequest, stream: Bool) throws -> OllamaChatRequest { + // Convert messages + let messages = try request.messages.compactMap { message -> OllamaMessage? in + switch message { + case let .system(_, content): + return OllamaMessage(role: "system", content: content, images: nil) + + case let .user(_, content): + return try convertUserMessage(content) + + case let .assistant(_, content, _): + return convertAssistantMessage(content) + + case let .tool(_, _, content): + // Ollama handles tool results differently - might need adaptation + return OllamaMessage(role: "user", content: "Tool result: \(content)", images: nil) + + case .reasoning: + // Skip reasoning messages for Ollama + return nil + } + } + + // Convert tools (if supported by the model) + let tools = request.tools?.map { toolDef -> OllamaTool in + OllamaTool( + type: "function", + function: OllamaFunction( + name: toolDef.function.name, + description: toolDef.function.description, + parameters: convertToolParameters(toolDef.function.parameters))) + } + + return OllamaChatRequest( + model: self.modelName, + messages: messages, + tools: tools, + stream: stream, + options: OllamaOptions( + temperature: request.settings.temperature, + topP: request.settings.topP, + stop: request.settings.stopSequences)) + } + + private func convertUserMessage(_ content: MessageContent) throws -> OllamaMessage { + switch content { + case let .text(text): + return OllamaMessage(role: "user", content: text, images: nil) + + case let .image(imageContent): + // Ollama supports base64 images + if let base64 = imageContent.base64 { + return OllamaMessage(role: "user", content: "", images: [base64]) + } else if imageContent.url != nil { + throw TachikomaError.invalidRequest("Image URLs not supported in Ollama - please provide base64 data") + } else { + throw TachikomaError.invalidRequest("No image data provided") + } + + case let .multimodal(parts): + var text = "" + var images: [String] = [] + + for part in parts { + if let partText = part.text { + text += partText + } else if let image = part.imageUrl { + if let base64 = image.base64 { + images.append(base64) + } else if image.url != nil { + throw TachikomaError.invalidRequest("Image URLs not supported in Ollama - please provide base64 data") + } + } + } + + return OllamaMessage(role: "user", content: text, images: images.isEmpty ? nil : images) + + case .file: + throw TachikomaError.invalidRequest("File content not supported in Ollama API") + + case let .audio(audioContent): + // Ollama doesn't support native audio, use transcript if available + if let transcript = audioContent.transcript { + var text = transcript + if let duration = audioContent.duration { + text = "[Audio transcript, duration: \(Int(duration))s] \(transcript)" + } else { + text = "[Audio transcript] \(transcript)" + } + return OllamaMessage(role: "user", content: text, images: nil) + } else { + throw TachikomaError.invalidRequest("Audio content must be transcribed before sending to Ollama") + } + } + } + + private func convertAssistantMessage(_ content: [AssistantContent]) -> OllamaMessage { + var text = "" + + for item in content { + switch item { + case let .outputText(outputText): + text += outputText + case let .refusal(refusal): + text += refusal + case let .toolCall(toolCall): + // Convert tool call to text representation for now + text += "\n[Tool Call: \(toolCall.function.name)(\(toolCall.function.arguments))]" + } + } + + return OllamaMessage(role: "assistant", content: text, images: nil) + } + + private func convertToolParameters(_ params: ToolParameters) -> [String: Any] { + var properties: [String: Any] = [:] + + for (key, schema) in params.properties { + properties[key] = convertParameterSchema(schema) + } + + return [ + "type": params.type, + "properties": properties, + "required": params.required + ] + } + + private func convertParameterSchema(_ schema: ParameterSchema) -> [String: Any] { + var result: [String: Any] = [ + "type": schema.type.rawValue + ] + + if let description = schema.description { + result["description"] = description + } + + if let enumValues = schema.enumValues { + result["enum"] = enumValues + } + + return result + } + + private func convertFromOllamaResponse(_ response: OllamaChatResponse) throws -> ModelResponse { + let content: [AssistantContent] = [.outputText(response.message.content)] + + // Ollama doesn't provide detailed usage info in the same format + let usage = Usage( + promptTokens: 0, // Not provided by Ollama + completionTokens: 0, // Not provided by Ollama + totalTokens: 0, // Not provided by Ollama + promptTokensDetails: nil, + completionTokensDetails: nil) + + return ModelResponse( + id: UUID().uuidString, // Ollama doesn't provide response IDs + model: response.model, + content: content, + usage: usage, + flagged: false, + finishReason: response.done ? .stop : nil) + } + + private func processOllamaChunk(_ chunk: OllamaChatChunk) -> [StreamEvent]? { + var events: [StreamEvent] = [] + + // First chunk with model info + if !chunk.model.isEmpty && events.isEmpty { + events.append(.responseStarted(StreamResponseStarted( + id: UUID().uuidString, + model: chunk.model, + systemFingerprint: nil))) + } + + // Text content + if let content = chunk.message?.content, !content.isEmpty { + events.append(.textDelta(StreamTextDelta(delta: content, index: 0))) + } + + // Done + if chunk.done { + events.append(.responseCompleted(StreamResponseCompleted( + id: UUID().uuidString, + usage: nil, + finishReason: .stop))) + } + + return events.isEmpty ? nil : events + } + + private func handleErrorResponse(data: Data, response: HTTPURLResponse) throws { + // Try to decode Ollama error format + if let errorText = String(data: data, encoding: .utf8) { + let message = errorText.isEmpty ? "HTTP \(response.statusCode)" : errorText + + switch response.statusCode { + case 400: + throw TachikomaError.invalidRequest(message) + case 404: + throw TachikomaError.modelNotFound(self.modelName) + case 500...599: + throw TachikomaError.modelOverloaded + default: + throw TachikomaError.apiError(message: message) + } + } else { + throw TachikomaError.apiError(message: "HTTP \(response.statusCode)") + } + } +} + diff --git a/Sources/Providers/Ollama/OllamaTypes.swift b/Sources/Providers/Ollama/OllamaTypes.swift new file mode 100644 index 0000000..c64799f --- /dev/null +++ b/Sources/Providers/Ollama/OllamaTypes.swift @@ -0,0 +1,69 @@ +import Foundation + +// MARK: - Ollama Request Types + +internal struct OllamaChatRequest: Encodable { + let model: String + let messages: [OllamaMessage] + let tools: [OllamaTool]? + let stream: Bool + let options: OllamaOptions? +} + +internal struct OllamaMessage: Codable { + let role: String + let content: String + let images: [String]? +} + +internal struct OllamaTool: Encodable { + let type: String + let function: OllamaFunction +} + +internal struct OllamaFunction: Encodable { + let name: String + let description: String + let parameters: [String: Any] + + enum CodingKeys: String, CodingKey { + case name, description, parameters + } + + func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encode(description, forKey: .description) + + // Encode parameters as JSON + let data = try JSONSerialization.data(withJSONObject: parameters) + let jsonString = String(data: data, encoding: .utf8) ?? "{}" + try container.encode(jsonString, forKey: .parameters) + } +} + +internal struct OllamaOptions: Encodable { + let temperature: Double? + let topP: Double? + let stop: [String]? + + enum CodingKeys: String, CodingKey { + case temperature + case topP = "top_p" + case stop + } +} + +// MARK: - Ollama Response Types + +internal struct OllamaChatResponse: Decodable { + let model: String + let message: OllamaMessage + let done: Bool +} + +internal struct OllamaChatChunk: Decodable { + let model: String + let message: OllamaMessage? + let done: Bool +} \ No newline at end of file diff --git a/Sources/Providers/OpenAI/OpenAIModel.swift b/Sources/Providers/OpenAI/OpenAIModel.swift new file mode 100644 index 0000000..8abd840 --- /dev/null +++ b/Sources/Providers/OpenAI/OpenAIModel.swift @@ -0,0 +1,516 @@ +import Foundation + +/// OpenAI model implementation conforming to ModelInterface +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public final class OpenAIModel: ModelInterface, Sendable { + private let apiKey: String + private let baseURL: URL + private let session: URLSession + private let organizationId: String? + private let customHeaders: [String: String]? + private let customModelName: String? + + public init( + apiKey: String, + baseURL: URL = URL(string: "https://api.openai.com/v1")!, + organizationId: String? = nil, + modelName: String? = nil, + headers: [String: String]? = nil, + session: URLSession? = nil) + { + self.apiKey = apiKey + self.baseURL = baseURL + self.organizationId = organizationId + self.customHeaders = headers + self.customModelName = modelName + + if let session { + self.session = session + } else { + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 600 // 10 minutes for o3 models + config.timeoutIntervalForResource = 600 + self.session = URLSession(configuration: config) + } + } + + // MARK: - ModelInterface Implementation + + public var maskedApiKey: String { + guard self.apiKey.count > 8 else { return "***" } + let start = self.apiKey.prefix(6) + let end = self.apiKey.suffix(2) + return "\(start)...\(end)" + } + + public func getResponse(request: ModelRequest) async throws -> ModelResponse { + let openAIRequest = try convertToOpenAIRequest(request, stream: false) + let endpoint = getEndpointForModel(request.settings.modelName) + let urlRequest = try createURLRequest(endpoint: endpoint, body: openAIRequest) + + let (data, response) = try await session.data(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + throw TachikomaError.networkError(underlying: URLError(.badServerResponse)) + } + + if httpResponse.statusCode != 200 { + try handleErrorResponse(data: data, response: httpResponse) + } + + do { + let openAIResponse = try JSONDecoder().decode(OpenAIResponse.self, from: data) + return try convertFromOpenAIResponse(openAIResponse) + } catch { + throw TachikomaError.decodingError(underlying: error) + } + } + + public func getStreamedResponse(request: ModelRequest) async throws -> AsyncThrowingStream { + let openAIRequest = try convertToOpenAIRequest(request, stream: true) + let endpoint = getEndpointForModel(request.settings.modelName) + let urlRequest = try createURLRequest(endpoint: endpoint, body: openAIRequest) + + return AsyncThrowingStream { continuation in + Task { + do { + let (bytes, response) = try await session.bytes(for: urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + continuation.finish(throwing: TachikomaError.networkError(underlying: URLError(.badServerResponse))) + return + } + + if httpResponse.statusCode != 200 { + var errorData = Data() + for try await byte in bytes.prefix(1024) { + errorData.append(byte) + } + try self.handleErrorResponse(data: errorData, response: httpResponse) + } + + // Process SSE stream + for try await line in bytes.lines { + if line.hasPrefix("data: ") { + let data = String(line.dropFirst(6)) + + if data == "[DONE]" { + continuation.finish() + return + } + + if let chunkData = data.data(using: .utf8), + let chunk = try? JSONDecoder().decode(OpenAIStreamChunk.self, from: chunkData) { + if let events = processStreamChunk(chunk) { + for event in events { + continuation.yield(event) + } + } + } + } + } + + continuation.finish() + } catch { + continuation.finish(throwing: TachikomaError.streamingError(error.localizedDescription)) + } + } + } + } + + // MARK: - Private Methods + + private func getEndpointForModel(_ modelName: String) -> String { + // Use Responses API for o3/o4 models, Chat Completions for others + if modelName.hasPrefix("o3") || modelName.hasPrefix("o4") { + return "responses" + } else { + return "chat/completions" + } + } + + private func createURLRequest(endpoint: String, body: any Encodable) throws -> URLRequest { + let url = self.baseURL.appendingPathComponent(endpoint) + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("Bearer \(self.apiKey)", forHTTPHeaderField: "Authorization") + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + if let orgId = organizationId { + request.setValue(orgId, forHTTPHeaderField: "OpenAI-Organization") + } + + customHeaders?.forEach { key, value in + request.setValue(value, forHTTPHeaderField: key) + } + + let encoder = JSONEncoder() + encoder.outputFormatting = .sortedKeys + + do { + request.httpBody = try encoder.encode(body) + } catch { + throw TachikomaError.configurationError("Failed to encode OpenAI request: \(error.localizedDescription)") + } + + // Set timeout for different API types + if endpoint == "responses" { + request.timeoutInterval = 600 // 10 minutes for o3 models + } else { + request.timeoutInterval = 120 // 2 minutes for other models + } + + return request + } + + private func convertToOpenAIRequest(_ request: ModelRequest, stream: Bool) throws -> any Encodable { + let modelName = customModelName ?? request.settings.modelName + + if modelName.hasPrefix("o3") || modelName.hasPrefix("o4") { + return try convertToResponsesRequest(request, stream: stream) + } else { + return try convertToChatRequest(request, stream: stream) + } + } + + private func convertToChatRequest(_ request: ModelRequest, stream: Bool) throws -> OpenAIChatRequest { + let messages = try request.messages.map { message -> OpenAIMessage in + switch message { + case let .system(_, content): + return OpenAIMessage(role: "system", content: .string(content)) + case let .user(_, content): + return try convertUserMessage(content) + case let .assistant(_, content, _): + return try convertAssistantMessage(content) + case let .tool(_, toolCallId, content): + return OpenAIMessage(role: "tool", content: .string(content), toolCallId: toolCallId) + case .reasoning: + throw TachikomaError.invalidRequest("Reasoning messages not supported in Chat Completions API") + } + } + + let tools = request.tools?.map { tool in + OpenAITool( + type: "function", + function: OpenAIFunction( + name: tool.function.name, + description: tool.function.description, + parameters: convertToolParameters(tool.function.parameters))) + } + + return OpenAIChatRequest( + model: customModelName ?? request.settings.modelName, + messages: messages, + tools: tools, + toolChoice: convertToolChoice(request.settings.toolChoice), + temperature: request.settings.temperature, + topP: request.settings.topP, + stream: stream, + maxTokens: request.settings.maxTokens) + } + + private func convertToResponsesRequest(_ request: ModelRequest, stream: Bool) throws -> OpenAIResponsesRequest { + let messages = try request.messages.compactMap { message -> OpenAIMessage? in + switch message { + case let .system(_, content): + return OpenAIMessage(role: "system", content: .string(content)) + case let .user(_, content): + return try convertUserMessage(content) + case let .assistant(_, content, _): + return try convertAssistantMessage(content) + case let .tool(_, _, content): + return OpenAIMessage(role: "user", content: .string(content)) + case .reasoning: + return nil // Skip reasoning messages + } + } + + let tools = request.tools?.map { tool in + OpenAIResponsesTool( + type: "function", + name: tool.function.name, + description: tool.function.description, + parameters: convertToolParameters(tool.function.parameters)) + } + + let modelName = customModelName ?? request.settings.modelName + + return OpenAIResponsesRequest( + model: modelName, + input: messages, + tools: tools, + toolChoice: convertToolChoice(request.settings.toolChoice), + temperature: nil, // o3/o4 models don't support temperature + topP: request.settings.topP, + stream: stream, + maxOutputTokens: request.settings.maxTokens ?? 65536, + reasoning: (modelName.hasPrefix("o3") || modelName.hasPrefix("o4")) ? + OpenAIReasoning( + effort: request.settings.additionalParameters?.string("reasoning_effort") ?? "medium", + summary: "detailed") : nil) + } + + private func convertUserMessage(_ content: MessageContent) throws -> OpenAIMessage { + switch content { + case let .text(text): + return OpenAIMessage(role: "user", content: .string(text)) + case let .image(imageContent): + var parts: [OpenAIMessageContentPart] = [] + + if let url = imageContent.url { + parts.append(OpenAIMessageContentPart( + type: "image_url", + text: nil, + imageUrl: OpenAIImageUrl(url: url, detail: imageContent.detail?.rawValue))) + } else if let base64 = imageContent.base64 { + parts.append(OpenAIMessageContentPart( + type: "image_url", + text: nil, + imageUrl: OpenAIImageUrl(url: "data:image/jpeg;base64,\(base64)", detail: imageContent.detail?.rawValue))) + } + + return OpenAIMessage(role: "user", content: .array(parts)) + case let .multimodal(parts): + let contentParts = parts.compactMap { part -> OpenAIMessageContentPart? in + if let text = part.text { + return OpenAIMessageContentPart(type: "text", text: text, imageUrl: nil) + } else if let image = part.imageUrl { + if let url = image.url { + return OpenAIMessageContentPart( + type: "image_url", + text: nil, + imageUrl: OpenAIImageUrl(url: url, detail: image.detail?.rawValue)) + } else if let base64 = image.base64 { + return OpenAIMessageContentPart( + type: "image_url", + text: nil, + imageUrl: OpenAIImageUrl(url: "data:image/jpeg;base64,\(base64)", detail: image.detail?.rawValue)) + } + } + return nil + } + return OpenAIMessage(role: "user", content: .array(contentParts)) + case .file: + throw TachikomaError.invalidRequest("File content not supported in OpenAI API") + case let .audio(audioContent): + if let transcript = audioContent.transcript { + var text = transcript + if let duration = audioContent.duration { + text = "[Audio transcript, duration: \(Int(duration))s] \(transcript)" + } else { + text = "[Audio transcript] \(transcript)" + } + return OpenAIMessage(role: "user", content: .string(text)) + } else { + throw TachikomaError.invalidRequest("Audio content must be transcribed before sending to OpenAI") + } + } + } + + private func convertAssistantMessage(_ content: [AssistantContent]) throws -> OpenAIMessage { + var textContent = "" + var toolCalls: [OpenAIToolCall] = [] + + for item in content { + switch item { + case let .outputText(text): + textContent += text + case let .refusal(refusal): + return OpenAIMessage(role: "assistant", content: .string(refusal)) + case let .toolCall(toolCall): + toolCalls.append(OpenAIToolCall( + id: toolCall.id, + type: "function", + function: OpenAIFunction( + name: toolCall.function.name, + description: nil, + parameters: nil, + arguments: toolCall.function.arguments))) + } + } + + if !toolCalls.isEmpty { + return OpenAIMessage(role: "assistant", content: .string(textContent), toolCalls: toolCalls) + } else { + return OpenAIMessage(role: "assistant", content: .string(textContent)) + } + } + + private func convertToolParameters(_ params: ToolParameters) -> [String: Any] { + var properties: [String: Any] = [:] + + for (key, schema) in params.properties { + properties[key] = convertParameterSchema(schema) + } + + return [ + "type": params.type, + "properties": properties, + "required": params.required, + "additionalProperties": params.additionalProperties + ] + } + + private func convertParameterSchema(_ schema: ParameterSchema) -> [String: Any] { + var result: [String: Any] = [ + "type": schema.type.rawValue + ] + + if let description = schema.description { + result["description"] = description + } + + if let enumValues = schema.enumValues { + result["enum"] = enumValues + } + + if let minimum = schema.minimum { + result["minimum"] = minimum + } + + if let maximum = schema.maximum { + result["maximum"] = maximum + } + + if let pattern = schema.pattern { + result["pattern"] = pattern + } + + if let items = schema.items { + result["items"] = convertParameterSchema(items.value) + } + + if let properties = schema.properties { + result["properties"] = properties.mapValues { convertParameterSchema($0) } + } + + return result + } + + private func convertToolChoice(_ toolChoice: ToolChoice?) -> String? { + guard let toolChoice else { return nil } + + switch toolChoice { + case .auto: + return "auto" + case .none: + return "none" + case .required: + return "required" + case let .specific(toolName): + return toolName + } + } + + private func convertFromOpenAIResponse(_ response: OpenAIResponse) throws -> ModelResponse { + guard let choice = response.choices.first else { + throw TachikomaError.apiError(message: "No choices in OpenAI response") + } + + var content: [AssistantContent] = [] + + if let messageContent = choice.message.content { + switch messageContent { + case let .string(text): + content.append(.outputText(text)) + case let .array(parts): + // Extract text from content parts + let text = parts.compactMap { $0.text }.joined() + if !text.isEmpty { + content.append(.outputText(text)) + } + } + } + + if let toolCalls = choice.message.toolCalls { + for toolCall in toolCalls { + content.append(.toolCall(ToolCallItem( + id: toolCall.id, + type: .function, + function: FunctionCall( + name: toolCall.function.name, + arguments: toolCall.function.arguments ?? "")))) + } + } + + let usage = response.usage.map { usage in + Usage( + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + totalTokens: usage.totalTokens, + promptTokensDetails: nil, + completionTokensDetails: nil) + } + + return ModelResponse( + id: response.id, + model: response.model, + content: content, + usage: usage, + flagged: false, + finishReason: convertFinishReason(choice.finishReason)) + } + + private func convertFinishReason(_ reason: String?) -> FinishReason? { + guard let reason else { return nil } + return FinishReason(rawValue: reason) + } + + private func processStreamChunk(_ chunk: OpenAIStreamChunk) -> [StreamEvent]? { + var events: [StreamEvent] = [] + + if let delta = chunk.choices?.first?.delta { + if let content = delta.content { + events.append(.textDelta(StreamTextDelta(delta: content, index: 0))) + } + + if let toolCalls = delta.toolCalls { + for toolCall in toolCalls { + if let id = toolCall.id, let function = toolCall.function { + events.append(.toolCallDelta(StreamToolCallDelta( + id: id, + index: toolCall.index ?? 0, + function: FunctionCallDelta( + name: function.name, + arguments: function.arguments)))) + } + } + } + } + + if let finishReason = chunk.choices?.first?.finishReason { + events.append(.responseCompleted(StreamResponseCompleted( + id: chunk.id ?? "", + usage: nil, + finishReason: convertFinishReason(finishReason)))) + } + + return events.isEmpty ? nil : events + } + + private func handleErrorResponse(data: Data, response: HTTPURLResponse) throws { + if let errorResponse = try? JSONDecoder().decode(OpenAIErrorResponse.self, from: data) { + let message = errorResponse.error.message + let code = errorResponse.error.code + + switch response.statusCode { + case 401: + throw TachikomaError.authenticationFailed + case 429: + throw TachikomaError.rateLimited + case 400: + if message.contains("context_length_exceeded") { + throw TachikomaError.contextLengthExceeded + } else { + throw TachikomaError.invalidRequest(message) + } + case 500...599: + throw TachikomaError.modelOverloaded + default: + throw TachikomaError.apiError(message: message, code: code) + } + } else { + throw TachikomaError.apiError(message: "HTTP \(response.statusCode)", code: "\(response.statusCode)") + } + } +} \ No newline at end of file diff --git a/Sources/Providers/OpenAI/OpenAITypes.swift b/Sources/Providers/OpenAI/OpenAITypes.swift new file mode 100644 index 0000000..d44ecf6 --- /dev/null +++ b/Sources/Providers/OpenAI/OpenAITypes.swift @@ -0,0 +1,481 @@ +import Foundation + +// MARK: - Helper Types + +/// Sendable wrapper for Any values +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnySendable: Codable, @unchecked Sendable { + public let value: Any + + public init(_ value: Any) { + self.value = value + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if let intValue = try? container.decode(Int.self) { + self.value = intValue + } else if let doubleValue = try? container.decode(Double.self) { + self.value = doubleValue + } else if let boolValue = try? container.decode(Bool.self) { + self.value = boolValue + } else if let stringValue = try? container.decode(String.self) { + self.value = stringValue + } else if let arrayValue = try? container.decode([AnySendable].self) { + self.value = arrayValue.map(\.value) + } else if let dictValue = try? container.decode([String: AnySendable].self) { + self.value = dictValue.mapValues(\.value) + } else { + self.value = NSNull() + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + + switch value { + case let intValue as Int: + try container.encode(intValue) + case let doubleValue as Double: + try container.encode(doubleValue) + case let boolValue as Bool: + try container.encode(boolValue) + case let stringValue as String: + try container.encode(stringValue) + case let arrayValue as [Any]: + try container.encode(arrayValue.map { AnySendable($0) }) + case let dictValue as [String: Any]: + try container.encode(dictValue.mapValues { AnySendable($0) }) + default: + try container.encodeNil() + } + } +} + +// MARK: - OpenAI API Request Types + +/// Chat Completions API request format +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIChatRequest: Codable, Sendable { + public let model: String + public let messages: [OpenAIMessage] + public let tools: [OpenAITool]? + public let toolChoice: String? + public let temperature: Double? + public let topP: Double? + public let stream: Bool? + public let maxTokens: Int? + + enum CodingKeys: String, CodingKey { + case model, messages, tools, temperature, stream + case toolChoice = "tool_choice" + case topP = "top_p" + case maxTokens = "max_tokens" + } +} + +/// Responses API request format (for o3/o4 models) +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIResponsesRequest: Codable, Sendable { + public let model: String + public let input: [OpenAIMessage] + public let tools: [OpenAIResponsesTool]? + public let toolChoice: String? + public let temperature: Double? + public let topP: Double? + public let stream: Bool? + public let maxOutputTokens: Int? + public let reasoning: OpenAIReasoning? + + enum CodingKeys: String, CodingKey { + case model, input, tools, temperature, stream, reasoning + case toolChoice = "tool_choice" + case topP = "top_p" + case maxOutputTokens = "max_output_tokens" + } +} + +/// OpenAI message format +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIMessage: Codable, Sendable { + public let role: String + public let content: MessageContent? + public let toolCalls: [OpenAIToolCall]? + public let toolCallId: String? + + public enum MessageContent: Codable, Sendable { + case string(String) + case array([OpenAIMessageContentPart]) + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else if let arrayValue = try? container.decode([OpenAIMessageContentPart].self) { + self = .array(arrayValue) + } else { + throw DecodingError.dataCorruptedError(in: container, debugDescription: "Unable to decode message content") + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .array(value): + try container.encode(value) + } + } + } + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + case toolCallId = "tool_call_id" + } + + public init(role: String, content: MessageContent? = nil, toolCalls: [OpenAIToolCall]? = nil, toolCallId: String? = nil) { + self.role = role + self.content = content + self.toolCalls = toolCalls + self.toolCallId = toolCallId + } +} + +/// OpenAI message content part for multimodal messages +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIMessageContentPart: Codable, Sendable { + public let type: String + public let text: String? + public let imageUrl: OpenAIImageUrl? + + enum CodingKeys: String, CodingKey { + case type, text + case imageUrl = "image_url" + } + + public init(type: String, text: String? = nil, imageUrl: OpenAIImageUrl? = nil) { + self.type = type + self.text = text + self.imageUrl = imageUrl + } +} + +/// OpenAI image URL format +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIImageUrl: Codable, Sendable { + public let url: String + public let detail: String? + + public init(url: String, detail: String? = nil) { + self.url = url + self.detail = detail + } +} + +/// OpenAI tool definition for Chat Completions API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAITool: Codable, Sendable { + public let type: String + public let function: OpenAIFunction + + public init(type: String, function: OpenAIFunction) { + self.type = type + self.function = function + } +} + +/// OpenAI tool definition for Responses API +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIResponsesTool: Codable, Sendable { + public let type: String + public let name: String + public let description: String + public let parameters: [String: AnySendable] + + enum CodingKeys: String, CodingKey { + case type, name, description, parameters + } + + public init(type: String, name: String, description: String, parameters: [String: Any]) { + self.type = type + self.name = name + self.description = description + self.parameters = parameters.mapValues { AnySendable($0) } + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.type = try container.decode(String.self, forKey: .type) + self.name = try container.decode(String.self, forKey: .name) + self.description = try container.decode(String.self, forKey: .description) + + // Decode parameters as generic JSON + let parametersContainer = try container.nestedContainer(keyedBy: AnyCodingKey.self, forKey: .parameters) + var parameters: [String: AnySendable] = [:] + for key in parametersContainer.allKeys { + parameters[key.stringValue] = try parametersContainer.decode(AnySendable.self, forKey: key) + } + self.parameters = parameters + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(type, forKey: .type) + try container.encode(name, forKey: .name) + try container.encode(description, forKey: .description) + + // Skip encoding parameters for now - this needs to be handled at runtime + try container.encodeIfPresent(nil as String?, forKey: .parameters) + } +} + +/// OpenAI function definition +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIFunction: Codable, Sendable { + public let name: String + public let description: String? + public let parameters: [String: AnySendable]? + public let arguments: String? + + enum CodingKeys: String, CodingKey { + case name, description, parameters, arguments + } + + public init(name: String, description: String? = nil, parameters: [String: Any]? = nil, arguments: String? = nil) { + self.name = name + self.description = description + self.parameters = parameters?.mapValues { AnySendable($0) } + self.arguments = arguments + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.name = try container.decode(String.self, forKey: .name) + self.description = try container.decodeIfPresent(String.self, forKey: .description) + self.arguments = try container.decodeIfPresent(String.self, forKey: .arguments) + + if container.contains(.parameters) { + let parametersContainer = try container.nestedContainer(keyedBy: AnyCodingKey.self, forKey: .parameters) + var parameters: [String: AnySendable] = [:] + for key in parametersContainer.allKeys { + parameters[key.stringValue] = try parametersContainer.decode(AnySendable.self, forKey: key) + } + self.parameters = parameters + } else { + self.parameters = nil + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encodeIfPresent(description, forKey: .description) + try container.encodeIfPresent(arguments, forKey: .arguments) + + if let parameters = parameters { + // Convert AnySendable parameters to AnyCodable for encoding + let codableParams = parameters.mapValues { AnyCodable($0.value) } + try container.encode(codableParams, forKey: .parameters) + } + } +} + +/// OpenAI tool call +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIToolCall: Codable, Sendable { + public let id: String + public let type: String + public let function: OpenAIFunction + + public init(id: String, type: String, function: OpenAIFunction) { + self.id = id + self.type = type + self.function = function + } +} + +/// OpenAI reasoning configuration for o3/o4 models +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIReasoning: Codable, Sendable { + public let effort: String + public let summary: String + + public init(effort: String, summary: String) { + self.effort = effort + self.summary = summary + } +} + +// MARK: - OpenAI Response Types + +/// OpenAI Chat Completions response +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIResponse: Codable, Sendable { + public let id: String + public let model: String + public let choices: [OpenAIChoice] + public let usage: OpenAIUsage? +} + +/// OpenAI response choice +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIChoice: Codable, Sendable { + public let message: OpenAIMessage + public let finishReason: String? + + enum CodingKeys: String, CodingKey { + case message + case finishReason = "finish_reason" + } +} + +/// OpenAI usage information +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIUsage: Codable, Sendable { + public let promptTokens: Int + public let completionTokens: Int + public let totalTokens: Int + + enum CodingKeys: String, CodingKey { + case promptTokens = "prompt_tokens" + case completionTokens = "completion_tokens" + case totalTokens = "total_tokens" + } +} + +// MARK: - OpenAI Streaming Types + +/// OpenAI streaming chunk +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIStreamChunk: Codable, Sendable { + public let id: String? + public let model: String? + public let choices: [OpenAIStreamChoice]? +} + +/// OpenAI streaming choice +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIStreamChoice: Codable, Sendable { + public let delta: OpenAIStreamDelta? + public let finishReason: String? + + enum CodingKeys: String, CodingKey { + case delta + case finishReason = "finish_reason" + } +} + +/// OpenAI streaming delta +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIStreamDelta: Codable, Sendable { + public let content: String? + public let toolCalls: [OpenAIStreamToolCall]? + + enum CodingKeys: String, CodingKey { + case content + case toolCalls = "tool_calls" + } +} + +/// OpenAI streaming tool call +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIStreamToolCall: Codable, Sendable { + public let id: String? + public let type: String? + public let index: Int? + public let function: OpenAIStreamFunction? +} + +/// OpenAI streaming function +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIStreamFunction: Codable, Sendable { + public let name: String? + public let arguments: String? +} + +// MARK: - Error Types + +/// OpenAI error response +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIErrorResponse: Codable, Sendable { + public let error: OpenAIError +} + +/// OpenAI error details +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct OpenAIError: Codable, Sendable { + public let message: String + public let type: String? + public let code: String? +} + +// MARK: - Helper Types + +/// Dynamic coding key for JSON encoding/decoding +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +struct AnyCodingKey: CodingKey { + var stringValue: String + var intValue: Int? + + init?(stringValue: String) { + self.stringValue = stringValue + self.intValue = nil + } + + init?(intValue: Int) { + self.stringValue = "\(intValue)" + self.intValue = intValue + } +} + +/// Wrapper for any codable value +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public struct AnyCodable: Codable, @unchecked Sendable { + public let value: Any + + public init(_ value: Any) { + self.value = value + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + + if let intValue = value as? Int { + try container.encode(intValue) + } else if let doubleValue = value as? Double { + try container.encode(doubleValue) + } else if let stringValue = value as? String { + try container.encode(stringValue) + } else if let boolValue = value as? Bool { + try container.encode(boolValue) + } else { + let data = try JSONSerialization.data(withJSONObject: value) + let str = String(data: data, encoding: .utf8) ?? "{}" + try container.encode(str) + } + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + + if let intValue = try? container.decode(Int.self) { + value = intValue + } else if let doubleValue = try? container.decode(Double.self) { + value = doubleValue + } else if let stringValue = try? container.decode(String.self) { + value = stringValue + } else if let boolValue = try? container.decode(Bool.self) { + value = boolValue + } else { + value = NSNull() + } + } +} + +// MARK: - Container Extensions + +// Note: Complex Any encoding/decoding extensions removed +// We use AnySendable for type-safe handling of dynamic JSON data \ No newline at end of file diff --git a/Sources/Tachikoma.swift b/Sources/Tachikoma.swift new file mode 100644 index 0000000..5d00166 --- /dev/null +++ b/Sources/Tachikoma.swift @@ -0,0 +1,87 @@ +import Foundation +@_exported import Logging + +/// Tachikoma - A comprehensive Swift package for AI model integration +/// +/// Tachikoma provides a unified interface for connecting to various AI providers +/// including OpenAI, Anthropic, Grok (xAI), Ollama, and custom endpoints. +/// It supports both streaming and non-streaming responses, tool calling, +/// multimodal inputs, and configuration management. +/// +/// Named after the AI entity from Ghost in the Shell, Tachikoma embodies +/// the cyberpunk aesthetic of autonomous AI systems. +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) +public final class Tachikoma: @unchecked Sendable { + public static let shared = Tachikoma() + + private let logger: Logger + + private init() { + self.logger = Logger(label: "build.tachikoma") + } + + /// Get the shared logger instance + public var log: Logger { + logger + } + + /// Get a model instance for the specified model name + /// - Parameter modelName: The model identifier (e.g., "gpt-4.1", "claude-opus-4", "provider-id/model-name") + /// - Returns: A model instance conforming to ModelInterface + /// - Throws: TachikomaError if the model is not available or configuration is invalid + public func getModel(_ modelName: String) async throws -> any ModelInterface { + return try await ModelProvider.shared.getModel(modelName: modelName) + } + + /// Configure OpenAI provider with specific settings + /// - Parameter configuration: OpenAI configuration + public func configureOpenAI(_ configuration: ProviderConfiguration.OpenAI) async { + await ModelProvider.shared.configureOpenAI(configuration) + } + + /// Configure Anthropic provider with specific settings + /// - Parameter configuration: Anthropic configuration + public func configureAnthropic(_ configuration: ProviderConfiguration.Anthropic) async { + await ModelProvider.shared.configureAnthropic(configuration) + } + + /// Configure Ollama provider with specific settings + /// - Parameter configuration: Ollama configuration + public func configureOllama(_ configuration: ProviderConfiguration.Ollama) async { + await ModelProvider.shared.configureOllama(configuration) + } + + /// Configure Grok provider with specific settings + /// - Parameter configuration: Grok configuration + public func configureGrok(_ configuration: ProviderConfiguration.Grok) async { + await ModelProvider.shared.configureGrok(configuration) + } + + /// Set up all providers from environment variables + /// - Throws: TachikomaError if setup fails + public func setupFromEnvironment() async throws { + try await ModelProvider.shared.setupFromEnvironment() + } + + /// List all available models from configured providers + /// - Returns: Array of available model identifiers + public func availableModels() async -> [String] { + return await ModelProvider.shared.listModels() + } + + /// Clear all cached model instances + public func clearModelCache() async { + await ModelProvider.shared.clearCache() + } + + /// Register a custom model factory + /// - Parameters: + /// - modelName: The model name to register + /// - factory: Factory closure that creates the model instance + public func registerModel( + name modelName: String, + factory: @escaping @Sendable () throws -> any ModelInterface) async + { + await ModelProvider.shared.register(modelName: modelName, factory: factory) + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/AnthropicModelTests.swift b/Tests/TachikomaTests/AnthropicModelTests.swift new file mode 100644 index 0000000..8c34cd7 --- /dev/null +++ b/Tests/TachikomaTests/AnthropicModelTests.swift @@ -0,0 +1,373 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("Anthropic Model Tests") +struct AnthropicModelTests { + @Test("Model initialization") + func modelInitialization() async throws { + let model = AnthropicModel( + apiKey: "sk-ant-test-key-123456789", + modelName: "claude-opus-4-20250514") + + #expect(model.maskedApiKey == "sk-ant...789") + } + + @Test("API key masking") + func apiKeyMasking() async throws { + // Test short key + let shortModel = AnthropicModel(apiKey: "short") + #expect(shortModel.maskedApiKey == "***") + + // Test normal key + let normalModel = AnthropicModel(apiKey: "sk-ant-api-key-1234567890abcdefghijklmnopqrstuvwxyz") + #expect(normalModel.maskedApiKey == "sk-ant...xyz") + } + + @Test("System message extraction") + func systemMessageExtraction() async throws { + let model = AnthropicModel(apiKey: "test-key") + + let request = ModelRequest( + messages: [ + Message.system(content: "You are a helpful assistant."), + Message.user(content: .text("Hello!")), + ], + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // System messages should be properly handled + #expect(request.messages.first?.type == .system) + } + + @Test("Tool conversion") + func toolConversion() async throws { + let model = AnthropicModel(apiKey: "test-key") + + let toolDef = ToolDefinition( + function: FunctionDefinition( + name: "get_weather", + description: "Get the current weather", + parameters: ToolParameters( + properties: ["location": ParameterSchema(type: .string, description: "The location")], + required: ["location"]))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("What's the weather?")), + ], + tools: [toolDef], + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + #expect(request.tools?.count == 1) + #expect(request.tools?.first?.function.name == "get_weather") + + // Test that the model can process the request (will fail at network level) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Image content handling") + func imageContentHandling() async throws { + let model = AnthropicModel(apiKey: "test-key", modelName: "claude-opus-4-20250514") + + let imageData = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + + let request = ModelRequest( + messages: [ + Message.user(content: .multimodal([ + MessageContentPart(type: "text", text: "What's in this image?"), + MessageContentPart(type: "image", imageUrl: ImageContent(base64: imageData)), + ])), + ], + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + if case let .user(_, content) = request.messages.first, + case let .multimodal(parts) = content + { + #expect(parts.count == 2) + } else { + Issue.record("Expected multimodal content") + } + + // Test that URL images are supported + let urlRequest = ModelRequest( + messages: [ + Message.user(content: .image(ImageContent(url: "https://example.com/image.jpg"))), + ], + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + if case let .user(_, content) = urlRequest.messages.first, + case .image = content + { + // Expected image content + } else { + Issue.record("Expected image content") + } + + // Test processing (will fail at network level) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Message type conversion") + func messageTypeConversion() async throws { + let model = AnthropicModel(apiKey: "test-key") + + // Test various message types + let messages: [Message] = [ + Message.system(content: "You are Claude."), + Message.user(content: .text("Hello Claude!")), + Message.assistant(content: [.outputText("Hello! How can I help you?")]), + Message.user(content: .text("What's 2+2?")), + Message.assistant(content: [.outputText("2 + 2 = 4")]), + ] + + let request = ModelRequest( + messages: messages, + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // Verify message structure + #expect(request.messages.count == 5) + #expect(request.messages[0].type == .system) + #expect(request.messages[1].type == .user) + #expect(request.messages[2].type == .assistant) + + // Test processing (will fail at network level) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Streaming response handling") + func streamingResponseHandling() async throws { + let model = AnthropicModel(apiKey: "test-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Write a short poem")), + ], + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // Test streaming (will fail at network level) + do { + let stream = try await model.getStreamedResponse(request: request) + var eventCount = 0 + + for try await event in stream { + eventCount += 1 + _ = event + } + + Issue.record("Expected network error but got \(eventCount) events") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Tool call handling") + func toolCallHandling() async throws { + let model = AnthropicModel(apiKey: "test-key") + + // Create a tool call message + let toolCall = ToolCallItem( + id: "call_123", + type: .function, + function: FunctionCall( + name: "get_weather", + arguments: "{\"location\": \"Paris\"}" + ) + ) + + let messages: [Message] = [ + Message.user(content: .text("What's the weather in Paris?")), + Message.assistant(content: [.toolCall(toolCall)]), + Message.tool(toolCallId: "call_123", content: "It's sunny, 22°C"), + ] + + let request = ModelRequest( + messages: messages, + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // Test tool call processing (will fail at network level) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Model variants") + func modelVariants() async throws { + let modelNames = [ + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + "claude-opus-4-20250514-thinking", + "claude-sonnet-4-20250514-thinking", + "claude-3-7-sonnet", + "claude-3-5-sonnet", + "claude-3-5-haiku", + ] + + for modelName in modelNames { + let model = AnthropicModel(apiKey: "test-key", modelName: modelName) + #expect(model.maskedApiKey == "***") + + // Test that each model variant can be created and handles requests + let request = ModelRequest( + messages: [Message.user(content: .text("Test"))], + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(modelName)") + } catch { + #expect(error is TachikomaError) + } + } + } + + @Test("Error handling") + func errorHandling() async throws { + let model = AnthropicModel(apiKey: "invalid-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error but got success") + } catch let error as TachikomaError { + // Verify we get appropriate error types + switch error { + case .apiError, .authenticationFailed, .networkError: + // Expected error types for invalid API key + break + default: + Issue.record("Unexpected error type: \(error)") + } + } catch { + Issue.record("Unexpected error type: \(type(of: error))") + } + } + + @Test("Audio content handling") + func audioContentHandling() async throws { + let model = AnthropicModel(apiKey: "test-key") + + let audioContent = AudioContent( + transcript: "Hello, this is a test transcript.", + duration: 5.0 + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .audio(audioContent)), + ], + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // Test audio content processing + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("File content rejection") + func fileContentRejection() async throws { + let model = AnthropicModel(apiKey: "test-key") + + let fileContent = FileContent( + id: nil, + url: nil, + name: "test.txt" + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .file(fileContent)), + ], + settings: ModelSettings(modelName: "claude-opus-4-20250514")) + + // File content should be rejected + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for file content") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - file content not supported + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } +} + +// MARK: - Provider Configuration Tests + +@Suite("Anthropic Provider Configuration Tests") +struct AnthropicProviderConfigurationTests { + @Test("Provider configuration") + func providerConfiguration() async throws { + let tachikoma = Tachikoma.shared + + // Test custom Anthropic configuration + let config = ProviderConfiguration.Anthropic( + apiKey: "test-key", + baseURL: URL(string: "https://api.anthropic.com/v1")! + ) + + await tachikoma.configureAnthropic(config) + } + + @Test("Model registration") + func modelRegistration() async throws { + let tachikoma = Tachikoma.shared + + // Register Claude models + let modelNames = [ + "claude-opus-4", + "claude-sonnet-4", + "claude-3-5-sonnet", + "claude-3-5-haiku", + ] + + for modelName in modelNames { + await tachikoma.registerModel(name: modelName, factory: { + AnthropicModel(apiKey: "test-key", modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel(modelName) + #expect(model is AnthropicModel) + } catch { + Issue.record("Failed to get model \(modelName): \(error)") + } + } + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/GrokModelTests.swift b/Tests/TachikomaTests/GrokModelTests.swift new file mode 100644 index 0000000..47f2f96 --- /dev/null +++ b/Tests/TachikomaTests/GrokModelTests.swift @@ -0,0 +1,331 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("Grok Model Tests") +struct GrokModelTests { + @Test("Model initialization") + func modelInitialization() async throws { + let model = GrokModel( + apiKey: "test-key-123456", + baseURL: URL(string: "https://api.x.ai/v1")!) + + #expect(model.maskedApiKey == "test-k...56") + } + + @Test("API key masking") + func apiKeyMasking() async throws { + // Test short key + let shortModel = GrokModel(apiKey: "short") + #expect(shortModel.maskedApiKey == "***") + + // Test normal key + let normalModel = GrokModel(apiKey: "test-api-key-1234567890abcdefghijklmnopqrstuvwxyz") + #expect(normalModel.maskedApiKey == "test-a...yz") + } + + @Test("Default base URL") + func defaultBaseURL() async throws { + let model = GrokModel(apiKey: "test-key-123456") + + // Verify it uses the correct xAI API endpoint + // We can't directly access baseURL, but we can test the behavior + #expect(model.maskedApiKey == "test-k...56") + } + + @Test("Parameter filtering for Grok 4") + func grok4ParameterFiltering() async throws { + let model = GrokModel(apiKey: "test-key") + + // Create a request with parameters that should be filtered for Grok 4 + let settings = ModelSettings( + modelName: "grok-4", + temperature: 0.7, + frequencyPenalty: 0.5, // Should be removed for grok-4 + presencePenalty: 0.5, // Should be removed for grok-4 + stopSequences: ["stop"] // Should be removed for grok-4 + ) + + let request = ModelRequest( + messages: [ + Message.system(content: "Test system message"), + Message.user(content: .text("Test user message")), + ], + tools: nil, + settings: settings) + + // We can't directly test the filtering without mocking the network request + // But we can verify the model handles the request without crashing + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + // Expected to fail due to no valid API key/network + #expect(error is TachikomaError) + } + } + + @Test("Tool parameter conversion") + func toolParameterConversion() async throws { + let model = GrokModel(apiKey: "test-key") + + // Create a tool definition + let tool = ToolDefinition( + function: FunctionDefinition( + name: "test_tool", + description: "A test tool", + parameters: ToolParameters( + type: "object", + properties: [ + "message": ParameterSchema( + type: .string, + description: "A test message"), + "count": ParameterSchema( + type: .integer, + description: "A count", + minimum: 0, + maximum: 100), + ], + required: ["message"]))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Use the test tool")), + ], + tools: [tool], + settings: ModelSettings(modelName: "grok-4")) + + // Verify the model can process tool definitions + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + // Expected to fail due to no valid API key/network + #expect(error is TachikomaError) + } + } + + @Test("Message type conversion") + func messageTypeConversion() async throws { + let model = GrokModel(apiKey: "test-key") + + // Test various message types + let messages: [Message] = [ + Message.system(content: "System prompt"), + Message.user(content: .text("User text")), + Message.assistant(content: [.outputText("Assistant response")]), + Message.tool( + toolCallId: "tool-123", + content: "Tool result"), + ] + + let request = ModelRequest( + messages: messages, + tools: nil, + settings: ModelSettings(modelName: "grok-4")) + + // Verify message conversion doesn't crash + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + // Expected to fail due to no valid API key/network + #expect(error is TachikomaError) + } + } + + @Test("Multimodal message support") + func multimodalMessageSupport() async throws { + let model = GrokModel(apiKey: "test-key", modelName: "grok-2-vision-1212") + + // Create a multimodal message with text and image + let imageData = Data([0xFF, 0xD8, 0xFF]) // Minimal JPEG header + + let request = ModelRequest( + messages: [ + Message.user(content: .multimodal([ + MessageContentPart(type: "text", text: "What is in this image?"), + MessageContentPart(type: "image", imageUrl: ImageContent(base64: imageData.base64EncodedString())), + ])), + ], + tools: nil, + settings: ModelSettings(modelName: "grok-2-vision-1212")) + + // Verify multimodal content handling + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + // Expected to fail due to no valid API key/network + #expect(error is TachikomaError) + } + } + + @Test("Streaming response handling") + func streamingResponse() async throws { + let model = GrokModel(apiKey: "test-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Stream this response")), + ], + tools: nil, + settings: ModelSettings(modelName: "grok-4")) + + // Test streaming + do { + let stream = try await model.getStreamedResponse(request: request) + var eventCount = 0 + + for try await event in stream { + eventCount += 1 + // Would normally process events here + _ = event + } + + Issue.record("Expected network error but got success with \(eventCount) events") + } catch { + // Expected to fail due to no valid API key/network + #expect(error is TachikomaError) + } + } + + @Test("Error handling") + func errorHandling() async throws { + let model = GrokModel(apiKey: "invalid-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "grok-4")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error but got success") + } catch let error as TachikomaError { + // Verify we get appropriate error types + switch error { + case .apiError, .authenticationFailed: + // Expected error types for invalid API key + break + default: + Issue.record("Unexpected error type: \(error)") + } + } catch { + Issue.record("Unexpected error type: \(type(of: error))") + } + } + + @Test("Reasoning message rejection") + func reasoningMessageRejection() async throws { + let model = GrokModel(apiKey: "test-key") + + let request = ModelRequest( + messages: [ + Message.reasoning(content: "Some reasoning"), + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "grok-4")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for reasoning messages") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - reasoning messages not supported + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } +} + +// MARK: - Model Provider Tests + +@Suite("Grok Model Provider Tests") +struct GrokModelProviderTests { + @Test("Grok model registration") + func grokModelRegistration() async throws { + let tachikoma = Tachikoma.shared + + // Register a test Grok model + await tachikoma.registerModel(name: "test-grok", factory: { + GrokModel(apiKey: "test-key", modelName: "grok-4") + }) + + // Test that we can get the model + do { + let model = try await tachikoma.getModel("test-grok") + #expect(model is GrokModel) + } catch { + Issue.record("Failed to get registered model: \(error)") + } + } + + @Test("Model name variants") + func modelNameVariants() async throws { + let tachikoma = Tachikoma.shared + + // Register various Grok model names + let modelNames = [ + "grok-4", + "grok-4-0709", + "grok-4-latest", + "grok-2-1212", + "grok-2-vision-1212", + "grok-beta", + "grok-vision-beta", + ] + + for modelName in modelNames { + await tachikoma.registerModel(name: modelName, factory: { + GrokModel(apiKey: "test-key", modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel(modelName) + #expect(model is GrokModel) + } catch { + Issue.record("Failed to get model \(modelName): \(error)") + } + } + } + + @Test("Parameter filtering detection") + func parameterFilteringDetection() async throws { + // Test that Grok 3 and 4 models filter parameters correctly + let grok3Model = GrokModel(apiKey: "test-key", modelName: "grok-3") + let grok4Model = GrokModel(apiKey: "test-key", modelName: "grok-4") + let grokBetaModel = GrokModel(apiKey: "test-key", modelName: "grok-beta") + + // All should handle the same request without errors (though network will fail) + let settings = ModelSettings( + modelName: "grok-4", + temperature: 0.7, + frequencyPenalty: 0.5, + presencePenalty: 0.5, + stopSequences: ["stop"] + ) + + let request = ModelRequest( + messages: [Message.user(content: .text("Test"))], + settings: settings + ) + + // Test each model handles parameter filtering + for (name, model) in [("grok-3", grok3Model), ("grok-4", grok4Model), ("grok-beta", grokBetaModel)] { + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(name)") + } catch { + // Expected to fail due to network/auth, but not due to parameter issues + #expect(error is TachikomaError) + } + } + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/OllamaModelTests.swift b/Tests/TachikomaTests/OllamaModelTests.swift new file mode 100644 index 0000000..f4cdbd0 --- /dev/null +++ b/Tests/TachikomaTests/OllamaModelTests.swift @@ -0,0 +1,611 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("Ollama Model Tests") +struct OllamaModelTests { + @Test("Model initialization") + func modelInitialization() async throws { + let model = OllamaModel( + modelName: "llama3.3", + baseURL: URL(string: "http://localhost:11434")!) + + // Ollama doesn't use API keys, so maskedApiKey should return a placeholder + #expect(model.maskedApiKey == "local-ollama") + } + + @Test("Default base URL") + func defaultBaseURL() async throws { + let model = OllamaModel(modelName: "llama3.3") + + // Should use default localhost URL + #expect(model.maskedApiKey == "local-ollama") + } + + @Test("Custom base URL") + func customBaseURL() async throws { + let customURL = URL(string: "http://remote-server:11434")! + let model = OllamaModel(modelName: "llama3.3", baseURL: customURL) + + #expect(model.maskedApiKey == "local-ollama") + } + + @Test("Tool calling support detection") + func toolCallingSupportDetection() async throws { + // Models with tool calling support + let supportedModels = [ + "llama3.3", + "llama3.2", + "llama3.1", + "mistral-nemo", + "firefunction-v2", + "command-r-plus", + "command-r", + ] + + for modelName in supportedModels { + let model = OllamaModel(modelName: modelName) + + let toolDef = ToolDefinition( + function: FunctionDefinition( + name: "get_time", + description: "Get current time", + parameters: ToolParameters( + properties: [:], + required: []))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("What time is it?")), + ], + tools: [toolDef], + settings: ModelSettings(modelName: "llama3.3")) + + // Should handle tool calls (will fail at network level) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(modelName)") + } catch { + #expect(error is TachikomaError) + } + } + } + + @Test("Vision model tool rejection") + func visionModelToolRejection() async throws { + // Vision models that don't support tool calling + let visionModels = [ + "llava", + "bakllava", + "llama3.2-vision:11b", + "qwen2.5vl:7b", + ] + + for modelName in visionModels { + let model = OllamaModel(modelName: modelName) + + let toolDef = ToolDefinition( + function: FunctionDefinition( + name: "get_time", + description: "Get current time", + parameters: ToolParameters( + properties: [:], + required: []))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("What time is it?")), + ], + tools: [toolDef], + settings: ModelSettings(modelName: "llama3.3")) + + // Should reject tool calls for vision models + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for vision model \(modelName) with tools") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - vision models don't support tool calling + break + default: + Issue.record("Expected invalidRequest error for \(modelName), got: \(error)") + } + } + } + } + + @Test("Message type conversion") + func messageTypeConversion() async throws { + let model = OllamaModel(modelName: "llama3.3") + + // Test various message types + let messages: [Message] = [ + Message.system(content: "You are a helpful assistant."), + Message.user(content: .text("Hello!")), + Message.assistant(content: [.outputText("Hi there!")]), + Message.tool( + toolCallId: "call_123", + content: "Current time: 2:30 PM"), + ] + + let request = ModelRequest( + messages: messages, + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + // Verify message conversion doesn't crash + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Multimodal message support for vision models") + func multimodalMessageSupport() async throws { + let model = OllamaModel(modelName: "llava") + + // Create a multimodal message with text and image + let imageData = Data([0xFF, 0xD8, 0xFF]) // Minimal JPEG header + + let request = ModelRequest( + messages: [ + Message.user(content: .multimodal([ + MessageContentPart(type: "text", text: "What is in this image?"), + MessageContentPart(type: "image", imageUrl: ImageContent(base64: imageData.base64EncodedString())), + ])), + ], + tools: nil, + settings: ModelSettings(modelName: "llava")) + + // Verify multimodal content handling for vision models + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Text-only models image rejection") + func textOnlyModelsImageRejection() async throws { + let model = OllamaModel(modelName: "llama3.3") // Text-only model + + let imageData = Data([0xFF, 0xD8, 0xFF]) + + let request = ModelRequest( + messages: [ + Message.user(content: .image(ImageContent(base64: imageData.base64EncodedString()))), + ], + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + // Should reject image content for text-only models + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for text-only model with image") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - text-only models don't support images + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } + + @Test("Streaming response handling") + func streamingResponse() async throws { + let model = OllamaModel(modelName: "llama3.3") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Tell me a joke")), + ], + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + // Test streaming + do { + let stream = try await model.getStreamedResponse(request: request) + var eventCount = 0 + + for try await event in stream { + eventCount += 1 + _ = event + } + + Issue.record("Expected network error but got success with \(eventCount) events") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Extended timeout handling") + func extendedTimeoutHandling() async throws { + let model = OllamaModel(modelName: "llama3.3") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Generate a long response")), + ], + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + // Ollama requests should have extended timeouts (5 minutes) + // We can't test the actual timeout without a real server, + // but we can verify the request structure + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Tool call JSON parsing") + func toolCallJSONParsing() async throws { + let model = OllamaModel(modelName: "llama3.3") + + // Test models that output tool calls as JSON in content + let toolDef = ToolDefinition( + function: FunctionDefinition( + name: "calculate", + description: "Perform calculation", + parameters: ToolParameters( + properties: [ + "expression": ParameterSchema(type: .string, description: "Math expression") + ], + required: ["expression"]))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("What is 5 + 3?")), + ], + tools: [toolDef], + settings: ModelSettings(modelName: "llama3.3")) + + // Should handle tool call parsing + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Error handling") + func errorHandling() async throws { + let model = OllamaModel(modelName: "nonexistent-model") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error but got success") + } catch let error as TachikomaError { + // Verify we get appropriate error types + switch error { + case .apiError, .networkError: + // Expected error types for nonexistent model + break + default: + Issue.record("Unexpected error type: \(error)") + } + } catch { + Issue.record("Unexpected error type: \(type(of: error))") + } + } + + @Test("Reasoning message rejection") + func reasoningMessageRejection() async throws { + let model = OllamaModel(modelName: "llama3.3") + + let request = ModelRequest( + messages: [ + Message.reasoning(content: "Let me think..."), + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for reasoning messages") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - reasoning messages not supported + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } + + @Test("File content rejection") + func fileContentRejection() async throws { + let model = OllamaModel(modelName: "llama3.3") + + let fileContent = FileContent( + id: nil, + url: nil, + name: "test.txt" + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .file(fileContent)), + ], + settings: ModelSettings(modelName: "llama3.3")) + + // File content should be rejected + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for file content") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - file content not supported + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } + + @Test("Audio content handling") + func audioContentHandling() async throws { + let model = OllamaModel(modelName: "llama3.3") + + let audioContent = AudioContent( + transcript: "Hello, this is a test transcript.", + duration: 5.0 + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .audio(audioContent)), + ], + settings: ModelSettings(modelName: "llama3.3")) + + // Test audio content processing (converts to text) + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } +} + +// MARK: - Model Provider Tests + +@Suite("Ollama Model Provider Tests") +struct OllamaModelProviderTests { + @Test("Ollama model registration") + func ollamaModelRegistration() async throws { + let tachikoma = Tachikoma.shared + + // Register a test Ollama model + await tachikoma.registerModel(name: "test-llama", factory: { + OllamaModel(modelName: "llama3.3") + }) + + // Test that we can get the model + do { + let model = try await tachikoma.getModel("test-llama") + #expect(model is OllamaModel) + } catch { + Issue.record("Failed to get registered model: \(error)") + } + } + + @Test("Model name variants") + func modelNameVariants() async throws { + let tachikoma = Tachikoma.shared + + // Register various Ollama model names + let modelNames = [ + "llama3.3", + "llama3.2", + "llama3.1", + "llava", + "mistral-nemo", + "firefunction-v2", + "command-r-plus", + "deepseek-r1:8b", + ] + + for modelName in modelNames { + await tachikoma.registerModel(name: modelName, factory: { + OllamaModel(modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel(modelName) + #expect(model is OllamaModel) + } catch { + Issue.record("Failed to get model \(modelName): \(error)") + } + } + } + + @Test("Base URL configuration") + func baseURLConfiguration() async throws { + let tachikoma = Tachikoma.shared + + // Test custom base URL + let customURL = URL(string: "http://remote-ollama:11434")! + + await tachikoma.registerModel(name: "custom-ollama", factory: { + OllamaModel(modelName: "llama3.3", baseURL: customURL) + }) + + do { + let model = try await tachikoma.getModel("custom-ollama") + #expect(model is OllamaModel) + } catch { + Issue.record("Failed to get custom URL model: \(error)") + } + } + + @Test("Tool support matrix") + func toolSupportMatrix() async throws { + let tachikoma = Tachikoma.shared + + // Models with tool support + let toolSupportedModels = [ + "llama3.3", + "llama3.2", + "mistral-nemo", + "firefunction-v2", + ] + + // Models without tool support + let nonToolModels = [ + "llava", + "bakllava", + "devstral", + ] + + let toolDef = ToolDefinition( + function: FunctionDefinition( + name: "test_tool", + description: "Test tool", + parameters: ToolParameters(properties: [:], required: []))) + + // Test tool-supported models + for modelName in toolSupportedModels { + await tachikoma.registerModel(name: "tool-\(modelName)", factory: { + OllamaModel(modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel("tool-\(modelName)") + #expect(model is OllamaModel) + + // Should accept tool definitions + let request = ModelRequest( + messages: [Message.user(content: .text("Use tool"))], + tools: [toolDef], + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(modelName)") + } catch { + #expect(error is TachikomaError) + } + } catch { + Issue.record("Failed to get tool model \(modelName): \(error)") + } + } + + // Test non-tool models + for modelName in nonToolModels { + await tachikoma.registerModel(name: "no-tool-\(modelName)", factory: { + OllamaModel(modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel("no-tool-\(modelName)") + #expect(model is OllamaModel) + + // Should reject tool definitions + let request = ModelRequest( + messages: [Message.user(content: .text("Use tool"))], + tools: [toolDef], + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for non-tool model \(modelName)") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected for non-tool models + break + default: + Issue.record("Expected invalidRequest for \(modelName), got: \(error)") + } + } + } catch { + Issue.record("Failed to get non-tool model \(modelName): \(error)") + } + } + } + + @Test("Vision capability detection") + func visionCapabilityDetection() async throws { + let tachikoma = Tachikoma.shared + + // Vision models + let visionModels = ["llava", "bakllava", "llama3.2-vision:11b"] + + // Text-only models + let textModels = ["llama3.3", "mistral-nemo", "command-r"] + + let imageData = Data([0xFF, 0xD8, 0xFF]) + + // Test vision models accept images + for modelName in visionModels { + await tachikoma.registerModel(name: "vision-\(modelName)", factory: { + OllamaModel(modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel("vision-\(modelName)") + let request = ModelRequest( + messages: [Message.user(content: .image(ImageContent(base64: imageData.base64EncodedString())))], + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for vision model \(modelName)") + } catch { + #expect(error is TachikomaError) + } + } catch { + Issue.record("Failed vision model test for \(modelName): \(error)") + } + } + + // Test text models reject images + for modelName in textModels { + await tachikoma.registerModel(name: "text-\(modelName)", factory: { + OllamaModel(modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel("text-\(modelName)") + let request = ModelRequest( + messages: [Message.user(content: .image(ImageContent(base64: imageData.base64EncodedString())))], + settings: ModelSettings(modelName: "llama3.3")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for text model \(modelName) with image") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected for text-only models + break + default: + Issue.record("Expected invalidRequest for \(modelName), got: \(error)") + } + } + } catch { + Issue.record("Failed text model test for \(modelName): \(error)") + } + } + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/OpenAIModelTests.swift b/Tests/TachikomaTests/OpenAIModelTests.swift new file mode 100644 index 0000000..a1f1644 --- /dev/null +++ b/Tests/TachikomaTests/OpenAIModelTests.swift @@ -0,0 +1,475 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("OpenAI Model Tests") +struct OpenAIModelTests { + @Test("Model initialization") + func modelInitialization() async throws { + let model = OpenAIModel( + apiKey: "sk-test-key-123456789", + modelName: "gpt-4.1") + + #expect(model.maskedApiKey == "sk-...789") + } + + @Test("API key masking") + func apiKeyMasking() async throws { + // Test short key + let shortModel = OpenAIModel(apiKey: "short") + #expect(shortModel.maskedApiKey == "***") + + // Test normal key + let normalModel = OpenAIModel(apiKey: "sk-test-api-key-1234567890abcdefghijklmnopqrstuvwxyz") + #expect(normalModel.maskedApiKey == "sk-...xyz") + } + + @Test("Default base URL") + func defaultBaseURL() async throws { + let model = OpenAIModel(apiKey: "test-key") + + // Verify it uses the correct OpenAI API endpoint + // We can't directly access baseURL, but we can test the behavior + #expect(model.maskedApiKey == "***") + } + + @Test("Dual API support") + func dualAPISupport() async throws { + let model = OpenAIModel(apiKey: "test-key") + + // Test Chat Completions API request + let chatRequest = ModelRequest( + messages: [ + Message.user(content: .text("Hello")), + ], + tools: nil, + settings: ModelSettings(apiType: "chat")) + + // Test Responses API request (for reasoning models) + let responsesRequest = ModelRequest( + messages: [ + Message.user(content: .text("Hello")), + ], + tools: nil, + settings: ModelSettings(apiType: "responses")) + + // Both should be processable (will fail at network level) + do { + _ = try await model.getResponse(request: chatRequest) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + + do { + _ = try await model.getResponse(request: responsesRequest) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Reasoning models parameter handling") + func reasoningModelsParameterHandling() async throws { + let model = OpenAIModel(apiKey: "test-key", modelName: "o3") + + // Create a request with reasoning parameters + let settings = ModelSettings( + reasoningEffort: "medium", + reasoning: ["summary": "detailed"], + temperature: nil // o3 models don't support temperature + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Solve this complex problem")), + ], + tools: nil, + settings: settings) + + // Should handle reasoning parameters correctly + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Tool parameter conversion") + func toolParameterConversion() async throws { + let model = OpenAIModel(apiKey: "test-key") + + // Create a tool definition + let tool = ToolDefinition( + function: FunctionDefinition( + name: "get_weather", + description: "Get current weather", + parameters: ToolParameters( + type: "object", + properties: [ + "location": ParameterSchema( + type: .string, + description: "The location"), + "units": ParameterSchema( + type: .string, + description: "Temperature units", + enumValues: ["celsius", "fahrenheit"]), + ], + required: ["location"]))) + + let request = ModelRequest( + messages: [ + Message.user(content: .text("What's the weather in Paris?")), + ], + tools: [tool], + settings: ModelSettings(modelName: "gpt-4.1")) + + // Verify the model can process tool definitions + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Message type conversion") + func messageTypeConversion() async throws { + let model = OpenAIModel(apiKey: "test-key") + + // Test various message types + let messages: [Message] = [ + Message.system(content: "You are a helpful assistant."), + Message.user(content: .text("Hello!")), + Message.assistant(content: [.outputText("Hi there!")]), + Message.tool( + toolCallId: "call_123", + content: "Weather data"), + ] + + let request = ModelRequest( + messages: messages, + tools: nil, + settings: ModelSettings(modelName: "gpt-4.1")) + + // Verify message conversion doesn't crash + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Multimodal message support") + func multimodalMessageSupport() async throws { + let model = OpenAIModel(apiKey: "test-key", modelName: "gpt-4o") + + // Create a multimodal message with text and image + let imageData = Data([0xFF, 0xD8, 0xFF]) // Minimal JPEG header + + let request = ModelRequest( + messages: [ + Message.user(content: .multimodal([ + MessageContentPart(type: "text", text: "What is in this image?"), + MessageContentPart(type: "image", imageUrl: ImageContent(base64: imageData.base64EncodedString())), + ])), + ], + tools: nil, + settings: ModelSettings(modelName: "gpt-4.1")) + + // Verify multimodal content handling + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Streaming response handling") + func streamingResponse() async throws { + let model = OpenAIModel(apiKey: "test-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Write a short story")), + ], + tools: nil, + settings: ModelSettings(modelName: "gpt-4.1")) + + // Test streaming + do { + let stream = try await model.getStreamedResponse(request: request) + var eventCount = 0 + + for try await event in stream { + eventCount += 1 + _ = event + } + + Issue.record("Expected network error but got success with \(eventCount) events") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Error handling") + func errorHandling() async throws { + let model = OpenAIModel(apiKey: "invalid-key") + + let request = ModelRequest( + messages: [ + Message.user(content: .text("Test")), + ], + tools: nil, + settings: ModelSettings(modelName: "gpt-4.1")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error but got success") + } catch let error as TachikomaError { + // Verify we get appropriate error types + switch error { + case .apiError, .authenticationFailed: + // Expected error types for invalid API key + break + default: + Issue.record("Unexpected error type: \(error)") + } + } catch { + Issue.record("Unexpected error type: \(type(of: error))") + } + } + + @Test("Reasoning message handling") + func reasoningMessageHandling() async throws { + let model = OpenAIModel(apiKey: "test-key", modelName: "o3") + + let request = ModelRequest( + messages: [ + Message.reasoning(content: "Let me think about this step by step..."), + Message.user(content: .text("What is 2+2?")), + ], + tools: nil, + settings: ModelSettings(modelName: "gpt-4.1")) + + // o3 models should handle reasoning messages + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error but got success") + } catch { + #expect(error is TachikomaError) + } + } + + @Test("Model variants") + func modelVariants() async throws { + let modelNames = [ + "o3", + "o3-mini", + "o3-pro", + "o4-mini", + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4o", + "gpt-4o-mini", + ] + + for modelName in modelNames { + let model = OpenAIModel(apiKey: "test-key", modelName: modelName) + #expect(model.maskedApiKey == "***") + + // Test that each model variant can be created and handles requests + let request = ModelRequest( + messages: [Message.user(content: .text("Test"))], + settings: ModelSettings(modelName: "gpt-4.1")) + + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(modelName)") + } catch { + #expect(error is TachikomaError) + } + } + } + + @Test("Temperature parameter filtering") + func temperatureParameterFiltering() async throws { + // o3 models should not support temperature + let o3Model = OpenAIModel(apiKey: "test-key", modelName: "o3") + let gptModel = OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + + let settings = ModelSettings(temperature: 0.7) + + let request = ModelRequest( + messages: [Message.user(content: .text("Test"))], + settings: settings) + + // Both should handle the request (temperature filtered for o3) + for (name, model) in [("o3", o3Model), ("gpt-4.1", gptModel)] { + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error for \(name)") + } catch { + #expect(error is TachikomaError) + } + } + } + + @Test("File content rejection") + func fileContentRejection() async throws { + let model = OpenAIModel(apiKey: "test-key") + + let fileContent = FileContent( + filename: "test.txt", + content: "Test file content", + mimeType: "text/plain" + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .file(fileContent)), + ], + settings: ModelSettings(modelName: "gpt-4.1")) + + // File content should be rejected + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected error for file content") + } catch let error as TachikomaError { + switch error { + case .invalidRequest: + // Expected - file content not supported + break + default: + Issue.record("Expected invalidRequest error, got: \(error)") + } + } + } + + @Test("Audio content handling") + func audioContentHandling() async throws { + let model = OpenAIModel(apiKey: "test-key") + + let audioContent = AudioContent( + transcript: "Hello, this is a test transcript.", + duration: 5.0 + ) + + let request = ModelRequest( + messages: [ + Message.user(content: .audio(audioContent)), + ], + settings: ModelSettings(modelName: "gpt-4.1")) + + // Test audio content processing + do { + _ = try await model.getResponse(request: request) + Issue.record("Expected network error") + } catch { + #expect(error is TachikomaError) + } + } +} + +// MARK: - Provider Configuration Tests + +@Suite("OpenAI Provider Configuration Tests") +struct OpenAIProviderConfigurationTests { + @Test("Provider configuration") + func providerConfiguration() async throws { + let tachikoma = Tachikoma.shared + + // Test custom OpenAI configuration + let config = ProviderConfiguration.OpenAI( + apiKey: "test-key", + baseURL: URL(string: "https://api.openai.com/v1")! + ) + + await tachikoma.configureOpenAI(config) + } + + @Test("Model registration") + func modelRegistration() async throws { + let tachikoma = Tachikoma.shared + + // Register OpenAI models + let modelNames = [ + "o3", + "o3-mini", + "gpt-4.1", + "gpt-4o", + ] + + for modelName in modelNames { + await tachikoma.registerModel(name: modelName, factory: { + OpenAIModel(apiKey: "test-key", modelName: modelName) + }) + + do { + let model = try await tachikoma.getModel(modelName) + #expect(model is OpenAIModel) + } catch { + Issue.record("Failed to get model \(modelName): \(error)") + } + } + } + + @Test("API type selection") + func apiTypeSelection() async throws { + let tachikoma = Tachikoma.shared + + // Register models with different API preferences + await tachikoma.registerModel(name: "gpt-4.1-chat", factory: { + OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + }) + + await tachikoma.registerModel(name: "o3-responses", factory: { + OpenAIModel(apiKey: "test-key", modelName: "o3") + }) + + // Test that both can be retrieved + do { + let chatModel = try await tachikoma.getModel("gpt-4.1-chat") + #expect(chatModel is OpenAIModel) + + let responsesModel = try await tachikoma.getModel("o3-responses") + #expect(responsesModel is OpenAIModel) + } catch { + Issue.record("Failed to get models: \(error)") + } + } + + @Test("Lenient model name resolution") + func lenientModelNameResolution() async throws { + let tachikoma = Tachikoma.shared + + // Register base models + await tachikoma.registerModel(name: "gpt-4.1", factory: { + OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + }) + + await tachikoma.registerModel(name: "o3", factory: { + OpenAIModel(apiKey: "test-key", modelName: "o3") + }) + + // Test lenient name matching + let nameMapping = [ + "gpt": "gpt-4.1", + "gpt-4": "gpt-4.1", + "gpt4": "gpt-4.1", + "o3": "o3", + ] + + for (input, _) in nameMapping { + do { + let model = try await tachikoma.getModel(input) + #expect(model is OpenAIModel) + } catch { + Issue.record("Failed to resolve \(input): \(error)") + } + } + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/StreamingAndToolTests.swift b/Tests/TachikomaTests/StreamingAndToolTests.swift new file mode 100644 index 0000000..b40458d --- /dev/null +++ b/Tests/TachikomaTests/StreamingAndToolTests.swift @@ -0,0 +1,439 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("Streaming Types Tests") +struct StreamingTypesTests { + @Test("StreamEvent creation and types") + func streamEventCreationAndTypes() { + let events: [StreamEvent] = [ + .contentDelta("Hello"), + .contentComplete("Hello, world!"), + .toolCallDelta("tool_123", "get_", "{}"), + .toolCallComplete("tool_123", "get_weather", "{\"location\": \"SF\"}"), + .reasoningDelta("I need to think..."), + .reasoningComplete("I need to think about this carefully."), + .done, + .error("Network error"), + .metadata(["usage": ["tokens": 42]]), + ] + + #expect(events.count == 9) + + // Test specific event properties + if case let .contentDelta(text) = events[0] { + #expect(text == "Hello") + } else { + Issue.record("Expected contentDelta event") + } + + if case let .toolCallDelta(id, name, args) = events[2] { + #expect(id == "tool_123") + #expect(name == "get_") + #expect(args == "{}") + } else { + Issue.record("Expected toolCallDelta event") + } + + if case .done = events[6] { + // Expected + } else { + Issue.record("Expected done event") + } + } + + @Test("StreamEvent Codable") + func streamEventCodable() throws { + let originalEvents: [StreamEvent] = [ + .contentDelta("Hello"), + .toolCallComplete("tool_123", "get_weather", "{\"location\": \"SF\"}"), + .done, + .error("Network error"), + .metadata(["usage": ["tokens": 42]]), + ] + + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + for originalEvent in originalEvents { + let data = try encoder.encode(originalEvent) + let decodedEvent = try decoder.decode(StreamEvent.self, from: data) + + // Compare events (simplified comparison) + switch (originalEvent, decodedEvent) { + case let (.contentDelta(orig), .contentDelta(decoded)): + #expect(orig == decoded) + case let (.toolCallComplete(origId, origName, origArgs), .toolCallComplete(decodedId, decodedName, decodedArgs)): + #expect(origId == decodedId) + #expect(origName == decodedName) + #expect(origArgs == decodedArgs) + case (.done, .done): + // Expected + break + case let (.error(orig), .error(decoded)): + #expect(orig == decoded) + case (.metadata, .metadata): + // Metadata comparison is complex, just verify it decodes + break + default: + Issue.record("Event types don't match") + } + } + } + + @Test("StreamingResponseIterator behavior") + func streamingResponseIteratorBehavior() async throws { + // Create a mock stream of events + let events: [StreamEvent] = [ + .contentDelta("Hello"), + .contentDelta(" world"), + .contentComplete("Hello world!"), + .done + ] + + // Create an AsyncStream to simulate streaming + let stream = AsyncThrowingStream { continuation in + Task { + for event in events { + continuation.yield(event) + try await Task.sleep(nanoseconds: 1_000_000) // 1ms delay + } + continuation.finish() + } + } + + // Collect events from the stream + var collectedEvents: [StreamEvent] = [] + + do { + for try await event in stream { + collectedEvents.append(event) + } + } catch { + Issue.record("Stream iteration failed: \(error)") + } + + #expect(collectedEvents.count == 4) + + // Verify event sequence + if case let .contentDelta(text1) = collectedEvents[0] { + #expect(text1 == "Hello") + } else { + Issue.record("Expected first contentDelta") + } + + if case let .contentDelta(text2) = collectedEvents[1] { + #expect(text2 == " world") + } else { + Issue.record("Expected second contentDelta") + } + + if case .done = collectedEvents[3] { + // Expected + } else { + Issue.record("Expected done event at end") + } + } + + @Test("StreamEvent error handling") + func streamEventErrorHandling() async throws { + let stream = AsyncThrowingStream { continuation in + continuation.yield(.contentDelta("Hello")) + continuation.finish(throwing: TachikomaError.networkError("Connection lost")) + } + + var collectedEvents: [StreamEvent] = [] + var caughtError: Error? + + do { + for try await event in stream { + collectedEvents.append(event) + } + } catch { + caughtError = error + } + + #expect(collectedEvents.count == 1) + #expect(caughtError is TachikomaError) + + if let tachikomaError = caughtError as? TachikomaError, + case let .networkError(message) = tachikomaError { + #expect(message == "Connection lost") + } else { + Issue.record("Expected TachikomaError.networkError") + } + } +} + +// MARK: - Tool Definition Tests + +@Suite("Tool Definition Tests") +struct ToolDefinitionTests { + @Test("Simple tool definition") + func simpleToolDefinition() { + let tool = ToolDefinition( + function: FunctionDefinition( + name: "get_time", + description: "Get the current time", + parameters: ToolParameters( + properties: [:], + required: [] + ) + ) + ) + + #expect(tool.function.name == "get_time") + #expect(tool.function.description == "Get the current time") + #expect(tool.function.parameters.properties.isEmpty) + #expect(tool.function.parameters.required.isEmpty) + } + + @Test("Tool definition with parameters") + func toolDefinitionWithParameters() { + let tool = ToolDefinition( + function: FunctionDefinition( + name: "get_weather", + description: "Get weather information", + parameters: ToolParameters( + type: "object", + properties: [ + "location": ParameterSchema( + type: .string, + description: "The location to get weather for" + ), + "units": ParameterSchema( + type: .string, + description: "Temperature units", + enumValues: ["celsius", "fahrenheit"] + ), + "include_forecast": ParameterSchema( + type: .boolean, + description: "Include forecast data" + ) + ], + required: ["location"] + ) + ) + ) + + #expect(tool.function.name == "get_weather") + #expect(tool.function.parameters.properties.count == 3) + #expect(tool.function.parameters.required == ["location"]) + + // Check parameter schemas + let locationParam = tool.function.parameters.properties["location"] + #expect(locationParam?.type == .string) + #expect(locationParam?.description == "The location to get weather for") + + let unitsParam = tool.function.parameters.properties["units"] + #expect(unitsParam?.enumValues == ["celsius", "fahrenheit"]) + } + + @Test("ParameterSchema types") + func parameterSchemaTypes() { + let schemas: [ParameterSchema] = [ + ParameterSchema(type: .string, description: "A string value"), + ParameterSchema(type: .integer, description: "An integer value", minimum: 0, maximum: 100), + ParameterSchema(type: .number, description: "A number value"), + ParameterSchema(type: .boolean, description: "A boolean value"), + ParameterSchema(type: .array, description: "An array value"), + ParameterSchema(type: .object, description: "An object value"), + ] + + #expect(schemas.count == 6) + + // Test specific properties + let intSchema = schemas[1] + #expect(intSchema.type == .integer) + #expect(intSchema.minimum == 0) + #expect(intSchema.maximum == 100) + } + + @Test("Tool parameters codable") + func toolParametersCodable() throws { + let original = ToolParameters( + type: "object", + properties: [ + "name": ParameterSchema(type: .string, description: "Name field"), + "age": ParameterSchema(type: .integer, description: "Age field", minimum: 0), + ], + required: ["name"] + ) + + let encoder = JSONEncoder() + let data = try encoder.encode(original) + + let decoder = JSONDecoder() + let decoded = try decoder.decode(ToolParameters.self, from: data) + + #expect(decoded.type == original.type) + #expect(decoded.properties.count == original.properties.count) + #expect(decoded.required == original.required) + + let nameParam = decoded.properties["name"] + #expect(nameParam?.type == .string) + #expect(nameParam?.description == "Name field") + } + + @Test("Generic Tool creation") + func genericToolCreation() async throws { + // Context type for the tool + struct WeatherContext { + let apiKey: String + } + + // Create a generic tool + let weatherTool = Tool( + name: "get_weather", + description: "Get weather for a location", + parameters: ToolParameters( + properties: [ + "location": ParameterSchema(type: .string, description: "Location name") + ], + required: ["location"] + ) + ) { input, context in + // Simulate tool execution + var location = "Unknown" + if case let .dictionary(dict) = input { + location = dict["location"] as? String ?? "Unknown" + } + return .string("Weather in \(location): 72°F and sunny (API key: \(context.apiKey.prefix(3))...)") + } + + // Convert to tool definition + let toolDef = weatherTool.toToolDefinition() + + #expect(toolDef.function.name == "get_weather") + #expect(toolDef.function.description == "Get weather for a location") + + // Test tool execution + let context = WeatherContext(apiKey: "test-api-key-123") + let input = ToolInput.dictionary(["location": "San Francisco"]) + + let output = try await weatherTool.execute(input, context) + + if case let .string(result) = output { + #expect(result.contains("San Francisco")) + #expect(result.contains("test...")) + } else { + Issue.record("Expected string output") + } + } + + @Test("ToolCallItem creation") + func toolCallItemCreation() { + let toolCall = ToolCallItem( + id: "call_abc123", + type: .function, + function: FunctionCall( + name: "calculate", + arguments: "{\"expression\": \"2 + 2\"}" + ) + ) + + #expect(toolCall.id == "call_abc123") + #expect(toolCall.type == .function) + #expect(toolCall.function.name == "calculate") + #expect(toolCall.function.arguments == "{\"expression\": \"2 + 2\"}") + } + + @Test("ToolInput and ToolOutput") + func toolInputAndOutput() { + let input = ToolInput.dictionary([ + "location": "Paris", + "units": "celsius" + ]) + + if case let .dictionary(args) = input { + #expect(args["location"] as? String == "Paris") + #expect(args["units"] as? String == "celsius") + } else { + Issue.record("Expected dictionary input") + } + + let output = ToolOutput.string("Weather in Paris: 18°C and cloudy") + + if case let .string(content) = output { + #expect(content == "Weather in Paris: 18°C and cloudy") + } else { + Issue.record("Expected string output") + } + } + + @Test("FunctionCall argument parsing") + func functionCallArgumentParsing() throws { + let functionCall = FunctionCall( + name: "get_weather", + arguments: "{\"location\": \"Tokyo\", \"units\": \"celsius\"}" + ) + + #expect(functionCall.name == "get_weather") + #expect(functionCall.arguments == "{\"location\": \"Tokyo\", \"units\": \"celsius\"}") + + // Test parsing arguments as JSON + let data = functionCall.arguments.data(using: .utf8)! + let parsed = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(parsed?["location"] as? String == "Tokyo") + #expect(parsed?["units"] as? String == "celsius") + } + + @Test("ParameterSchema with complex properties") + func parameterSchemaWithComplexProperties() { + let schema = ParameterSchema( + type: .object, + description: "A complex object", + properties: [ + "nested_string": ParameterSchema(type: .string, description: "Nested string"), + "nested_array": ParameterSchema(type: .array, description: "Nested array"), + ] + ) + + #expect(schema.type == .object) + #expect(schema.properties?.count == 2) + + let nestedString = schema.properties?["nested_string"] as? ParameterSchema + #expect(nestedString?.type == .string) + #expect(nestedString?.description == "Nested string") + } + + @Test("Tool definition JSON serialization") + func toolDefinitionJSONSerialization() throws { + let tool = ToolDefinition( + function: FunctionDefinition( + name: "calculate", + description: "Perform a calculation", + parameters: ToolParameters( + type: "object", + properties: [ + "expression": ParameterSchema( + type: .string, + description: "Mathematical expression to evaluate" + ) + ], + required: ["expression"] + ) + ) + ) + + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + + let jsonData = try encoder.encode(tool) + let jsonString = String(data: jsonData, encoding: .utf8)! + + #expect(jsonString.contains("calculate")) + #expect(jsonString.contains("Perform a calculation")) + #expect(jsonString.contains("expression")) + #expect(jsonString.contains("Mathematical expression")) + + // Test round-trip encoding/decoding + let decoder = JSONDecoder() + let decoded = try decoder.decode(ToolDefinition.self, from: jsonData) + + #expect(decoded.function.name == tool.function.name) + #expect(decoded.function.description == tool.function.description) + #expect(decoded.function.parameters.required == tool.function.parameters.required) + } +} \ No newline at end of file diff --git a/Tests/TachikomaTests/TachikomaCoreTests.swift b/Tests/TachikomaTests/TachikomaCoreTests.swift new file mode 100644 index 0000000..06cd39f --- /dev/null +++ b/Tests/TachikomaTests/TachikomaCoreTests.swift @@ -0,0 +1,507 @@ +import Foundation +import Testing +@testable import Tachikoma + +@Suite("Tachikoma Core Tests") +struct TachikomaCoreTests { + @Test("Tachikoma singleton initialization") + func singletonInitialization() async throws { + let tachikoma1 = Tachikoma.shared + let tachikoma2 = Tachikoma.shared + + // Should be the same instance + #expect(tachikoma1 === tachikoma2) + } + + @Test("Model registration and retrieval") + func modelRegistrationAndRetrieval() async throws { + let tachikoma = Tachikoma.shared + + // Register a test model + await tachikoma.registerModel(name: "test-model", factory: { + OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + }) + + // Should be able to retrieve it + do { + let model = try await tachikoma.getModel("test-model") + #expect(model is OpenAIModel) + } catch { + Issue.record("Failed to retrieve registered model: \(error)") + } + } + + @Test("Model not found error") + func modelNotFoundError() async throws { + let tachikoma = Tachikoma.shared + + // Should throw error for non-existent model + do { + _ = try await tachikoma.getModel("nonexistent-model") + Issue.record("Expected error for nonexistent model") + } catch let error as TachikomaError { + switch error { + case .modelNotFound: + // Expected + break + default: + Issue.record("Expected modelNotFound error, got: \(error)") + } + } + } + + @Test("Provider configuration") + func providerConfiguration() async throws { + let tachikoma = Tachikoma.shared + + // Test OpenAI configuration + let openaiConfig = ProviderConfiguration.OpenAI( + apiKey: "test-key", + baseURL: URL(string: "https://api.openai.com/v1")! + ) + await tachikoma.configureOpenAI(openaiConfig) + + // Test Anthropic configuration + let anthropicConfig = ProviderConfiguration.Anthropic( + apiKey: "test-key", + baseURL: URL(string: "https://api.anthropic.com/v1")! + ) + await tachikoma.configureAnthropic(anthropicConfig) + + // Test Grok configuration + let grokConfig = ProviderConfiguration.Grok( + apiKey: "test-key", + baseURL: URL(string: "https://api.x.ai/v1")! + ) + await tachikoma.configureGrok(grokConfig) + + // Test Ollama configuration + let ollamaConfig = ProviderConfiguration.Ollama( + baseURL: URL(string: "http://localhost:11434")! + ) + await tachikoma.configureOllama(ollamaConfig) + } + + @Test("Model factory functions") + func modelFactoryFunctions() async throws { + let tachikoma = Tachikoma.shared + + // Test that factory functions work correctly + await tachikoma.registerModel(name: "factory-openai", factory: { + OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + }) + + await tachikoma.registerModel(name: "factory-anthropic", factory: { + AnthropicModel(apiKey: "test-key", modelName: "claude-opus-4-20250514") + }) + + await tachikoma.registerModel(name: "factory-grok", factory: { + GrokModel(apiKey: "test-key", modelName: "grok-4") + }) + + await tachikoma.registerModel(name: "factory-ollama", factory: { + OllamaModel(modelName: "llama3.3") + }) + + // Test retrieval + let modelNames = ["factory-openai", "factory-anthropic", "factory-grok", "factory-ollama"] + + for modelName in modelNames { + do { + let model = try await tachikoma.getModel(modelName) + #expect(model is any ModelInterface) + } catch { + Issue.record("Failed to get \(modelName): \(error)") + } + } + } + + @Test("Concurrent model access") + func concurrentModelAccess() async throws { + let tachikoma = Tachikoma.shared + + // Register a model + await tachikoma.registerModel(name: "concurrent-test", factory: { + OpenAIModel(apiKey: "test-key", modelName: "gpt-4.1") + }) + + // Access it concurrently + await withTaskGroup(of: Void.self) { group in + for i in 0..<10 { + group.addTask { + do { + let model = try await tachikoma.getModel("concurrent-test") + #expect(model is OpenAIModel) + } catch { + Issue.record("Concurrent access failed for iteration \(i): \(error)") + } + } + } + } + } +} + +// MARK: - Message Type Tests + +@Suite("Message Type Tests") +struct MessageTypeTests { + @Test("System message creation") + func systemMessageCreation() { + let message = Message.system(content: "You are a helpful assistant.") + + if case let .system(id, content) = message { + #expect(id == nil) + #expect(content == "You are a helpful assistant.") + } else { + Issue.record("Expected system message") + } + } + + @Test("User message with text content") + func userMessageWithTextContent() { + let message = Message.user(content: .text("Hello, AI!")) + + if case let .user(id, content) = message { + #expect(id == nil) + if case let .text(text) = content { + #expect(text == "Hello, AI!") + } else { + Issue.record("Expected text content") + } + } else { + Issue.record("Expected user message") + } + } + + @Test("User message with multimodal content") + func userMessageWithMultimodalContent() { + let imageData = Data([0xFF, 0xD8, 0xFF]) + let message = Message.user(content: .multimodal([ + .text("What's in this image?"), + .imageUrl(ImageUrl(base64: imageData.base64EncodedString())) + ])) + + if case let .user(_, content) = message { + if case let .multimodal(parts) = content { + #expect(parts.count == 2) + + if case let .text(text) = parts[0] { + #expect(text == "What's in this image?") + } else { + Issue.record("Expected text part at index 0") + } + + if case .imageUrl = parts[1] { + // Expected image part + } else { + Issue.record("Expected image part at index 1") + } + } else { + Issue.record("Expected multimodal content") + } + } else { + Issue.record("Expected user message") + } + } + + @Test("Assistant message creation") + func assistantMessageCreation() { + let message = Message.assistant(content: [ + .outputText("Hello! How can I help you?") + ]) + + if case let .assistant(id, content, status) = message { + #expect(id == nil) + #expect(status == .completed) + #expect(content.count == 1) + + if case let .outputText(text) = content[0] { + #expect(text == "Hello! How can I help you?") + } else { + Issue.record("Expected output text content") + } + } else { + Issue.record("Expected assistant message") + } + } + + @Test("Tool call message") + func toolCallMessage() { + let toolCall = ToolCallItem( + id: "call_123", + type: .function, + function: FunctionCall( + name: "get_weather", + arguments: "{\"location\": \"San Francisco\"}" + ) + ) + + let message = Message.assistant(content: [.toolCall(toolCall)]) + + if case let .assistant(_, content, _) = message { + if case let .toolCall(call) = content[0] { + #expect(call.id == "call_123") + #expect(call.function.name == "get_weather") + #expect(call.function.arguments == "{\"location\": \"San Francisco\"}") + } else { + Issue.record("Expected tool call content") + } + } else { + Issue.record("Expected assistant message") + } + } + + @Test("Tool result message") + func toolResultMessage() { + let message = Message.tool( + toolCallId: "call_123", + content: "The weather in San Francisco is 72°F and sunny." + ) + + if case let .tool(id, toolCallId, content) = message { + #expect(id == nil) + #expect(toolCallId == "call_123") + #expect(content == "The weather in San Francisco is 72°F and sunny.") + } else { + Issue.record("Expected tool message") + } + } + + @Test("Reasoning message") + func reasoningMessage() { + let message = Message.reasoning(content: "Let me think about this step by step...") + + if case let .reasoning(id, content) = message { + #expect(id == nil) + #expect(content == "Let me think about this step by step...") + } else { + Issue.record("Expected reasoning message") + } + } + + @Test("Message with custom ID") + func messageWithCustomID() { + let message = Message.user(id: "custom-123", content: .text("Hello")) + + if case let .user(id, _) = message { + #expect(id == "custom-123") + } else { + Issue.record("Expected user message with custom ID") + } + } + + @Test("Message type property") + func messageTypeProperty() { + let messages: [Message] = [ + .system(content: "System"), + .user(content: .text("User")), + .assistant(content: [.outputText("Assistant")]), + .tool(toolCallId: "call", content: "Tool"), + .reasoning(content: "Reasoning") + ] + + let expectedTypes: [MessageType] = [.system, .user, .assistant, .tool, .reasoning] + + for (message, expectedType) in zip(messages, expectedTypes) { + #expect(message.type == expectedType) + } + } +} + +// MARK: - MessageContent Tests + +@Suite("MessageContent Tests") +struct MessageContentTests { + @Test("Text content") + func textContent() { + let content = MessageContent.text("Hello, world!") + + if case let .text(text) = content { + #expect(text == "Hello, world!") + } else { + Issue.record("Expected text content") + } + } + + @Test("Image content with base64") + func imageContentWithBase64() { + let imageData = Data([0xFF, 0xD8, 0xFF]) + let imageUrl = ImageUrl(base64: imageData.base64EncodedString()) + let content = MessageContent.image(imageUrl) + + if case let .image(url) = content { + #expect(url.base64 == imageData.base64EncodedString()) + #expect(url.url == nil) + } else { + Issue.record("Expected image content") + } + } + + @Test("Image content with URL") + func imageContentWithURL() { + let imageUrl = ImageUrl(url: "https://example.com/image.jpg") + let content = MessageContent.image(imageUrl) + + if case let .image(url) = content { + #expect(url.url == "https://example.com/image.jpg") + #expect(url.base64 == nil) + } else { + Issue.record("Expected image content") + } + } + + @Test("Audio content") + func audioContent() { + let audioData = AudioContent( + transcript: "Hello, this is a test.", + duration: 5.0 + ) + let content = MessageContent.audio(audioData) + + if case let .audio(audio) = content { + #expect(audio.transcript == "Hello, this is a test.") + #expect(audio.duration == 5.0) + } else { + Issue.record("Expected audio content") + } + } + + @Test("File content") + func fileContent() { + let fileData = FileContent( + filename: "test.txt", + content: "File content here", + mimeType: "text/plain" + ) + let content = MessageContent.file(fileData) + + if case let .file(file) = content { + #expect(file.filename == "test.txt") + #expect(file.content == "File content here") + #expect(file.mimeType == "text/plain") + } else { + Issue.record("Expected file content") + } + } + + @Test("Multimodal content") + func multimodalContent() { + let parts: [MessageContentPart] = [ + .text("Describe this image:"), + .imageUrl(ImageUrl(url: "https://example.com/image.jpg")), + .text("What do you see?") + ] + let content = MessageContent.multimodal(parts) + + if case let .multimodal(contentParts) = content { + #expect(contentParts.count == 3) + + #expect(contentParts[0].type == "text") + #expect(contentParts[0].text == "Describe this image:") + + #expect(contentParts[1].type == "image") + #expect(contentParts[1].imageUrl != nil) + + #expect(contentParts[2].type == "text") + #expect(contentParts[2].text == "What do you see?") + } else { + Issue.record("Expected multimodal content") + } + } +} + +// MARK: - Error Handling Tests + +@Suite("Error Handling Tests") +struct ErrorHandlingTests { + @Test("TachikomaError cases") + func tachikomaErrorCases() { + let errors: [TachikomaError] = [ + .modelNotFound("test-model"), + .invalidRequest("Invalid parameters"), + .authenticationFailed, + .apiError(message: "Rate limit exceeded", code: "429"), + .networkError(underlying: URLError(.notConnectedToInternet)), + .configurationError("Missing configuration"), + .streamingError("Stream interrupted"), + ] + + // Verify all error cases can be created + #expect(errors.count == 9) + + // Test error descriptions + for error in errors { + let description = error.localizedDescription + #expect(!description.isEmpty) + } + } + + @Test("Error equality") + func errorEquality() { + let error1 = TachikomaError.modelNotFound("test") + let error2 = TachikomaError.modelNotFound("test") + let error3 = TachikomaError.modelNotFound("different") + + #expect(error1.localizedDescription == error2.localizedDescription) + #expect(error1.localizedDescription != error3.localizedDescription) + } +} + +// MARK: - Model Settings Tests + +@Suite("Model Settings Tests") +struct ModelSettingsTests { + @Test("Default model settings") + func defaultModelSettings() { + let settings = ModelSettings(modelName: "test-model") + + #expect(settings.modelName == "test-model") + #expect(settings.temperature == nil) + #expect(settings.maxTokens == nil) + #expect(settings.topP == nil) + #expect(settings.frequencyPenalty == nil) + #expect(settings.presencePenalty == nil) + #expect(settings.stopSequences == nil) + #expect(settings.seed == nil) + #expect(settings.toolChoice == nil) + #expect(settings.parallelToolCalls == nil) + #expect(settings.additionalParameters == nil) + } + + @Test("Custom model settings") + func customModelSettings() { + var additionalParams = ModelParameters() + additionalParams.set("apiType", value: "chat") + additionalParams.set("reasoningEffort", value: "medium") + additionalParams.set("reasoning", value: ["summary": "detailed"]) + additionalParams.set("logprobs", value: true) + additionalParams.set("topLogprobs", value: 5) + + let settings = ModelSettings( + modelName: "test-model", + temperature: 0.7, + topP: 0.9, + maxTokens: 1000, + frequencyPenalty: 0.1, + presencePenalty: 0.2, + stopSequences: ["STOP"], + toolChoice: .auto, + parallelToolCalls: true, + seed: 42, + additionalParameters: additionalParams + ) + + #expect(settings.modelName == "test-model") + #expect(settings.temperature == 0.7) + #expect(settings.maxTokens == 1000) + #expect(settings.topP == 0.9) + #expect(settings.frequencyPenalty == 0.1) + #expect(settings.presencePenalty == 0.2) + #expect(settings.stopSequences == ["STOP"]) + #expect(settings.seed == 42) + #expect(settings.toolChoice == .auto) + #expect(settings.parallelToolCalls == true) + #expect(settings.additionalParameters?.get("apiType") as? String == "chat") + #expect(settings.additionalParameters?.get("reasoningEffort") as? String == "medium") + } +} \ No newline at end of file diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..6681c8e Binary files /dev/null and b/assets/logo.png differ