# embedding_adapters.py
# -*- coding: utf-8 -*-
import logging
import requests
import traceback
from typing import List
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
def ensure_openai_base_url_has_v1(url: str) -> str:
"""
若用户输入的 url 不包含 '/v1',则在末尾追加 '/v1'。
"""
import re
url = url.strip()
if not url:
return url
if not re.search(r'/v\d+$', url):
if '/v1' not in url:
url = url.rstrip('/') + '/v1'
return url
class BaseEmbeddingAdapter:
"""
Embedding 接口统一基类
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError
def embed_query(self, query: str) -> List[float]:
raise NotImplementedError
class OpenAIEmbeddingAdapter(BaseEmbeddingAdapter):
"""
基于 OpenAI 兼容接口的适配器,支持本地部署
"""
def __init__(self, api_key: str, base_url: str, model_name: str):
self.base_url = base_url.rstrip("/")
self.model_name = model_name
def embed_documents(self, texts: List[str]) -> List[List[float]]:
try:
response = requests.post(
f"{self.base_url}/embeddings",
json={
"model": self.model_name,
"input": texts
}
)
response.raise_for_status()
result = response.json()
return [data["embedding"] for data in result["data"]]
except Exception as e:
logging.error(f"Embedding request error: {e}\n{traceback.format_exc()}")
return [[0.0] * 1536] * len(texts) # 返回空向量
def embed_query(self, query: str) -> List[float]:
try:
response = requests.post(
f"{self.base_url}/embeddings",
json={
"model": self.model_name,
"input": query
}
)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
logging.error(f"Embedding request error: {e}\n{traceback.format_exc()}")
return [0.0] * 1536 # 返回空向量
class AzureOpenAIEmbeddingAdapter(BaseEmbeddingAdapter):
"""
基于 AzureOpenAIEmbeddings(或兼容接口)的适配器
"""
def __init__(self, api_key: str, base_url: str, model_name: str):
import re
match = re.match(r'https://(.+?)/openai/deployments/(.+?)/embeddings\?api-version=(.+)', base_url)
if match:
self.azure_endpoint = f"https://{match.group(1)}"
self.azure_deployment = match.group(2)
self.api_version = match.group(3)
else:
raise ValueError("Invalid Azure OpenAI base_url format")
self._embedding = AzureOpenAIEmbeddings(
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
openai_api_key=api_key,
api_version=self.api_version,
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedding.embed_documents(texts)
def embed_query(self, query: str) -> List[float]:
return self._embedding.embed_query(query)
class OllamaEmbeddingAdapter(BaseEmbeddingAdapter):
"""
其接口路径为 /api/embeddings
"""
def __init__(self, model_name: str, base_url: str):
self.model_name = model_name
self.base_url = base_url.rstrip("/")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in texts:
vec = self._embed_single(text)
embeddings.append(vec)
return embeddings
def embed_query(self, query: str) -> List[float]:
return self._embed_single(query)
def _embed_single(self, text: str) -> List[float]:
"""
调用 Ollama 本地服务 /api/embeddings 接口,获取文本 embedding
"""
url = self.base_url.rstrip("/")
if "/api/embeddings" not in url:
if "/api" in url:
url = f"{url}/embeddings"
else:
if "/v1" in url:
url = url[:url.index("/v1")]
url = f"{url}/api/embeddings"
data = {
"model": self.model_name,
"prompt": text
}
try:
response = requests.post(url, json=data)
response.raise_for_status()
result = response.json()
if "embedding" not in result:
raise ValueError("No 'embedding' field in Ollama response.")
return result["embedding"]
except requests.exceptions.RequestException as e:
logging.error(f"Ollama embeddings request error: {e}\n{traceback.format_exc()}")
return []
class MLStudioEmbeddingAdapter(BaseEmbeddingAdapter):
"""
基于 ML Studio 本地部署的适配器
"""
def __init__(self, api_key: str, base_url: str, model_name: str):
self.base_url = base_url.rstrip("/")
self.model_name = model_name
def embed_documents(self, texts: List[str]) -> List[List[float]]:
try:
response = requests.post(
f"{self.base_url}/embeddings",
json={
"model": self.model_name,
"input": texts
}
)
response.raise_for_status()
result = response.json()
return [data["embedding"] for data in result["data"]]
except Exception as e:
logging.error(f"ML Studio embedding request error: {e}\n{traceback.format_exc()}")
return [[0.0] * 1536] * len(texts) # 返回空向量
def embed_query(self, query: str) -> List[float]:
try:
response = requests.post(
f"{self.base_url}/embeddings",
json={
"model": self.model_name,
"input": query
}
)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
logging.error(f"ML Studio embedding request error: {e}\n{traceback.format_exc()}")
return [0.0] * 1536 # 返回空向量
class GeminiEmbeddingAdapter(BaseEmbeddingAdapter):
"""
基于 Google Generative AI (Gemini) 接口的 Embedding 适配器
使用直接 POST 请求方式,URL 示例:
https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key=YOUR_API_KEY
"""
def __init__(self, api_key: str, model_name: str, base_url: str):
"""
:param api_key: 传入的 Google API Key
:param model_name: 这里一般是 "text-embedding-004"
:param base_url: e.g. https://generativelanguage.googleapis.com/v1beta/models
"""
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url.rstrip("/")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in texts:
vec = self._embed_single(text)
embeddings.append(vec)
return embeddings
def embed_query(self, query: str) -> List[float]:
return self._embed_single(query)
def _embed_single(self, text: str) -> List[float]:
"""
直接调用 Google Generative Language API (Gemini) 接口,获取文本 embedding
"""
url = f"{self.base_url}/{self.model_name}:embedContent?key={self.api_key}"
payload = {
"model": self.model_name,
"content": {
"parts": [
{"text": text}
]
}
}
try:
response = requests.post(url, json=payload)
print(response.text)
response.raise_for_status()
result = response.json()
embedding_data = result.get("embedding", {})
return embedding_data.get("values", [])
except requests.exceptions.RequestException as e:
logging.error(f"Gemini embed_content request error: {e}\n{traceback.format_exc()}")
return []
except Exception as e:
logging.error(f"Gemini embed_content parse error: {e}\n{traceback.format_exc()}")
return []
def create_embedding_adapter(
interface_format: str,
api_key: str,
base_url: str,
model_name: str
) -> BaseEmbeddingAdapter:
"""
工厂函数:根据 interface_format 返回不同的 embedding 适配器实例
"""
fmt = interface_format.strip().lower()
if fmt == "openai":
return OpenAIEmbeddingAdapter(api_key, base_url, model_name)
elif fmt == "azure openai":
return AzureOpenAIEmbeddingAdapter(api_key, base_url, model_name)
elif fmt == "ollama":
return OllamaEmbeddingAdapter(model_name, base_url)
elif fmt == "ml studio":
return MLStudioEmbeddingAdapter(api_key, base_url, model_name)
elif fmt == "gemini":
return GeminiEmbeddingAdapter(api_key, model_name, base_url)
else:
raise ValueError(f"Unknown embedding interface_format: {interface_format}")