聊聊Spring AI的Tool Calling

本文主要研究一下Spring AI的Tool Calling

ToolCallback

org/springframework/ai/tool/ToolCallback.java

public interface ToolCallback extends FunctionCallback {

    /**
     * Definition used by the AI model to determine when and how to call the tool.
     */
    ToolDefinition getToolDefinition();

    /**
     * Metadata providing additional information on how to handle the tool.
     */
    default ToolMetadata getToolMetadata() {
        return ToolMetadata.builder().build();
    }

    /**
     * Execute tool with the given input and return the result to send back to the AI
     * model.
     */
    String call(String toolInput);

    /**
     * Execute tool with the given input and context, and return the result to send back
     * to the AI model.
     */
    default String call(String toolInput, @Nullable ToolContext tooContext) {
        if (tooContext != null && !tooContext.getContext().isEmpty()) {
            throw new UnsupportedOperationException("Tool context is not supported!");
        }
        return call(toolInput);
    }

    @Override
    @Deprecated // Call getToolDefinition().name() instead
    default String getName() {
        return getToolDefinition().name();
    }

    @Override
    @Deprecated // Call getToolDefinition().description() instead
    default String getDescription() {
        return getToolDefinition().description();
    }

    @Override
    @Deprecated // Call getToolDefinition().inputTypeSchema() instead
    default String getInputTypeSchema() {
        return getToolDefinition().inputSchema();
    }

}

ToolCallback继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback

MethodToolCallback

org/springframework/ai/tool/method/MethodToolCallback.java

public class MethodToolCallback implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);

    private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Method toolMethod;

    @Nullable
    private final Object toolObject;

    private final ToolCallResultConverter toolCallResultConverter;

    public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
            @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
        Assert.notNull(toolDefinition, "toolDefinition cannot be null");
        Assert.notNull(toolMethod, "toolMethod cannot be null");
        Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
                "toolObject cannot be null for non-static methods");
        this.toolDefinition = toolDefinition;
        this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
        this.toolMethod = toolMethod;
        this.toolObject = toolObject;
        this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
                : DEFAULT_RESULT_CONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
        return toolMetadata;
    }

    @Override
    public String call(String toolInput) {
        return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        Assert.hasText(toolInput, "toolInput cannot be null or empty");

        logger.debug("Starting execution of tool: {}", toolDefinition.name());

        validateToolContextSupport(toolContext);

        Map<String, Object> toolArguments = extractToolArguments(toolInput);

        Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

        Object result = callMethod(methodArguments);

        logger.debug("Successful execution of tool: {}", toolDefinition.name());

        Type returnType = toolMethod.getGenericReturnType();

        return toolCallResultConverter.convert(result, returnType);
    }

    @Nullable
    private Object callMethod(Object[] methodArguments) {
        if (isObjectNotPublic() || isMethodNotPublic()) {
            toolMethod.setAccessible(true);
        }

        Object result;
        try {
            result = toolMethod.invoke(toolObject, methodArguments);
        }
        catch (IllegalAccessException ex) {
            throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
        }
        catch (InvocationTargetException ex) {
            throw new ToolExecutionException(toolDefinition, ex.getCause());
        }
        return result;
    }

    //......
}   

MethodToolCallback实现了ToolCallback接口,其call方法通过buildMethodArguments构建参数,再通过callMethod获取返回值,最后通过toolCallResultConverter.convert来转换返回值类型;callMethod主要是通过反射调用执行
目前如下几个类型作为参数或者返回类型不支持

  • Optional
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux)
  • Functional types (e.g. Function, Supplier, Consumer).

FunctionToolCallback

org/springframework/ai/tool/function/FunctionToolCallback.java

public class FunctionToolCallback<I, O> implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);

    private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Type toolInputType;

    private final BiFunction<I, ToolContext, O> toolFunction;

    private final ToolCallResultConverter toolCallResultConverter;

    public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
            BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
        Assert.notNull(toolDefinition, "toolDefinition cannot be null");
        Assert.notNull(toolInputType, "toolInputType cannot be null");
        Assert.notNull(toolFunction, "toolFunction cannot be null");
        this.toolDefinition = toolDefinition;
        this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
        this.toolFunction = toolFunction;
        this.toolInputType = toolInputType;
        this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
                : DEFAULT_RESULT_CONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
        return toolMetadata;
    }

    @Override
    public String call(String toolInput) {
        return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        Assert.hasText(toolInput, "toolInput cannot be null or empty");

        logger.debug("Starting execution of tool: {}", toolDefinition.name());

        I request = JsonParser.fromJson(toolInput, toolInputType);
        O response = toolFunction.apply(request, toolContext);

        logger.debug("Successful execution of tool: {}", toolDefinition.name());

        return toolCallResultConverter.convert(response, null);
    }

    @Override
    public String toString() {
        return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
    }

    //......
}   

