Springboot 2.x对接AI连接自己的接口和数据库

Springboot3.0可以直接引入ai的starter包
springboot2.x需要通过接口的方式去实现

实现的效果如下

image.png

代码逻辑如下

application.properties文件

#ai
spring.ai.openai.api-key=您的令牌
spring.ai.openai.base-url=ai地址
spring.ai.openai.chat.options.model=hunyuan-lite

WebClientConfig 配置全局请求

/**
 * 配置全局请求
 */
@Configuration
public class AiConfig {

    @Value("${spring.ai.openai.api-key}")
    private String apiKey;

    @Value("${spring.ai.openai.base-url}")
    private String baseUrl;

    @Bean
    public WebClient init() {
        return WebClient.builder().baseUrl(baseUrl).defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
                .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).codecs(configurer -> {
                    configurer.defaultCodecs().maxInMemorySize(16 * 1024 * 1024); // 16MB
                }).build();
    }
}

AiController 控制层

@Valid
@Slf4j
@RestController
@RequiredArgsConstructor
@RequestMapping(value = "/api/ai")
@ApiSort(99)
public class AiController {

    private final WebClient webClient;
    private final ObjectMapper objectMapper;

    // 构建系统提示词
    public static final String SYSTEM_PROMPT = "你是一个数据库查询助手,可以根据用户的问题调用相应的查询函数。请根据用户的问题选择合适的函数,并提取必要的参数。如果用户的问题不明确,请要求用户提供更多信息。"
            + "重要规则:"
            + "1、当用户调用get_ticket_process和get_work_clockin函数时,参数格式必须是:YYYY-MM-DD至YYYY-MM-DD,如果用户使用相对日期(如昨天、今天、最近3天、本月、上月等),如果是今天(或昨天)则开始和结束都相同的日期范围,如果是月默认取第一天和最后一天,请参考当前时间将其转换为具体的日期范围。"
            + "2、当用户调用get_user_salary函数时,参数格式必须是:YYYY-MM,如果用户使用相对日期(如本月、上月等),请参考当前时间将其转换为具体的月份(格式:YYYY-MM)。";

    @Value("${spring.ai.openai.chat.options.model}")
    private String model;

    // 简单的函数注册表
    private final Map<String, AiFunctionExecutor> functionExecutors = new HashMap<>();

    private final AiTicketProcessFunctionExecutor aiTicketProcessFunctionExecutor;
    private final AiWorkClockinFunctionExecutor aiWorkClockinFunctionExecutor;
    private final AiUserSalaryFunctionExecutor aiUserSalaryFunctionExecutor;

    @PostConstruct
    public void init() {
        // 注册测试函数
        functionExecutors.put("get_ticket_process", aiTicketProcessFunctionExecutor);
        functionExecutors.put("get_work_clockin", aiWorkClockinFunctionExecutor);
        functionExecutors.put("get_user_salary", aiUserSalaryFunctionExecutor);

    }

    /**
     * 构建函数定义
     */
    private List<Map<String, Object>> buildFunctionDefinitions() {
        List<Map<String, Object>> tools = new ArrayList<>();
        // 计件记录
        tools.add(createToolDefinition("get_ticket_process", "获取当前用户某时间段内的计件数量、金额。",
                Map.of("type", "object", "properties",
                        Map.of("date", Map.of("type", "string", "description", "日期范围,格式如:YYYY-MM-DD至YYYY-MM-DD")),
                        "required", List.of("date"))));
        // 计时记录
        tools.add(createToolDefinition("get_work_clockin", "获取当前用户某时间段内的计时数量、金额。",
                Map.of("type", "object", "properties",
                        Map.of("date", Map.of("type", "string", "description", "日期范围,格式如:YYYY-MM-DD至YYYY-MM-DD")),
                        "required", List.of("date"))));
        // 工资信息
        tools.add(createToolDefinition("get_user_salary", "获取当前用户本月或上月的工资薪酬信息。",
                Map.of("type", "object", "properties",
                        Map.of("month", Map.of("type", "string", "description", "月份,格式如:YYYY-MM")), "required",
                        List.of("month"))));
        return tools;
    }

