diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index 6a68603..078cdd3 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -1,11 +1,75 @@ import { NextRequest } from 'next/server'; import { Message } from 'ollama'; import ollama from '@/lib/ollama'; -import { allTools, executeTool } from '@/lib/tools'; +import { allTools, executeTool, getRegisteredTools } from '@/lib/tools'; // Maximum number of tool call iterations to prevent infinite loops const MAX_TOOL_ITERATIONS = 10; +/** + * Parse text-based tool calls from model output + * Supports formats like: + * - tool_name[ARGS]{"arg": "value"} + * - {"name": "tool_name", "arguments": {...}} + * - {"tool": "tool_name", "arguments": {...}} + */ +function parseTextToolCalls( + content: string +): Array<{ name: string; arguments: Record }> { + const toolCalls: Array<{ name: string; arguments: Record }> = []; + const registeredTools = getRegisteredTools(); + + // Pattern 1: tool_name[ARGS]{...} + const argsPattern = /(\w+)\[ARGS\](\{[\s\S]*?\})/g; + let match; + while ((match = argsPattern.exec(content)) !== null) { + const toolName = match[1]; + try { + const args = JSON.parse(match[2]); + if (registeredTools.includes(toolName)) { + toolCalls.push({ name: toolName, arguments: args }); + } + } catch { + // Invalid JSON, skip + } + } + + // Pattern 2: {"name": "...", "arguments": {...}} + const toolCallTagPattern = /([\s\S]*?)<\/tool_call>/g; + while ((match = toolCallTagPattern.exec(content)) !== null) { + try { + const parsed = JSON.parse(match[1]); + if (parsed.name && registeredTools.includes(parsed.name)) { + toolCalls.push({ + name: parsed.name, + arguments: parsed.arguments || {}, + }); + } + } catch { + // Invalid JSON, skip + } + } + + // Pattern 3: Standalone JSON with tool field + const jsonToolPattern = /\{[\s\S]*?"(?:tool|function)"[\s\S]*?\}/g; + while ((match = jsonToolPattern.exec(content)) !== null) { + try { + const parsed = JSON.parse(match[0]); + const toolName = parsed.tool || parsed.function || parsed.name; + if (toolName && registeredTools.includes(toolName)) { + toolCalls.push({ + name: toolName, + arguments: parsed.arguments || parsed.params || parsed.parameters || {}, + }); + } + } catch { + // Invalid JSON, skip + } + } + + return toolCalls; +} + export async function POST(request: NextRequest) { try { const { model, messages, enableTools = true } = await request.json(); @@ -50,7 +114,7 @@ export async function POST(request: NextRequest) { controller.enqueue(encoder.encode(text)); } - // Check for tool calls in the final chunk + // Check for native tool calls in the final chunk if (chunk.message?.tool_calls) { toolCalls = chunk.message.tool_calls.map((tc) => ({ name: tc.function.name, @@ -59,38 +123,36 @@ export async function POST(request: NextRequest) { } } + // If no native tool calls, try to parse text-based tool calls + if (toolCalls.length === 0 && enableTools) { + toolCalls = parseTextToolCalls(fullContent); + } + // If no tool calls, we're done if (toolCalls.length === 0) { break; } // Process tool calls - // Send a marker so frontend knows tool calls are happening - controller.enqueue(encoder.encode('\n\n---TOOL_CALLS---\n')); + controller.enqueue(encoder.encode('\n\n')); - // Add the assistant's response with tool calls to messages + // Add the assistant's response to messages workingMessages.push({ role: 'assistant', content: fullContent, - tool_calls: toolCalls.map((tc) => ({ - function: { - name: tc.name, - arguments: tc.arguments, - }, - })), }); // Execute each tool and collect results for (const toolCall of toolCalls) { - controller.enqueue(encoder.encode(`\n**Using tool: ${toolCall.name}**\n`)); + controller.enqueue(encoder.encode(`**Using tool: ${toolCall.name}**\n`)); const result = await executeTool(toolCall.name, toolCall.arguments); // Send tool result to stream if (result.success) { - controller.enqueue(encoder.encode(`\`\`\`\n${result.result}\n\`\`\`\n`)); + controller.enqueue(encoder.encode(`\`\`\`\n${result.result}\n\`\`\`\n\n`)); } else { - controller.enqueue(encoder.encode(`Error: ${result.error}\n`)); + controller.enqueue(encoder.encode(`Error: ${result.error}\n\n`)); } // Add tool result to messages for next iteration @@ -100,7 +162,7 @@ export async function POST(request: NextRequest) { }); } - controller.enqueue(encoder.encode('\n---END_TOOL_CALLS---\n\n')); + // Continue to let the model respond with the tool results } controller.close();