FunctionToolCallback实现了ToolCallback接口,其call方法通过JsonParser.fromJson(toolInput, toolInputType)转换请求参数,再通过toolFunction.apply(request, toolContext)获取返回结果,最后通过toolCallResultConverter.convert(response, null)来转换结果
目前如下类型不支持作为参数或者返回类型

  • Primitive types
  • Optional
  • Collection types (e.g. List, Map, Array, Set)
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux).

示例

class DateTimeTools {

    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

MethodToolCallback

Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolCallback toolCallback = MethodToolCallback.builder()
    .toolDefinition(ToolDefinition.builder(method)
            .description("Get the current date and time in the user's timezone")
            .build())
    .toolMethod(method)
    .toolObject(new DateTimeTools())
    .build();

亦或是使用@Tool注解

class DateTimeTools {

    @Tool(description = "Get the current date and time in the user's timezone")
    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

亦或是通过ToolCallbacks.from方法

ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());

FunctionToolCallback

public class WeatherService implements Function<WeatherRequest, WeatherResponse> {
    public WeatherResponse apply(WeatherRequest request) {
        return new WeatherResponse(30.0, Unit.C);
    }
}

ToolCallback toolCallback = FunctionToolCallback
    .builder("currentWeather", new WeatherService())
    .description("Get the weather in location")
    .inputType(WeatherRequest.class)
    .build();

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools(toolCallback)
    .call()
    .content();    

亦或设置到chatOptions

ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(toolCallback)
    .build():
Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions);
chatModel.call(prompt);

亦或是注册到spring中

@Configuration(proxyBeanMethods = false)
class WeatherTools {

    WeatherService weatherService = new WeatherService();

    @Bean
    @Description("Get the weather in location")
    Function<WeatherRequest, WeatherResponse> currentWeather() {
        return weatherService;
    }

}

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools("currentWeather")
    .call()
    .content();

Tool Specification

ToolDefinition

org/springframework/ai/tool/definition/ToolDefinition.java

public interface ToolDefinition {

    /**
     * The tool name. Unique within the tool set provided to a model.
     */
    String name();

    /**
     * The tool description, used by the AI model to determine what the tool does.
     */
    String description();

    /**
     * The schema of the parameters used to call the tool.
     */
    String inputSchema();

    /**
     * Create a default {@link ToolDefinition} builder.
     */
    static DefaultToolDefinition.Builder builder() {
        return DefaultToolDefinition.builder();
    }

    /**
     * Create a default {@link ToolDefinition} builder from a {@link Method}.
     */
    static DefaultToolDefinition.Builder builder(Method method) {
        Assert.notNull(method, "method cannot be null");
        return DefaultToolDefinition.builder()
            .name(ToolUtils.getToolName(method))
            .description(ToolUtils.getToolDescription(method))
            .inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
    }

    /**
     * Create a default {@link ToolDefinition} instance from a {@link Method}.
     */
    static ToolDefinition from(Method method) {
        return ToolDefinition.builder(method).build();
    }

}

ToolDefinition定义了name、description、inputSchema属性,它提供了builder方法可以基于Method来构建DefaultToolDefinition

示例

Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolDefinition toolDefinition = ToolDefinition.builder(method)
    .name("currentDateTime")
    .description("Get the current date and time in the user's timezone")
    .inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
    .build();

JSON Schema

Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的请求参数的json schema,对于参数描述可以使用如下注解:

@ToolParam(description = "…") from Spring AI
@JsonClassDescription(description = "…") from Jackson
@JsonPropertyDescription(description = "…") from Jackson
@Schema(description = "…") from Swagger.

示例

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.context.i18n.LocaleContextHolder;

class DateTimeTools {

    @Tool(description = "Set a user alarm for the given time")
    void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) {
        LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME);
        System.out.println("Alarm set for " + alarmTime);
    }

}

