changes
This commit is contained in:
+77
-15
@@ -1,11 +1,75 @@
|
|||||||
import { NextRequest } from 'next/server';
|
import { NextRequest } from 'next/server';
|
||||||
import { Message } from 'ollama';
|
import { Message } from 'ollama';
|
||||||
import ollama from '@/lib/ollama';
|
import ollama from '@/lib/ollama';
|
||||||
import { allTools, executeTool } from '@/lib/tools';
|
import { allTools, executeTool, getRegisteredTools } from '@/lib/tools';
|
||||||
|
|
||||||
// Maximum number of tool call iterations to prevent infinite loops
|
// Maximum number of tool call iterations to prevent infinite loops
|
||||||
const MAX_TOOL_ITERATIONS = 10;
|
const MAX_TOOL_ITERATIONS = 10;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse text-based tool calls from model output
|
||||||
|
* Supports formats like:
|
||||||
|
* - tool_name[ARGS]{"arg": "value"}
|
||||||
|
* - <tool_call>{"name": "tool_name", "arguments": {...}}</tool_call>
|
||||||
|
* - {"tool": "tool_name", "arguments": {...}}
|
||||||
|
*/
|
||||||
|
function parseTextToolCalls(
|
||||||
|
content: string
|
||||||
|
): Array<{ name: string; arguments: Record<string, unknown> }> {
|
||||||
|
const toolCalls: Array<{ name: string; arguments: Record<string, unknown> }> = [];
|
||||||
|
const registeredTools = getRegisteredTools();
|
||||||
|
|
||||||
|
// Pattern 1: tool_name[ARGS]{...}
|
||||||
|
const argsPattern = /(\w+)\[ARGS\](\{[\s\S]*?\})/g;
|
||||||
|
let match;
|
||||||
|
while ((match = argsPattern.exec(content)) !== null) {
|
||||||
|
const toolName = match[1];
|
||||||
|
try {
|
||||||
|
const args = JSON.parse(match[2]);
|
||||||
|
if (registeredTools.includes(toolName)) {
|
||||||
|
toolCalls.push({ name: toolName, arguments: args });
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Invalid JSON, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 2: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||||
|
const toolCallTagPattern = /<tool_call>([\s\S]*?)<\/tool_call>/g;
|
||||||
|
while ((match = toolCallTagPattern.exec(content)) !== null) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(match[1]);
|
||||||
|
if (parsed.name && registeredTools.includes(parsed.name)) {
|
||||||
|
toolCalls.push({
|
||||||
|
name: parsed.name,
|
||||||
|
arguments: parsed.arguments || {},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Invalid JSON, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 3: Standalone JSON with tool field
|
||||||
|
const jsonToolPattern = /\{[\s\S]*?"(?:tool|function)"[\s\S]*?\}/g;
|
||||||
|
while ((match = jsonToolPattern.exec(content)) !== null) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(match[0]);
|
||||||
|
const toolName = parsed.tool || parsed.function || parsed.name;
|
||||||
|
if (toolName && registeredTools.includes(toolName)) {
|
||||||
|
toolCalls.push({
|
||||||
|
name: toolName,
|
||||||
|
arguments: parsed.arguments || parsed.params || parsed.parameters || {},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Invalid JSON, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCalls;
|
||||||
|
}
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
export async function POST(request: NextRequest) {
|
||||||
try {
|
try {
|
||||||
const { model, messages, enableTools = true } = await request.json();
|
const { model, messages, enableTools = true } = await request.json();
|
||||||
@@ -50,7 +114,7 @@ export async function POST(request: NextRequest) {
|
|||||||
controller.enqueue(encoder.encode(text));
|
controller.enqueue(encoder.encode(text));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls in the final chunk
|
// Check for native tool calls in the final chunk
|
||||||
if (chunk.message?.tool_calls) {
|
if (chunk.message?.tool_calls) {
|
||||||
toolCalls = chunk.message.tool_calls.map((tc) => ({
|
toolCalls = chunk.message.tool_calls.map((tc) => ({
|
||||||
name: tc.function.name,
|
name: tc.function.name,
|
||||||
@@ -59,38 +123,36 @@ export async function POST(request: NextRequest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no native tool calls, try to parse text-based tool calls
|
||||||
|
if (toolCalls.length === 0 && enableTools) {
|
||||||
|
toolCalls = parseTextToolCalls(fullContent);
|
||||||
|
}
|
||||||
|
|
||||||
// If no tool calls, we're done
|
// If no tool calls, we're done
|
||||||
if (toolCalls.length === 0) {
|
if (toolCalls.length === 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process tool calls
|
// Process tool calls
|
||||||
// Send a marker so frontend knows tool calls are happening
|
controller.enqueue(encoder.encode('\n\n'));
|
||||||
controller.enqueue(encoder.encode('\n\n---TOOL_CALLS---\n'));
|
|
||||||
|
|
||||||
// Add the assistant's response with tool calls to messages
|
// Add the assistant's response to messages
|
||||||
workingMessages.push({
|
workingMessages.push({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: fullContent,
|
content: fullContent,
|
||||||
tool_calls: toolCalls.map((tc) => ({
|
|
||||||
function: {
|
|
||||||
name: tc.name,
|
|
||||||
arguments: tc.arguments,
|
|
||||||
},
|
|
||||||
})),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Execute each tool and collect results
|
// Execute each tool and collect results
|
||||||
for (const toolCall of toolCalls) {
|
for (const toolCall of toolCalls) {
|
||||||
controller.enqueue(encoder.encode(`\n**Using tool: ${toolCall.name}**\n`));
|
controller.enqueue(encoder.encode(`**Using tool: ${toolCall.name}**\n`));
|
||||||
|
|
||||||
const result = await executeTool(toolCall.name, toolCall.arguments);
|
const result = await executeTool(toolCall.name, toolCall.arguments);
|
||||||
|
|
||||||
// Send tool result to stream
|
// Send tool result to stream
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
controller.enqueue(encoder.encode(`\`\`\`\n${result.result}\n\`\`\`\n`));
|
controller.enqueue(encoder.encode(`\`\`\`\n${result.result}\n\`\`\`\n\n`));
|
||||||
} else {
|
} else {
|
||||||
controller.enqueue(encoder.encode(`Error: ${result.error}\n`));
|
controller.enqueue(encoder.encode(`Error: ${result.error}\n\n`));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool result to messages for next iteration
|
// Add tool result to messages for next iteration
|
||||||
@@ -100,7 +162,7 @@ export async function POST(request: NextRequest) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.enqueue(encoder.encode('\n---END_TOOL_CALLS---\n\n'));
|
// Continue to let the model respond with the tool results
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.close();
|
controller.close();
|
||||||
|
|||||||
Reference in New Issue
Block a user