2025-02-18

# embedding_adapters.py

# -*- coding: utf-8 -*-

import logging

import requests

import traceback

from typing import List

from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings

def ensure_openai_base_url_has_v1(url: str) -> str:

    """

    若用户输入的 url 不包含 '/v1',则在末尾追加 '/v1'。

    """

    import re

    url = url.strip()

    if not url:

        return url

    if not re.search(r'/v\d+$', url):

        if '/v1' not in url:

            url = url.rstrip('/') + '/v1'

    return url

class BaseEmbeddingAdapter:

    """

    Embedding 接口统一基类

    """

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        raise NotImplementedError

    def embed_query(self, query: str) -> List[float]:

        raise NotImplementedError

class OpenAIEmbeddingAdapter(BaseEmbeddingAdapter):

    """

    基于 OpenAI 兼容接口的适配器,支持本地部署

    """

    def __init__(self, api_key: str, base_url: str, model_name: str):

        self.base_url = base_url.rstrip("/")

        self.model_name = model_name

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        try:

            response = requests.post(

                f"{self.base_url}/embeddings",

                json={

                    "model": self.model_name,

                    "input": texts

                }

            )

            response.raise_for_status()

            result = response.json()

            return [data["embedding"] for data in result["data"]]

        except Exception as e:

            logging.error(f"Embedding request error: {e}\n{traceback.format_exc()}")

            return [[0.0] * 1536] * len(texts)  # 返回空向量

    def embed_query(self, query: str) -> List[float]:

        try:

            response = requests.post(

                f"{self.base_url}/embeddings",

                json={

                    "model": self.model_name,

                    "input": query

                }

            )

            response.raise_for_status()

            result = response.json()

            return result["data"][0]["embedding"]

        except Exception as e:

            logging.error(f"Embedding request error: {e}\n{traceback.format_exc()}")

            return [0.0] * 1536  # 返回空向量

class AzureOpenAIEmbeddingAdapter(BaseEmbeddingAdapter):

    """

    基于 AzureOpenAIEmbeddings(或兼容接口)的适配器

    """

    def __init__(self, api_key: str, base_url: str, model_name: str):

        import re

        match = re.match(r'https://(.+?)/openai/deployments/(.+?)/embeddings\?api-version=(.+)', base_url)

        if match:

            self.azure_endpoint = f"https://{match.group(1)}"

            self.azure_deployment = match.group(2)

            self.api_version = match.group(3)

        else:

            raise ValueError("Invalid Azure OpenAI base_url format")


        self._embedding = AzureOpenAIEmbeddings(

            azure_endpoint=self.azure_endpoint,

            azure_deployment=self.azure_deployment,

            openai_api_key=api_key,

            api_version=self.api_version,

        )

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        return self._embedding.embed_documents(texts)

    def embed_query(self, query: str) -> List[float]:

        return self._embedding.embed_query(query)

class OllamaEmbeddingAdapter(BaseEmbeddingAdapter):

    """

    其接口路径为 /api/embeddings

    """

    def __init__(self, model_name: str, base_url: str):

        self.model_name = model_name

        self.base_url = base_url.rstrip("/")

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        embeddings = []

        for text in texts:

            vec = self._embed_single(text)

            embeddings.append(vec)

        return embeddings

    def embed_query(self, query: str) -> List[float]:

        return self._embed_single(query)

    def _embed_single(self, text: str) -> List[float]:

        """

        调用 Ollama 本地服务 /api/embeddings 接口,获取文本 embedding

        """

        url = self.base_url.rstrip("/")

        if "/api/embeddings" not in url:

            if "/api" in url:

                url = f"{url}/embeddings"

            else:

                if "/v1" in url:

                    url = url[:url.index("/v1")]

                url = f"{url}/api/embeddings"

        data = {

            "model": self.model_name,

            "prompt": text

        }

        try:

            response = requests.post(url, json=data)

            response.raise_for_status()

            result = response.json()

            if "embedding" not in result:

                raise ValueError("No 'embedding' field in Ollama response.")

            return result["embedding"]

        except requests.exceptions.RequestException as e:

            logging.error(f"Ollama embeddings request error: {e}\n{traceback.format_exc()}")

            return []