对于是否必填,可以使用如下注解:

@ToolParam(required = false) from Spring AI
@JsonProperty(required = false) from Jackson
@Schema(required = false) from Swagger
@Nullable from Spring Framework.

示例:

class CustomerTools {

    @Tool(description = "Update customer information")
    void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) {
        System.out.println("Updated info for customer with id: " + id);
    }

}

Result Conversion

Spring AI提供了ToolCallResultConverter用于将tool calling的返回数据进行转换再发送给AI模型
org/springframework/ai/tool/execution/ToolCallResultConverter.java

@FunctionalInterface
public interface ToolCallResultConverter {

    /**
     * Given an Object returned by a tool, convert it to a String compatible with the
     * given class type.
     */
    String convert(@Nullable Object result, @Nullable Type returnType);

}

它有一个默认实现DefaultToolCallResultConverter

public final class DefaultToolCallResultConverter implements ToolCallResultConverter {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);

    @Override
    public String convert(@Nullable Object result, @Nullable Type returnType) {
        if (returnType == Void.TYPE) {
            logger.debug("The tool has no return type. Converting to conventional response.");
            return "Done";
        }
        else {
            logger.debug("Converting tool result to JSON.");
            return JsonParser.toJson(result);
        }
    }

}

DefaultToolCallResultConverter采用的是JsonParser.toJson(result),将返回类型转换为json字符串

也可以自己指定,比如

class CustomerTools {

    @Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

Tool Context

Spring AI提供了ToolContext,可以将附加的上下文信息传递给工具。这一功能允许开发者提供额外的、由用户提供的数据,这些数据可以在工具执行过程中与AI模型传递的工具参数一起使用。使用示例如下:

class CustomerTools {

    @Tool(description = "Retrieve customer information")
    Customer getCustomerInfo(Long id, ToolContext toolContext) {
        return customerRepository.findById(id, toolContext.get("tenantId"));
    }

}

对于chatClient:

ChatModel chatModel = ...

String response = ChatClient.create(chatModel)
        .prompt("Tell me more about the customer with ID 42")
        .tools(new CustomerTools())
        .toolContext(Map.of("tenantId", "acme"))
        .call()
        .content();

System.out.println(response);

对于chatModel:

ChatModel chatModel = ...
ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools());
ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(customerTools)
    .toolContext(Map.of("tenantId", "acme"))
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
chatModel.call(prompt);

Return Direct

Spring AI提供了returnDirect参数,设置为true则会将tool calling的返回直接返回,而不是经过大模型再返回。默认是返回给AI模型,AI模型处理之后再返回给用户。
示例如下:

class CustomerTools {

    @Tool(description = "Retrieve customer information", returnDirect = true)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

亦或是

ToolMetadata toolMetadata = ToolMetadata.builder()
    .returnDirect(true)
    .build();

ToolCallingManager

org/springframework/ai/model/tool/ToolCallingManager.java

public interface ToolCallingManager {

    /**
     * Resolve the tool definitions from the model's tool calling options.
     */
    List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);

    /**
     * Execute the tool calls requested by the model.
     */
    ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);

    /**
     * Create a default {@link ToolCallingManager} builder.
     */
    static DefaultToolCallingManager.Builder builder() {
        return DefaultToolCallingManager.builder();
    }

}

ToolCallingManager定义了resolveToolDefinitions、executeToolCalls方法,默认实现是DefaultToolCallingManager

DefaultToolCallingManager

org/springframework/ai/model/tool/DefaultToolCallingManager.java

