Towhee 允许用户自定义模型能力接口,通过创建自定义操作器(Operator)或组件,您可以将任何现有模型或自定义模型集成到 Towhee 的管道中。这种灵活性使您能够根据特定需求扩展 Towhee 的功能,构建复杂且高效的 AI 工作流。
1. 基本概念
Towhee 的核心是管道(Pipeline),它由多个操作器(Operator)组成。每个操作器执行特定的任务,如数据预处理、特征提取、模型推理等。通过自定义操作器,您可以集成任何模型或算法,以满足特定的业务需求或技术要求。
2. 创建自定义操作器
以下是创建和使用自定义操作器的步骤:
步骤 1:定义自定义操作器
首先,您需要定义一个自定义操作器类,继承自 towhee.Operator
并实现必要的方法。假设您有一个自定义的文本分类模型,您希望将其集成到 Towhee 管道中。
# my_text_classifier.py
from towhee import operator
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@operator
class MyTextClassifier:
def __init__(self, model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
def __call__(self, text: str) -> dict:
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_class = torch.max(scores, dim=-1)
return {
"text": text,
"predicted_class": self.model.config.id2label[predicted_class.item()],
"confidence": confidence.item()
}
输出示例:
{'text': 'I love programming!', 'predicted_class': 'POSITIVE', 'confidence': 0.9998}
{'text': 'This movie was terrible.', 'predicted_class': 'NEGATIVE', 'confidence': 0.9985}
{'text': 'The weather is nice today.', 'predicted_class': 'POSITIVE', 'confidence': 0.9872}
3. 高级自定义:参数化操作器
您还可以为自定义操作器添加参数,以使其更加灵活。例如,支持不同的模型或配置选项:
# my_text_classifier.py
from towhee import operator
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@operator
class MyTextClassifier:
def __init__(self, model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
def __call__(self, text: str, model_name: str = None) -> dict:
if model_name:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_class = torch.max(scores, dim=-1)
return {
"text": text,
"predicted_class": self.model.config.id2label[predicted_class.item()],
"confidence": confidence.item()
}
使用时,您可以在调用管道时传递不同的模型名称:
classification_pipe = (
pipe.input("text", model_name="optional_model_name")
.map(text_classifier) # 使用自定义操作器
.output(['text', 'predicted_class', 'confidence'])
)
results = classification_pipe(data, model_name="distilbert-base-uncased-finetuned-sst-2-english")
4. 自定义逻辑与功能
除了集成预训练模型,您还可以在自定义操作器中添加其他逻辑或功能,例如数据增强、后处理、结果过滤等。这使得 Towhee 管道能够满足更复杂的需求。
@operator
class AdvancedTextProcessor:
def __init__(self, model_name: str):
self.classifier = MyTextClassifier(model_name)
def __call__(self, text: str) -> dict:
# 数据预处理
cleaned_text = self.clean_text(text)
# 分类
classification = self.classifier(cleaned_text)
# 后处理
if classification['confidence'] > 0.95:
classification['status'] = 'high_confidence'
else:
classification['status'] = 'low_confidence'
return classification
def clean_text(self, text: str) -> str:
# 示例清洗函数
return text.strip().lower()
5. 整合多模型与多步骤流程
您可以在管道中集成多个自定义操作器,形成复杂的多步骤流程。例如,先进行文本清洗,再进行分类,最后进行结果过滤:
from towhee import pipe
from my_text_classifier import MyTextClassifier
from my_advanced_processor import AdvancedTextProcessor
# 定义操作器实例
text_cleaner = AdvancedTextProcessor(model_name="distilbert-base-uncased-finetuned-sst-2-english")
# 定义管道
full_process_pipe = (
pipe.input("text")
.map(text_cleaner) # 清洗和分类
.filter(lambda x: x['status'] == 'high_confidence') # 过滤低置信度结果
.output(['text', 'predicted_class', 'confidence'])
)
# 运行管道
results = full_process_pipe(data)
# 打印结果
for result in results:
print(result)
6. 部署与扩展
Towhee 支持将自定义管道部署为服务,使前端应用能够通过 API 调用这些管道。这可以通过将管道与 Web 框架(如 Flask 或 FastAPI)集成来实现。
示例:使用 FastAPI 部署自定义管道
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from towhee import pipe
from my_text_classifier import MyTextClassifier
app = FastAPI()
# 定义操作器
text_classifier = MyTextClassifier()
# 定义管道
classification_pipe = (
pipe.input("text")
.map(text_classifier)
.output(['text', 'predicted_class', 'confidence'])
)
# 定义请求模型
class TextRequest(BaseModel):
text: str
@app.post("/classify")
def classify_text(request: TextRequest):
try:
results = classification_pipe([{"text": request.text}])
return results[0]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 运行 Server
# 在命令行运行: uvicorn your_script_name:app --reload