class MLStudioEmbeddingAdapter(BaseEmbeddingAdapter):

    """

    基于 ML Studio 本地部署的适配器

    """

    def __init__(self, api_key: str, base_url: str, model_name: str):

        self.base_url = base_url.rstrip("/")

        self.model_name = model_name

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        try:

            response = requests.post(

                f"{self.base_url}/embeddings",

                json={

                    "model": self.model_name,

                    "input": texts

                }

            )

            response.raise_for_status()

            result = response.json()

            return [data["embedding"] for data in result["data"]]

        except Exception as e:

            logging.error(f"ML Studio embedding request error: {e}\n{traceback.format_exc()}")

            return [[0.0] * 1536] * len(texts)  # 返回空向量

    def embed_query(self, query: str) -> List[float]:

        try:

            response = requests.post(

                f"{self.base_url}/embeddings",

                json={

                    "model": self.model_name,

                    "input": query

                }

            )

            response.raise_for_status()

            result = response.json()

            return result["data"][0]["embedding"]

        except Exception as e:

            logging.error(f"ML Studio embedding request error: {e}\n{traceback.format_exc()}")

            return [0.0] * 1536  # 返回空向量

class GeminiEmbeddingAdapter(BaseEmbeddingAdapter):

    """

    基于 Google Generative AI (Gemini) 接口的 Embedding 适配器

    使用直接 POST 请求方式,URL 示例:

    https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key=YOUR_API_KEY

    """

    def __init__(self, api_key: str, model_name: str, base_url: str):

        """

        :param api_key: 传入的 Google API Key

        :param model_name: 这里一般是 "text-embedding-004"

        :param base_url: e.g. https://generativelanguage.googleapis.com/v1beta/models

        """

        self.api_key = api_key

        self.model_name = model_name

        self.base_url = base_url.rstrip("/")

    def embed_documents(self, texts: List[str]) -> List[List[float]]:

        embeddings = []

        for text in texts:

            vec = self._embed_single(text)

            embeddings.append(vec)

        return embeddings

    def embed_query(self, query: str) -> List[float]:

        return self._embed_single(query)

    def _embed_single(self, text: str) -> List[float]:

        """

        直接调用 Google Generative Language API (Gemini) 接口,获取文本 embedding

        """

        url = f"{self.base_url}/{self.model_name}:embedContent?key={self.api_key}"

        payload = {

            "model": self.model_name,

            "content": {

                "parts": [

                    {"text": text}

                ]

            }

        }

        try:

            response = requests.post(url, json=payload)

            print(response.text)

            response.raise_for_status()

            result = response.json()

            embedding_data = result.get("embedding", {})

            return embedding_data.get("values", [])

        except requests.exceptions.RequestException as e:

            logging.error(f"Gemini embed_content request error: {e}\n{traceback.format_exc()}")

            return []

        except Exception as e:

            logging.error(f"Gemini embed_content parse error: {e}\n{traceback.format_exc()}")

            return []

def create_embedding_adapter(

    interface_format: str,

    api_key: str,

    base_url: str,

    model_name: str

) -> BaseEmbeddingAdapter:

    """

    工厂函数:根据 interface_format 返回不同的 embedding 适配器实例

    """

    fmt = interface_format.strip().lower()

    if fmt == "openai":

        return OpenAIEmbeddingAdapter(api_key, base_url, model_name)

    elif fmt == "azure openai":

        return AzureOpenAIEmbeddingAdapter(api_key, base_url, model_name)

    elif fmt == "ollama":

        return OllamaEmbeddingAdapter(model_name, base_url)

    elif fmt == "ml studio":

        return MLStudioEmbeddingAdapter(api_key, base_url, model_name)

    elif fmt == "gemini":

        return GeminiEmbeddingAdapter(api_key, model_name, base_url)

    else:

        raise ValueError(f"Unknown embedding interface_format: {interface_format}")

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容