public class DefaultToolCallingManager implements ToolCallingManager {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);

    // @formatter:off

    private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
            = ObservationRegistry.NOOP;

    private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
            = new DelegatingToolCallbackResolver(List.of());

    private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
            = DefaultToolExecutionExceptionProcessor.builder().build();

    // @formatter:on

    private final ObservationRegistry observationRegistry;

    private final ToolCallbackResolver toolCallbackResolver;

    private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

    public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
            ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
        Assert.notNull(observationRegistry, "observationRegistry cannot be null");
        Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
        Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

        this.observationRegistry = observationRegistry;
        this.toolCallbackResolver = toolCallbackResolver;
        this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
    }

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
        Assert.notNull(chatOptions, "chatOptions cannot be null");

        List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
        for (String toolName : chatOptions.getToolNames()) {
            // Skip the tool if it is already present in the request toolCallbacks.
            // That might happen if a tool is defined in the options
            // both as a ToolCallback and as a tool name.
            if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
                continue;
            }
            FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }
            toolCallbacks.add(toolCallback);
        }

        return toolCallbacks.stream().map(functionCallback -> {
            if (functionCallback instanceof ToolCallback toolCallback) {
                return toolCallback.getToolDefinition();
            }
            else {
                return ToolDefinition.builder()
                    .name(functionCallback.getName())
                    .description(functionCallback.getDescription())
                    .inputSchema(functionCallback.getInputTypeSchema())
                    .build();
            }
        }).toList();
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
        Assert.notNull(prompt, "prompt cannot be null");
        Assert.notNull(chatResponse, "chatResponse cannot be null");

        Optional<Generation> toolCallGeneration = chatResponse.getResults()
            .stream()
            .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
            .findFirst();

        if (toolCallGeneration.isEmpty()) {
            throw new IllegalStateException("No tool call requested by the chat model");
        }

        AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();

        ToolContext toolContext = buildToolContext(prompt, assistantMessage);

        InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
                toolContext);

        List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
                assistantMessage, internalToolExecutionResult.toolResponseMessage());

        return ToolExecutionResult.builder()
            .conversationHistory(conversationHistory)
            .returnDirect(internalToolExecutionResult.returnDirect())
            .build();
    }

    //......

    /**
     * Execute the tool call and return the response message. To ensure backward
     * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are
     * supported.
     */
    private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
            ToolContext toolContext) {
        List<FunctionCallback> toolCallbacks = List.of();
        if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
            toolCallbacks = toolCallingChatOptions.getToolCallbacks();
        }
        else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
            toolCallbacks = functionOptions.getFunctionCallbacks();
        }

        List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

        Boolean returnDirect = null;

        for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

            logger.debug("Executing tool call: {}", toolCall.name());

            String toolName = toolCall.name();
            String toolInputArguments = toolCall.arguments();

            FunctionCallback toolCallback = toolCallbacks.stream()
                .filter(tool -> toolName.equals(tool.getName()))
                .findFirst()
                .orElseGet(() -> toolCallbackResolver.resolve(toolName));

            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }

            if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
                returnDirect = callback.getToolMetadata().returnDirect();
            }
            else if (toolCallback instanceof ToolCallback callback) {
                returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
            }
            else if (returnDirect == null) {
                // This is a temporary solution to ensure backward compatibility with
                // FunctionCallback.
                // TODO: remove this block when FunctionCallback is removed.
                returnDirect = false;
            }

            String toolResult;
            try {
                toolResult = toolCallback.call(toolInputArguments, toolContext);
            }
            catch (ToolExecutionException ex) {
                toolResult = toolExecutionExceptionProcessor.process(ex);
            }

            toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
        }

        return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
    }

    private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
            AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
        List<Message> messages = new ArrayList<>(previousMessages);
        messages.add(assistantMessage);
        messages.add(toolResponseMessage);
        return messages;
    }   
}   

DefaultToolCallingManager的resolveToolDefinitions方法会通过toolCallbackResolver来解析chatOptions.getToolCallbacks(),executeToolCalls方法先筛选出需要toolCall支持的assistantMessage,然后构建toolContext,再执行executeToolCall获取执行结构,再基于此构建conversationHistory。
executeToolCall方法遍历assistantMessage.getToolCalls(),通过toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通过toolCallback.call(toolInputArguments, toolContext)获取结果,如果出现ToolExecutionException,则通过toolExecutionExceptionProcessor.process(ex)去做兜底操作

ToolExecutionExceptionProcessor

org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java

@FunctionalInterface
public interface ToolExecutionExceptionProcessor {

    /**
     * Convert an exception thrown by a tool to a String that can be sent back to the AI
     * model or throw an exception to be handled by the caller.
     */
    String process(ToolExecutionException exception);

}

ToolExecutionExceptionProcessor定义process

DefaultToolExecutionExceptionProcessor

public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

    private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);

    private static final boolean DEFAULT_ALWAYS_THROW = false;

    private final boolean alwaysThrow;

    public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
        this.alwaysThrow = alwaysThrow;
    }

    @Override
    public String process(ToolExecutionException exception) {
        Assert.notNull(exception, "exception cannot be null");
        if (alwaysThrow) {
            throw exception;
        }
        logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
                exception.getMessage());
        return exception.getMessage();
    }

    //......
}   

