Springboot3.0可以直接引入ai的starter包
springboot2.x需要通过接口的方式去实现
实现的效果如下

image.png
代码逻辑如下
application.properties文件
#ai
spring.ai.openai.api-key=您的令牌
spring.ai.openai.base-url=ai地址
spring.ai.openai.chat.options.model=hunyuan-lite
WebClientConfig 配置全局请求
/**
* 配置全局请求
*/
@Configuration
public class AiConfig {
@Value("${spring.ai.openai.api-key}")
private String apiKey;
@Value("${spring.ai.openai.base-url}")
private String baseUrl;
@Bean
public WebClient init() {
return WebClient.builder().baseUrl(baseUrl).defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).codecs(configurer -> {
configurer.defaultCodecs().maxInMemorySize(16 * 1024 * 1024); // 16MB
}).build();
}
}
AiController 控制层
@Valid
@Slf4j
@RestController
@RequiredArgsConstructor
@RequestMapping(value = "/api/ai")
@ApiSort(99)
public class AiController {
private final WebClient webClient;
private final ObjectMapper objectMapper;
// 构建系统提示词
public static final String SYSTEM_PROMPT = "你是一个数据库查询助手,可以根据用户的问题调用相应的查询函数。请根据用户的问题选择合适的函数,并提取必要的参数。如果用户的问题不明确,请要求用户提供更多信息。"
+ "重要规则:"
+ "1、当用户调用get_ticket_process和get_work_clockin函数时,参数格式必须是:YYYY-MM-DD至YYYY-MM-DD,如果用户使用相对日期(如昨天、今天、最近3天、本月、上月等),如果是今天(或昨天)则开始和结束都相同的日期范围,如果是月默认取第一天和最后一天,请参考当前时间将其转换为具体的日期范围。"
+ "2、当用户调用get_user_salary函数时,参数格式必须是:YYYY-MM,如果用户使用相对日期(如本月、上月等),请参考当前时间将其转换为具体的月份(格式:YYYY-MM)。";
@Value("${spring.ai.openai.chat.options.model}")
private String model;
// 简单的函数注册表
private final Map<String, AiFunctionExecutor> functionExecutors = new HashMap<>();
private final AiTicketProcessFunctionExecutor aiTicketProcessFunctionExecutor;
private final AiWorkClockinFunctionExecutor aiWorkClockinFunctionExecutor;
private final AiUserSalaryFunctionExecutor aiUserSalaryFunctionExecutor;
@PostConstruct
public void init() {
// 注册测试函数
functionExecutors.put("get_ticket_process", aiTicketProcessFunctionExecutor);
functionExecutors.put("get_work_clockin", aiWorkClockinFunctionExecutor);
functionExecutors.put("get_user_salary", aiUserSalaryFunctionExecutor);
}
/**
* 构建函数定义
*/
private List<Map<String, Object>> buildFunctionDefinitions() {
List<Map<String, Object>> tools = new ArrayList<>();
// 计件记录
tools.add(createToolDefinition("get_ticket_process", "获取当前用户某时间段内的计件数量、金额。",
Map.of("type", "object", "properties",
Map.of("date", Map.of("type", "string", "description", "日期范围,格式如:YYYY-MM-DD至YYYY-MM-DD")),
"required", List.of("date"))));
// 计时记录
tools.add(createToolDefinition("get_work_clockin", "获取当前用户某时间段内的计时数量、金额。",
Map.of("type", "object", "properties",
Map.of("date", Map.of("type", "string", "description", "日期范围,格式如:YYYY-MM-DD至YYYY-MM-DD")),
"required", List.of("date"))));
// 工资信息
tools.add(createToolDefinition("get_user_salary", "获取当前用户本月或上月的工资薪酬信息。",
Map.of("type", "object", "properties",
Map.of("month", Map.of("type", "string", "description", "月份,格式如:YYYY-MM")), "required",
List.of("month"))));
return tools;
}
/**
* 使用ResponseBodyEmitter的SSE接口
*/
@SkipResponseWrap
@GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public ResponseBodyEmitter streamChat(@RequestParam("prompt") String prompt) {
// log.info("收到聊天请求,prompt: {}", prompt);
String teamCode = ContextHolder.ctx().getTeamCode();
String userCode = ContextHolder.ctx().getUserCode();
ContextTeamUserVo ctu = ContextHolder.ctx().getCtu();
ResponseBodyEmitter emitter = new ResponseBodyEmitter();
CompletableFuture.runAsync(() -> {
try {
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", model);
requestBody.put("stream", false);
// 构建函数定义
List<Map<String, Object>> tools = buildFunctionDefinitions();
requestBody.put("tools", tools);
requestBody.put("tool_choice", "auto");
List<Map<String, Object>> messages = new ArrayList<>();
messages.add(Map.of("role", "system", "content", SYSTEM_PROMPT + "\n 当前时间:" + new Date()));
messages.add(Map.of("role", "user", "content", prompt));
requestBody.put("messages", messages);
String requestBodyJson = objectMapper.writeValueAsString(requestBody);
// log.info("发送OpenAI请求: {}", requestBodyJson);
// 发送请求
String resp = webClient.post().uri("/chat/completions").bodyValue(requestBodyJson).retrieve()
.bodyToMono(String.class).block();
// log.info("收到OpenAI响应: {}", resp);
// 处理响应
String result = processOpenAIResponse(resp, ctu, teamCode, userCode);
// 发送SSE格式数据,明确指定UTF-8编码
String sseData = "data: " + result + "\n\n";
MediaType mediaType = MediaType.parseMediaType("text/event-stream;charset=UTF-8");
emitter.send(sseData, mediaType);
emitter.complete();
} catch (Exception e) {
log.error("处理请求失败", e);
try {
String errorMessage = "data: 处理请求时出现错误: " + e.getMessage() + "\n\n";
MediaType mediaType = MediaType.parseMediaType("text/event-stream;charset=UTF-8");
emitter.send(errorMessage, mediaType);
emitter.complete();
} catch (IOException ex) {
emitter.completeWithError(ex);
}
}
});
return emitter;
}
/**
* 处理OpenAI响应,支持Function Calling
*
* @param ctu
* @param userCode
* @param teamCode
*/
private String processOpenAIResponse(String responseJson, ContextTeamUserVo ctu, String teamCode, String userCode)
throws Exception {
@SuppressWarnings("unchecked")
Map<String, Object> response = objectMapper.readValue(responseJson, Map.class);
@SuppressWarnings("unchecked")
List<Map<String, Object>> choices = (List<Map<String, Object>>) response.get("choices");
if (choices == null || choices.isEmpty()) {
return "未收到有效响应";
}
Map<String, Object> choice = choices.get(0);
@SuppressWarnings("unchecked")
Map<String, Object> message = (Map<String, Object>) choice.get("message");
if (message == null) {
return "响应格式异常";
}
// 检查是否有工具调用(新的格式)
@SuppressWarnings("unchecked")
List<Map<String, Object>> toolCalls = (List<Map<String, Object>>) message.get("tool_calls");
if (toolCalls != null && !toolCalls.isEmpty()) {
// 取第一个工具调用
Map<String, Object> toolCall = toolCalls.get(0);
@SuppressWarnings("unchecked")
Map<String, Object> function = (Map<String, Object>) toolCall.get("function");
String functionName = (String) function.get("name");
String argumentsJson = (String) function.get("arguments");
// log.info("检测到工具调用: {}, 参数: {}", functionName, argumentsJson);
// 执行函数并返回结果
return executeFunctionAndRespond(functionName, argumentsJson, ctu, teamCode, userCode);
} else {
// 没有工具调用,直接返回内容
String content = (String) message.get("content");
// log.info("无工具调用: {}", content);
return content != null ? content : "未获得有效回复";
}
}
/**
* 执行函数并生成最终响应
*
* @param ctu
* @param userCode
* @param teamCode
*/
private String executeFunctionAndRespond(String functionName, String argumentsJson, ContextTeamUserVo ctu,
String teamCode, String userCode) throws Exception {
@SuppressWarnings("unchecked")
Map<String, Object> arguments = objectMapper.readValue(argumentsJson, Map.class);
arguments.put("workCode", ctu.getWorkCode());
arguments.put("teamCode", teamCode);
arguments.put("userCode", userCode);
// 执行函数
AiFunctionExecutor executor = functionExecutors.get(functionName);
if (executor == null) {
return "未知的函数: " + functionName;
}
Object functionResult = executor.execute(arguments);
String resultStr = functionResult.toString();
// log.info("函数执行结果: {}", resultStr);
// 直接返回函数执行结果(简化处理,不进行二次AI调用)
return "根据您的请求,我查询到以下信息:\n\n" + resultStr;
}
/**
* AI工具方法定义
*
* @param name 名称
* @param description 说明
* @param parameters 参数
* @return
*/
private Map<String, Object> createToolDefinition(String name, String description, Map<String, Object> parameters) {
Map<String, Object> tool = new HashMap<>();
tool.put("Type", "function");
Map<String, Object> function = new HashMap<>();
function.put("Name", name);
function.put("Description", description);
try {
// 将参数转换为JSON字符串
String parametersJson = objectMapper.writeValueAsString(parameters);
function.put("Parameters", parametersJson);
} catch (Exception e) {
log.error("参数序列化失败", e);
function.put("Parameters", "{}");
}
tool.put("Function", function);
return tool;
}
}
定义AI工具的接口类
/**
* Ai方法接口
*/
public interface AiFunctionExecutor {
/**
* 执行方法
*
* @param arguments 参数
* @return
*/
Object execute(Map<String, Object> arguments);
}
实现其中一个接口
@Slf4j
@RequiredArgsConstructor
@Transactional(readOnly = true)
@Service
public class AiUserSalaryFunctionExecutor implements AiFunctionExecutor {
private final SalaryRecordService salaryRecordService;
@Override
public Object execute(Map<String, Object> arguments) {
String month = (String) arguments.get("month");
if (month == null || month.trim().isEmpty() || !month.contains("-")) {
return "请提供查询工资的月份";
}
String workCode = (String) arguments.get("workCode");
String teamCode = (String) arguments.get("teamCode");
List<UserSalaryListVo> list = salaryRecordService.selectUserSalaryList(teamCode, month, workCode);
if (list == null || list.isEmpty()) {
StringBuilder sb = new StringBuilder();
sb.append(" 在").append(month).append("期间,没有找到记录。");
return sb.toString();
}
try {
return formatWeatherResponse(month, list.get(0));
} catch (Exception e) {
return "获取工资信息失败";
}
}
private String formatWeatherResponse(String month, UserSalaryListVo vo) {
StringBuilder sb = new StringBuilder();
sb.append(" 在").append(month).append("期间,工资记录汇总如下:\n");
sb.append(" 计件工资:").append(vo.getProcessPrice().doubleValue()).append("\n");
sb.append(" 固定计件:").append(vo.getPiecePrice().doubleValue()).append("\n");
sb.append(" 计时工资:").append(vo.getClockinPrice().doubleValue()).append("\n");
sb.append(" 奖励金额:").append(vo.getRewardPrice().doubleValue()).append("\n");
sb.append(" 扣款金额:").append(vo.getPunishPrice().doubleValue()).append("\n");
sb.append(" 合计薪资:").append(vo.getTotalPrice().doubleValue()).append("\n");
sb.append(" 已发薪资:").append(vo.getSalaryPrice().doubleValue()).append("\n");
sb.append(" 待发薪资:").append(vo.getRemainingPrice().doubleValue()).append("\n");
sb.append(" 如需查询详情,请前往菜单:[功能-计件计时-我的薪酬]查看。");
return sb.toString();
}
}
其他的类似 ,对接其他接口即可