序
本文主要研究一下Spring AI的Tool Calling
ToolCallback
org/springframework/ai/tool/ToolCallback.java
public interface ToolCallback extends FunctionCallback {
/**
* Definition used by the AI model to determine when and how to call the tool.
*/
ToolDefinition getToolDefinition();
/**
* Metadata providing additional information on how to handle the tool.
*/
default ToolMetadata getToolMetadata() {
return ToolMetadata.builder().build();
}
/**
* Execute tool with the given input and return the result to send back to the AI
* model.
*/
String call(String toolInput);
/**
* Execute tool with the given input and context, and return the result to send back
* to the AI model.
*/
default String call(String toolInput, @Nullable ToolContext tooContext) {
if (tooContext != null && !tooContext.getContext().isEmpty()) {
throw new UnsupportedOperationException("Tool context is not supported!");
}
return call(toolInput);
}
@Override
@Deprecated // Call getToolDefinition().name() instead
default String getName() {
return getToolDefinition().name();
}
@Override
@Deprecated // Call getToolDefinition().description() instead
default String getDescription() {
return getToolDefinition().description();
}
@Override
@Deprecated // Call getToolDefinition().inputTypeSchema() instead
default String getInputTypeSchema() {
return getToolDefinition().inputSchema();
}
}
ToolCallback继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback
MethodToolCallback
org/springframework/ai/tool/method/MethodToolCallback.java
public class MethodToolCallback implements ToolCallback {
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);
private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
private final ToolDefinition toolDefinition;
private final ToolMetadata toolMetadata;
private final Method toolMethod;
@Nullable
private final Object toolObject;
private final ToolCallResultConverter toolCallResultConverter;
public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
@Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
Assert.notNull(toolMethod, "toolMethod cannot be null");
Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
"toolObject cannot be null for non-static methods");
this.toolDefinition = toolDefinition;
this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
this.toolMethod = toolMethod;
this.toolObject = toolObject;
this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
: DEFAULT_RESULT_CONVERTER;
}
@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}
@Override
public ToolMetadata getToolMetadata() {
return toolMetadata;
}
@Override
public String call(String toolInput) {
return call(toolInput, null);
}
@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");
logger.debug("Starting execution of tool: {}", toolDefinition.name());
validateToolContextSupport(toolContext);
Map<String, Object> toolArguments = extractToolArguments(toolInput);
Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
Object result = callMethod(methodArguments);
logger.debug("Successful execution of tool: {}", toolDefinition.name());
Type returnType = toolMethod.getGenericReturnType();
return toolCallResultConverter.convert(result, returnType);
}
@Nullable
private Object callMethod(Object[] methodArguments) {
if (isObjectNotPublic() || isMethodNotPublic()) {
toolMethod.setAccessible(true);
}
Object result;
try {
result = toolMethod.invoke(toolObject, methodArguments);
}
catch (IllegalAccessException ex) {
throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
}
catch (InvocationTargetException ex) {
throw new ToolExecutionException(toolDefinition, ex.getCause());
}
return result;
}
//......
}
MethodToolCallback实现了ToolCallback接口,其call方法通过buildMethodArguments构建参数,再通过callMethod获取返回值,最后通过toolCallResultConverter.convert来转换返回值类型;callMethod主要是通过反射调用执行
目前如下几个类型作为参数或者返回类型不支持
- Optional
- Asynchronous types (e.g. CompletableFuture, Future)
- Reactive types (e.g. Flow, Mono, Flux)
- Functional types (e.g. Function, Supplier, Consumer).
FunctionToolCallback
org/springframework/ai/tool/function/FunctionToolCallback.java
public class FunctionToolCallback<I, O> implements ToolCallback {
private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);
private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
private final ToolDefinition toolDefinition;
private final ToolMetadata toolMetadata;
private final Type toolInputType;
private final BiFunction<I, ToolContext, O> toolFunction;
private final ToolCallResultConverter toolCallResultConverter;
public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
Assert.notNull(toolInputType, "toolInputType cannot be null");
Assert.notNull(toolFunction, "toolFunction cannot be null");
this.toolDefinition = toolDefinition;
this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
this.toolFunction = toolFunction;
this.toolInputType = toolInputType;
this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
: DEFAULT_RESULT_CONVERTER;
}
@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}
@Override
public ToolMetadata getToolMetadata() {
return toolMetadata;
}
@Override
public String call(String toolInput) {
return call(toolInput, null);
}
@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");
logger.debug("Starting execution of tool: {}", toolDefinition.name());
I request = JsonParser.fromJson(toolInput, toolInputType);
O response = toolFunction.apply(request, toolContext);
logger.debug("Successful execution of tool: {}", toolDefinition.name());
return toolCallResultConverter.convert(response, null);
}
@Override
public String toString() {
return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
}
//......
}
FunctionToolCallback实现了ToolCallback接口,其call方法通过JsonParser.fromJson(toolInput, toolInputType)转换请求参数,再通过toolFunction.apply(request, toolContext)获取返回结果,最后通过toolCallResultConverter.convert(response, null)来转换结果
目前如下类型不支持作为参数或者返回类型
- Primitive types
- Optional
- Collection types (e.g. List, Map, Array, Set)
- Asynchronous types (e.g. CompletableFuture, Future)
- Reactive types (e.g. Flow, Mono, Flux).
示例
class DateTimeTools {
String getCurrentDateTime() {
return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
}
}
MethodToolCallback
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolCallback toolCallback = MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(method)
.description("Get the current date and time in the user's timezone")
.build())
.toolMethod(method)
.toolObject(new DateTimeTools())
.build();
亦或是使用@Tool注解
class DateTimeTools {
@Tool(description = "Get the current date and time in the user's timezone")
String getCurrentDateTime() {
return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
}
}
亦或是通过ToolCallbacks.from方法
ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());
FunctionToolCallback
public class WeatherService implements Function<WeatherRequest, WeatherResponse> {
public WeatherResponse apply(WeatherRequest request) {
return new WeatherResponse(30.0, Unit.C);
}
}
ToolCallback toolCallback = FunctionToolCallback
.builder("currentWeather", new WeatherService())
.description("Get the weather in location")
.inputType(WeatherRequest.class)
.build();
ChatClient.create(chatModel)
.prompt("What's the weather like in Copenhagen?")
.tools(toolCallback)
.call()
.content();
亦或设置到chatOptions
ChatOptions chatOptions = ToolCallingChatOptions.builder()
.toolCallbacks(toolCallback)
.build():
Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions);
chatModel.call(prompt);
亦或是注册到spring中
@Configuration(proxyBeanMethods = false)
class WeatherTools {
WeatherService weatherService = new WeatherService();
@Bean
@Description("Get the weather in location")
Function<WeatherRequest, WeatherResponse> currentWeather() {
return weatherService;
}
}
ChatClient.create(chatModel)
.prompt("What's the weather like in Copenhagen?")
.tools("currentWeather")
.call()
.content();
Tool Specification
ToolDefinition
org/springframework/ai/tool/definition/ToolDefinition.java
public interface ToolDefinition {
/**
* The tool name. Unique within the tool set provided to a model.
*/
String name();
/**
* The tool description, used by the AI model to determine what the tool does.
*/
String description();
/**
* The schema of the parameters used to call the tool.
*/
String inputSchema();
/**
* Create a default {@link ToolDefinition} builder.
*/
static DefaultToolDefinition.Builder builder() {
return DefaultToolDefinition.builder();
}
/**
* Create a default {@link ToolDefinition} builder from a {@link Method}.
*/
static DefaultToolDefinition.Builder builder(Method method) {
Assert.notNull(method, "method cannot be null");
return DefaultToolDefinition.builder()
.name(ToolUtils.getToolName(method))
.description(ToolUtils.getToolDescription(method))
.inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
}
/**
* Create a default {@link ToolDefinition} instance from a {@link Method}.
*/
static ToolDefinition from(Method method) {
return ToolDefinition.builder(method).build();
}
}
ToolDefinition定义了name、description、inputSchema属性,它提供了builder方法可以基于Method来构建DefaultToolDefinition
示例
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolDefinition toolDefinition = ToolDefinition.builder(method)
.name("currentDateTime")
.description("Get the current date and time in the user's timezone")
.inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
.build();
JSON Schema
Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的请求参数的json schema,对于参数描述可以使用如下注解:
@ToolParam(description = "…") from Spring AI
@JsonClassDescription(description = "…") from Jackson
@JsonPropertyDescription(description = "…") from Jackson
@Schema(description = "…") from Swagger.
示例
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.context.i18n.LocaleContextHolder;
class DateTimeTools {
@Tool(description = "Set a user alarm for the given time")
void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) {
LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME);
System.out.println("Alarm set for " + alarmTime);
}
}
对于是否必填,可以使用如下注解:
@ToolParam(required = false) from Spring AI
@JsonProperty(required = false) from Jackson
@Schema(required = false) from Swagger
@Nullable from Spring Framework.
示例:
class CustomerTools {
@Tool(description = "Update customer information")
void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) {
System.out.println("Updated info for customer with id: " + id);
}
}
Result Conversion
Spring AI提供了ToolCallResultConverter用于将tool calling的返回数据进行转换再发送给AI模型
org/springframework/ai/tool/execution/ToolCallResultConverter.java
@FunctionalInterface
public interface ToolCallResultConverter {
/**
* Given an Object returned by a tool, convert it to a String compatible with the
* given class type.
*/
String convert(@Nullable Object result, @Nullable Type returnType);
}
它有一个默认实现DefaultToolCallResultConverter
public final class DefaultToolCallResultConverter implements ToolCallResultConverter {
private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);
@Override
public String convert(@Nullable Object result, @Nullable Type returnType) {
if (returnType == Void.TYPE) {
logger.debug("The tool has no return type. Converting to conventional response.");
return "Done";
}
else {
logger.debug("Converting tool result to JSON.");
return JsonParser.toJson(result);
}
}
}
DefaultToolCallResultConverter采用的是JsonParser.toJson(result),将返回类型转换为json字符串
也可以自己指定,比如
class CustomerTools {
@Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class)
Customer getCustomerInfo(Long id) {
return customerRepository.findById(id);
}
}
Tool Context
Spring AI提供了ToolContext,可以将附加的上下文信息传递给工具。这一功能允许开发者提供额外的、由用户提供的数据,这些数据可以在工具执行过程中与AI模型传递的工具参数一起使用。使用示例如下:
class CustomerTools {
@Tool(description = "Retrieve customer information")
Customer getCustomerInfo(Long id, ToolContext toolContext) {
return customerRepository.findById(id, toolContext.get("tenantId"));
}
}
对于chatClient:
ChatModel chatModel = ...
String response = ChatClient.create(chatModel)
.prompt("Tell me more about the customer with ID 42")
.tools(new CustomerTools())
.toolContext(Map.of("tenantId", "acme"))
.call()
.content();
System.out.println(response);
对于chatModel:
ChatModel chatModel = ...
ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools());
ChatOptions chatOptions = ToolCallingChatOptions.builder()
.toolCallbacks(customerTools)
.toolContext(Map.of("tenantId", "acme"))
.build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
chatModel.call(prompt);
Return Direct
Spring AI提供了returnDirect参数,设置为true则会将tool calling的返回直接返回,而不是经过大模型再返回。默认是返回给AI模型,AI模型处理之后再返回给用户。
示例如下:
class CustomerTools {
@Tool(description = "Retrieve customer information", returnDirect = true)
Customer getCustomerInfo(Long id) {
return customerRepository.findById(id);
}
}
亦或是
ToolMetadata toolMetadata = ToolMetadata.builder()
.returnDirect(true)
.build();
ToolCallingManager
org/springframework/ai/model/tool/ToolCallingManager.java
public interface ToolCallingManager {
/**
* Resolve the tool definitions from the model's tool calling options.
*/
List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);
/**
* Execute the tool calls requested by the model.
*/
ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);
/**
* Create a default {@link ToolCallingManager} builder.
*/
static DefaultToolCallingManager.Builder builder() {
return DefaultToolCallingManager.builder();
}
}
ToolCallingManager定义了resolveToolDefinitions、executeToolCalls方法,默认实现是DefaultToolCallingManager
DefaultToolCallingManager
org/springframework/ai/model/tool/DefaultToolCallingManager.java
public class DefaultToolCallingManager implements ToolCallingManager {
private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);
// @formatter:off
private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
= ObservationRegistry.NOOP;
private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
= new DelegatingToolCallbackResolver(List.of());
private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
= DefaultToolExecutionExceptionProcessor.builder().build();
// @formatter:on
private final ObservationRegistry observationRegistry;
private final ToolCallbackResolver toolCallbackResolver;
private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;
public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");
this.observationRegistry = observationRegistry;
this.toolCallbackResolver = toolCallbackResolver;
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
}
@Override
public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
Assert.notNull(chatOptions, "chatOptions cannot be null");
List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
for (String toolName : chatOptions.getToolNames()) {
// Skip the tool if it is already present in the request toolCallbacks.
// That might happen if a tool is defined in the options
// both as a ToolCallback and as a tool name.
if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
continue;
}
FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}
toolCallbacks.add(toolCallback);
}
return toolCallbacks.stream().map(functionCallback -> {
if (functionCallback instanceof ToolCallback toolCallback) {
return toolCallback.getToolDefinition();
}
else {
return ToolDefinition.builder()
.name(functionCallback.getName())
.description(functionCallback.getDescription())
.inputSchema(functionCallback.getInputTypeSchema())
.build();
}
}).toList();
}
@Override
public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
Assert.notNull(prompt, "prompt cannot be null");
Assert.notNull(chatResponse, "chatResponse cannot be null");
Optional<Generation> toolCallGeneration = chatResponse.getResults()
.stream()
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
.findFirst();
if (toolCallGeneration.isEmpty()) {
throw new IllegalStateException("No tool call requested by the chat model");
}
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
ToolContext toolContext = buildToolContext(prompt, assistantMessage);
InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
toolContext);
List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
assistantMessage, internalToolExecutionResult.toolResponseMessage());
return ToolExecutionResult.builder()
.conversationHistory(conversationHistory)
.returnDirect(internalToolExecutionResult.returnDirect())
.build();
}
//......
/**
* Execute the tool call and return the response message. To ensure backward
* compatibility, both {@link ToolCallback} and {@link FunctionCallback} are
* supported.
*/
private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
ToolContext toolContext) {
List<FunctionCallback> toolCallbacks = List.of();
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
toolCallbacks = toolCallingChatOptions.getToolCallbacks();
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
toolCallbacks = functionOptions.getFunctionCallbacks();
}
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
Boolean returnDirect = null;
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
logger.debug("Executing tool call: {}", toolCall.name());
String toolName = toolCall.name();
String toolInputArguments = toolCall.arguments();
FunctionCallback toolCallback = toolCallbacks.stream()
.filter(tool -> toolName.equals(tool.getName()))
.findFirst()
.orElseGet(() -> toolCallbackResolver.resolve(toolName));
if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}
if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
returnDirect = callback.getToolMetadata().returnDirect();
}
else if (toolCallback instanceof ToolCallback callback) {
returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
}
else if (returnDirect == null) {
// This is a temporary solution to ensure backward compatibility with
// FunctionCallback.
// TODO: remove this block when FunctionCallback is removed.
returnDirect = false;
}
String toolResult;
try {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = toolExecutionExceptionProcessor.process(ex);
}
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
}
return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
}
private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
List<Message> messages = new ArrayList<>(previousMessages);
messages.add(assistantMessage);
messages.add(toolResponseMessage);
return messages;
}
}
DefaultToolCallingManager的resolveToolDefinitions方法会通过toolCallbackResolver来解析chatOptions.getToolCallbacks(),executeToolCalls方法先筛选出需要toolCall支持的assistantMessage,然后构建toolContext,再执行executeToolCall获取执行结构,再基于此构建conversationHistory。
executeToolCall方法遍历assistantMessage.getToolCalls(),通过toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通过toolCallback.call(toolInputArguments, toolContext)获取结果,如果出现ToolExecutionException,则通过toolExecutionExceptionProcessor.process(ex)去做兜底操作
ToolExecutionExceptionProcessor
org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java
@FunctionalInterface
public interface ToolExecutionExceptionProcessor {
/**
* Convert an exception thrown by a tool to a String that can be sent back to the AI
* model or throw an exception to be handled by the caller.
*/
String process(ToolExecutionException exception);
}
ToolExecutionExceptionProcessor定义process
DefaultToolExecutionExceptionProcessor
public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {
private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);
private static final boolean DEFAULT_ALWAYS_THROW = false;
private final boolean alwaysThrow;
public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
this.alwaysThrow = alwaysThrow;
}
@Override
public String process(ToolExecutionException exception) {
Assert.notNull(exception, "exception cannot be null");
if (alwaysThrow) {
throw exception;
}
logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
exception.getMessage());
return exception.getMessage();
}
//......
}
DefaultToolExecutionExceptionProcessor对于alwaysThrow为true的(
默认为false
)直接抛出该异常,否则返回异常的信息
User-Controlled Tool Execution
ToolCallingChatOptions提供了internalToolExecutionEnabled属性,设置为false可以自行控制对tool的调用过程(也可以自己实现ToolExecutionEligibilityPredicate去控制),示例如下:
ChatModel chatModel = ...
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
ChatOptions chatOptions = ToolCallingChatOptions.builder()
.toolCallbacks(new CustomerTools())
.internalToolExecutionEnabled(false)
.build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
ChatResponse chatResponse = chatModel.call(prompt);
while (chatResponse.hasToolCalls()) {
ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);
prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions);
chatResponse = chatModel.call(prompt);
}
System.out.println(chatResponse.getResult().getOutput().getText());
这里自己通过toolCallingManager.executeToolCalls去执行,再传递给chatModel
ToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java
public interface ToolCallbackResolver {
/**
* Resolve the {@link FunctionCallback} for the given tool name.
*/
@Nullable
FunctionCallback resolve(String toolName);
}
ToolCallbackResolver定义了resolve方法,用于根据toolName来获取对应的FunctionCallback,它有三种实现,分别是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver
StaticToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java
public class StaticToolCallbackResolver implements ToolCallbackResolver {
private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);
private final Map<String, FunctionCallback> toolCallbacks = new HashMap<>();
public StaticToolCallbackResolver(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
toolCallbacks.forEach(callback -> {
if (callback instanceof ToolCallback toolCallback) {
this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback);
}
this.toolCallbacks.put(callback.getName(), callback);
});
}
@Override
public FunctionCallback resolve(String toolName) {
Assert.hasText(toolName, "toolName cannot be null or empty");
logger.debug("ToolCallback resolution attempt from static registry");
return toolCallbacks.get(toolName);
}
}
StaticToolCallbackResolver依据构造器传入的
List<FunctionCallback>
来寻找
SpringBeanToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java
public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {
private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);
private static final Map<String, ToolCallback> toolCallbacksCache = new HashMap<>();
private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA;
private final GenericApplicationContext applicationContext;
private final SchemaType schemaType;
public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
@Nullable SchemaType schemaType) {
Assert.notNull(applicationContext, "applicationContext cannot be null");
this.applicationContext = applicationContext;
this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE;
}
@Override
public ToolCallback resolve(String toolName) {
Assert.hasText(toolName, "toolName cannot be null or empty");
logger.debug("ToolCallback resolution attempt from Spring application context");
ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);
if (resolvedToolCallback != null) {
return resolvedToolCallback;
}
ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName);
ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);
String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
Object bean = applicationContext.getBean(toolName);
resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);
toolCallbacksCache.put(toolName, resolvedToolCallback);
return resolvedToolCallback;
}
//......
}
SpringBeanToolCallbackResolver使用GenericApplicationContext根据toolName去spring容器查找,找到的话会放到toolCallbacksCache中
DelegatingToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java
public class DelegatingToolCallbackResolver implements ToolCallbackResolver {
private final List<ToolCallbackResolver> toolCallbackResolvers;
public DelegatingToolCallbackResolver(List<ToolCallbackResolver> toolCallbackResolvers) {
Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
this.toolCallbackResolvers = toolCallbackResolvers;
}
@Override
@Nullable
public FunctionCallback resolve(String toolName) {
Assert.hasText(toolName, "toolName cannot be null or empty");
for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
if (toolCallback != null) {
return toolCallback;
}
}
return null;
}
}
DelegatingToolCallbackResolver把resolve方法委托给了构造器传入的其他toolCallbackResolvers
小结
Spring AI提供了ToolCallback来实现Tool Calling,它继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback。
整个Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct
整个Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling