Towhee自定义大模型能力接口

Towhee 允许用户自定义模型能力接口,通过创建自定义操作器(Operator)或组件,您可以将任何现有模型或自定义模型集成到 Towhee 的管道中。这种灵活性使您能够根据特定需求扩展 Towhee 的功能,构建复杂且高效的 AI 工作流。

1. 基本概念

Towhee 的核心是管道(Pipeline),它由多个操作器(Operator)组成。每个操作器执行特定的任务,如数据预处理、特征提取、模型推理等。通过自定义操作器,您可以集成任何模型或算法,以满足特定的业务需求或技术要求。

2. 创建自定义操作器

以下是创建和使用自定义操作器的步骤:

步骤 1:定义自定义操作器

首先,您需要定义一个自定义操作器类,继承自 towhee.Operator 并实现必要的方法。假设您有一个自定义的文本分类模型,您希望将其集成到 Towhee 管道中。

# my_text_classifier.py
from towhee import operator
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

@operator
class MyTextClassifier:
    def __init__(self, model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.eval()
    
    def __call__(self, text: str) -> dict:
        inputs = self.tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
            confidence, predicted_class = torch.max(scores, dim=-1)
        return {
            "text": text,
            "predicted_class": self.model.config.id2label[predicted_class.item()],
            "confidence": confidence.item()
        }

输出示例:

{'text': 'I love programming!', 'predicted_class': 'POSITIVE', 'confidence': 0.9998}
{'text': 'This movie was terrible.', 'predicted_class': 'NEGATIVE', 'confidence': 0.9985}
{'text': 'The weather is nice today.', 'predicted_class': 'POSITIVE', 'confidence': 0.9872}

3. 高级自定义:参数化操作器

您还可以为自定义操作器添加参数,以使其更加灵活。例如,支持不同的模型或配置选项:

# my_text_classifier.py
from towhee import operator
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

@operator
class MyTextClassifier:
    def __init__(self, model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.eval()
    
    def __call__(self, text: str, model_name: str = None) -> dict:
        if model_name:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            self.model.eval()
        
        inputs = self.tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
            confidence, predicted_class = torch.max(scores, dim=-1)
        return {
            "text": text,
            "predicted_class": self.model.config.id2label[predicted_class.item()],
            "confidence": confidence.item()
        }

使用时,您可以在调用管道时传递不同的模型名称:

classification_pipe = (
    pipe.input("text", model_name="optional_model_name")
    .map(text_classifier)  # 使用自定义操作器
    .output(['text', 'predicted_class', 'confidence'])
)

results = classification_pipe(data, model_name="distilbert-base-uncased-finetuned-sst-2-english")

4. 自定义逻辑与功能

除了集成预训练模型,您还可以在自定义操作器中添加其他逻辑或功能,例如数据增强、后处理、结果过滤等。这使得 Towhee 管道能够满足更复杂的需求。

@operator
class AdvancedTextProcessor:
    def __init__(self, model_name: str):
        self.classifier = MyTextClassifier(model_name)
    
    def __call__(self, text: str) -> dict:
        # 数据预处理
        cleaned_text = self.clean_text(text)
        
        # 分类
        classification = self.classifier(cleaned_text)
        
        # 后处理
        if classification['confidence'] > 0.95:
            classification['status'] = 'high_confidence'
        else:
            classification['status'] = 'low_confidence'
        
        return classification
    
    def clean_text(self, text: str) -> str:
        # 示例清洗函数
        return text.strip().lower()

5. 整合多模型与多步骤流程

您可以在管道中集成多个自定义操作器,形成复杂的多步骤流程。例如,先进行文本清洗,再进行分类,最后进行结果过滤:

from towhee import pipe
from my_text_classifier import MyTextClassifier
from my_advanced_processor import AdvancedTextProcessor

# 定义操作器实例
text_cleaner = AdvancedTextProcessor(model_name="distilbert-base-uncased-finetuned-sst-2-english")

# 定义管道
full_process_pipe = (
    pipe.input("text")
    .map(text_cleaner)  # 清洗和分类
    .filter(lambda x: x['status'] == 'high_confidence')  # 过滤低置信度结果
    .output(['text', 'predicted_class', 'confidence'])
)

# 运行管道
results = full_process_pipe(data)

# 打印结果
for result in results:
    print(result)

6. 部署与扩展

Towhee 支持将自定义管道部署为服务,使前端应用能够通过 API 调用这些管道。这可以通过将管道与 Web 框架(如 Flask 或 FastAPI)集成来实现。

示例:使用 FastAPI 部署自定义管道

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from towhee import pipe
from my_text_classifier import MyTextClassifier

app = FastAPI()

# 定义操作器
text_classifier = MyTextClassifier()

# 定义管道
classification_pipe = (
    pipe.input("text")
    .map(text_classifier)
    .output(['text', 'predicted_class', 'confidence'])
)

# 定义请求模型
class TextRequest(BaseModel):
    text: str

@app.post("/classify")
def classify_text(request: TextRequest):
    try:
        results = classification_pipe([{"text": request.text}])
        return results[0]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# 运行 Server
# 在命令行运行: uvicorn your_script_name:app --reload
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容