    /**
     * 使用ResponseBodyEmitter的SSE接口
     */
    @SkipResponseWrap
    @GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public ResponseBodyEmitter streamChat(@RequestParam("prompt") String prompt) {
//      log.info("收到聊天请求,prompt: {}", prompt);
        String teamCode = ContextHolder.ctx().getTeamCode();
        String userCode = ContextHolder.ctx().getUserCode();
        ContextTeamUserVo ctu = ContextHolder.ctx().getCtu();

        ResponseBodyEmitter emitter = new ResponseBodyEmitter();
        CompletableFuture.runAsync(() -> {
            try {
                // 构建请求体
                Map<String, Object> requestBody = new HashMap<>();
                requestBody.put("model", model);
                requestBody.put("stream", false);
                // 构建函数定义
                List<Map<String, Object>> tools = buildFunctionDefinitions();
                requestBody.put("tools", tools);
                requestBody.put("tool_choice", "auto");

                List<Map<String, Object>> messages = new ArrayList<>();
                messages.add(Map.of("role", "system", "content", SYSTEM_PROMPT + "\n 当前时间:" + new Date()));
                messages.add(Map.of("role", "user", "content", prompt));
                requestBody.put("messages", messages);

                String requestBodyJson = objectMapper.writeValueAsString(requestBody);
//              log.info("发送OpenAI请求: {}", requestBodyJson);

                // 发送请求
                String resp = webClient.post().uri("/chat/completions").bodyValue(requestBodyJson).retrieve()
                        .bodyToMono(String.class).block();
//              log.info("收到OpenAI响应: {}", resp);

                // 处理响应
                String result = processOpenAIResponse(resp, ctu, teamCode, userCode);
                // 发送SSE格式数据,明确指定UTF-8编码
                String sseData = "data: " + result + "\n\n";

                MediaType mediaType = MediaType.parseMediaType("text/event-stream;charset=UTF-8");
                emitter.send(sseData, mediaType);
                emitter.complete();

            } catch (Exception e) {
                log.error("处理请求失败", e);
                try {
                    String errorMessage = "data: 处理请求时出现错误: " + e.getMessage() + "\n\n";
                    MediaType mediaType = MediaType.parseMediaType("text/event-stream;charset=UTF-8");
                    emitter.send(errorMessage, mediaType);
                    emitter.complete();
                } catch (IOException ex) {
                    emitter.completeWithError(ex);
                }
            }
        });

        return emitter;
    }

    /**
     * 处理OpenAI响应,支持Function Calling
     * 
     * @param ctu
     * @param userCode
     * @param teamCode
     */
    private String processOpenAIResponse(String responseJson, ContextTeamUserVo ctu, String teamCode, String userCode)
            throws Exception {
        @SuppressWarnings("unchecked")
        Map<String, Object> response = objectMapper.readValue(responseJson, Map.class);
        @SuppressWarnings("unchecked")
        List<Map<String, Object>> choices = (List<Map<String, Object>>) response.get("choices");

        if (choices == null || choices.isEmpty()) {
            return "未收到有效响应";
        }

        Map<String, Object> choice = choices.get(0);
        @SuppressWarnings("unchecked")
        Map<String, Object> message = (Map<String, Object>) choice.get("message");

        if (message == null) {
            return "响应格式异常";
        }

        // 检查是否有工具调用(新的格式)
        @SuppressWarnings("unchecked")
        List<Map<String, Object>> toolCalls = (List<Map<String, Object>>) message.get("tool_calls");
        if (toolCalls != null && !toolCalls.isEmpty()) {
            // 取第一个工具调用
            Map<String, Object> toolCall = toolCalls.get(0);
            @SuppressWarnings("unchecked")
            Map<String, Object> function = (Map<String, Object>) toolCall.get("function");
            String functionName = (String) function.get("name");
            String argumentsJson = (String) function.get("arguments");
//          log.info("检测到工具调用: {}, 参数: {}", functionName, argumentsJson);
            // 执行函数并返回结果
            return executeFunctionAndRespond(functionName, argumentsJson, ctu, teamCode, userCode);
        } else {
            // 没有工具调用,直接返回内容
            String content = (String) message.get("content");
//          log.info("无工具调用: {}", content);
            return content != null ? content : "未获得有效回复";
        }
    }

    /**
     * 执行函数并生成最终响应
     * 
     * @param ctu
     * @param userCode
     * @param teamCode
     */
    private String executeFunctionAndRespond(String functionName, String argumentsJson, ContextTeamUserVo ctu,
            String teamCode, String userCode) throws Exception {
        @SuppressWarnings("unchecked")
        Map<String, Object> arguments = objectMapper.readValue(argumentsJson, Map.class);
        arguments.put("workCode", ctu.getWorkCode());
        arguments.put("teamCode", teamCode);
        arguments.put("userCode", userCode);
        // 执行函数
        AiFunctionExecutor executor = functionExecutors.get(functionName);
        if (executor == null) {
            return "未知的函数: " + functionName;
        }

        Object functionResult = executor.execute(arguments);
        String resultStr = functionResult.toString();

//      log.info("函数执行结果: {}", resultStr);

        // 直接返回函数执行结果(简化处理,不进行二次AI调用)
        return "根据您的请求,我查询到以下信息:\n\n" + resultStr;
    }

    /**
     * AI工具方法定义
     * 
     * @param name        名称
     * @param description 说明
     * @param parameters  参数
     * @return
     */
    private Map<String, Object> createToolDefinition(String name, String description, Map<String, Object> parameters) {
        Map<String, Object> tool = new HashMap<>();
        tool.put("Type", "function");
        Map<String, Object> function = new HashMap<>();
        function.put("Name", name);
        function.put("Description", description);
        try {
            // 将参数转换为JSON字符串
            String parametersJson = objectMapper.writeValueAsString(parameters);
            function.put("Parameters", parametersJson);
        } catch (Exception e) {
            log.error("参数序列化失败", e);
            function.put("Parameters", "{}");
        }
        tool.put("Function", function);
        return tool;
    }

}

定义AI工具的接口类

/**
 * Ai方法接口
 */
public interface AiFunctionExecutor {
    /**
     * 执行方法
     * 
     * @param arguments 参数
     * @return
     */
    Object execute(Map<String, Object> arguments);
}

实现其中一个接口


@Slf4j
@RequiredArgsConstructor
@Transactional(readOnly = true)
@Service
public class AiUserSalaryFunctionExecutor implements AiFunctionExecutor {
    private final SalaryRecordService salaryRecordService;

    @Override
    public Object execute(Map<String, Object> arguments) {
        String month = (String) arguments.get("month");
        if (month == null || month.trim().isEmpty() || !month.contains("-")) {
            return "请提供查询工资的月份";
        }
        String workCode = (String) arguments.get("workCode");
        String teamCode = (String) arguments.get("teamCode");

        List<UserSalaryListVo> list = salaryRecordService.selectUserSalaryList(teamCode, month, workCode);

        if (list == null || list.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            sb.append("   在").append(month).append("期间,没有找到记录。");
            return sb.toString();
        }
        try {

            return formatWeatherResponse(month, list.get(0));
        } catch (Exception e) {
            return "获取工资信息失败";
        }
    }

    private String formatWeatherResponse(String month, UserSalaryListVo vo) {
        StringBuilder sb = new StringBuilder();
        sb.append("   在").append(month).append("期间,工资记录汇总如下:\n");
        sb.append("   计件工资:").append(vo.getProcessPrice().doubleValue()).append("\n");
        sb.append("   固定计件:").append(vo.getPiecePrice().doubleValue()).append("\n");
        sb.append("   计时工资:").append(vo.getClockinPrice().doubleValue()).append("\n");
        sb.append("   奖励金额:").append(vo.getRewardPrice().doubleValue()).append("\n");
        sb.append("   扣款金额:").append(vo.getPunishPrice().doubleValue()).append("\n");

        sb.append("   合计薪资:").append(vo.getTotalPrice().doubleValue()).append("\n");
        sb.append("   已发薪资:").append(vo.getSalaryPrice().doubleValue()).append("\n");
        sb.append("   待发薪资:").append(vo.getRemainingPrice().doubleValue()).append("\n");

        sb.append("   如需查询详情,请前往菜单:[功能-计件计时-我的薪酬]查看。");
        return sb.toString();
    }

}

其他的类似 ,对接其他接口即可

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容