DefaultToolExecutionExceptionProcessor对于alwaysThrow为true的(默认为false)直接抛出该异常,否则返回异常的信息

User-Controlled Tool Execution

ToolCallingChatOptions提供了internalToolExecutionEnabled属性,设置为false可以自行控制对tool的调用过程(也可以自己实现ToolExecutionEligibilityPredicate去控制),示例如下:

ChatModel chatModel = ...
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();

ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(new CustomerTools())
    .internalToolExecutionEnabled(false)
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);

ChatResponse chatResponse = chatModel.call(prompt);

while (chatResponse.hasToolCalls()) {
    ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

    prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions);

    chatResponse = chatModel.call(prompt);
}

System.out.println(chatResponse.getResult().getOutput().getText());

这里自己通过toolCallingManager.executeToolCalls去执行,再传递给chatModel

ToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java

public interface ToolCallbackResolver {

    /**
     * Resolve the {@link FunctionCallback} for the given tool name.
     */
    @Nullable
    FunctionCallback resolve(String toolName);

}

ToolCallbackResolver定义了resolve方法,用于根据toolName来获取对应的FunctionCallback,它有三种实现,分别是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver

StaticToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java

public class StaticToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);

    private final Map<String, FunctionCallback> toolCallbacks = new HashMap<>();

    public StaticToolCallbackResolver(List<FunctionCallback> toolCallbacks) {
        Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
        Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");

        toolCallbacks.forEach(callback -> {
            if (callback instanceof ToolCallback toolCallback) {
                this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback);
            }
            this.toolCallbacks.put(callback.getName(), callback);
        });
    }

    @Override
    public FunctionCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");
        logger.debug("ToolCallback resolution attempt from static registry");
        return toolCallbacks.get(toolName);
    }

}

StaticToolCallbackResolver依据构造器传入的List<FunctionCallback>来寻找

SpringBeanToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java

public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);

    private static final Map<String, ToolCallback> toolCallbacksCache = new HashMap<>();

    private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA;

    private final GenericApplicationContext applicationContext;

    private final SchemaType schemaType;

    public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
            @Nullable SchemaType schemaType) {
        Assert.notNull(applicationContext, "applicationContext cannot be null");

        this.applicationContext = applicationContext;
        this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE;
    }

    @Override
    public ToolCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");

        logger.debug("ToolCallback resolution attempt from Spring application context");

        ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);

        if (resolvedToolCallback != null) {
            return resolvedToolCallback;
        }

        ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName);
        ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
                ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);

        String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
        Object bean = applicationContext.getBean(toolName);

        resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);

        toolCallbacksCache.put(toolName, resolvedToolCallback);

        return resolvedToolCallback;
    }

    //......
}   

SpringBeanToolCallbackResolver使用GenericApplicationContext根据toolName去spring容器查找,找到的话会放到toolCallbacksCache中

DelegatingToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java

public class DelegatingToolCallbackResolver implements ToolCallbackResolver {

    private final List<ToolCallbackResolver> toolCallbackResolvers;

    public DelegatingToolCallbackResolver(List<ToolCallbackResolver> toolCallbackResolvers) {
        Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
        Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
        this.toolCallbackResolvers = toolCallbackResolvers;
    }

    @Override
    @Nullable
    public FunctionCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");

        for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
            FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
            if (toolCallback != null) {
                return toolCallback;
            }
        }
        return null;
    }

}

DelegatingToolCallbackResolver把resolve方法委托给了构造器传入的其他toolCallbackResolvers

小结

Spring AI提供了ToolCallback来实现Tool Calling,它继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback。

整个Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct
整个Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling

doc

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 226,197评论 6 524
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 97,254评论 3 410
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 173,718评论 0 370
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 61,801评论 1 305
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 70,732评论 6 404
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 54,200评论 1 318
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 42,389评论 3 433
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 41,484评论 0 282
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 48,024评论 1 328
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 40,013评论 3 352
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 42,125评论 1 359
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 37,698评论 5 353
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 43,407评论 3 342
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 33,795评论 0 25
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 34,996评论 1 278
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 50,724评论 3 384
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 47,150评论 2 368

推荐阅读更多精彩内容