Tensorflow.js 3. 构建图像分类器

在本章中,我们将通过构建几个检测图像中对象的 Web 应用程序来更深入地了解 TensorFlow.js 的功能。

将会有更完整的代码示例和说明,让您更好地了解如何在您的项目中实施机器学习。

3.1 使用预训练模型
我们将要构建的第一个项目是一个快速游戏,系统会提示您找到周围的特定物体,使用设备的相机拍摄它们的照片,然后检查机器学习模型是否识别它们。

输出如下:



这个项目的核心是我们之前谈到的同一个对象检测模型,称为 mobilenet 。

该模型使用开源 ImageNet 数据库进行预训练,该数据库由按 1000 个不同类别组织的图像组成。

这意味着该模型能够根据它所训练的数据识别 1000 个不同的对象。

要开始这个项目,我们需要导入 TensorFlow.js 和 mobilenet 模型。

有两种方法可以做到这一点。 您可以使用 HTML 文件中的脚本标签导入它们。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"></script>

或者,如果您使用的是前端框架,例如 React.js,则可以在依赖项中安装 TensorFlow.js,然后将其导入到 JavaScript 文件中。

在您的终端中:

npm install @tensorflow/tfjs @tensorflow-models/mobilenet
yarn add @tensorflow/tfjs @tensorflow-models/mobilenet

在您的 JavaScript 文件中:

import "@tensorflow/tfjs";
import "@tensorflow-models/mobilenet";

导入这两个文件使我们能够访问 tf 和 mobilenet 对象。

我们需要采取的第一步是在应用程序中加载模型。

async function app(){
       const model = await model.load();
}

模型是非常重的文件,加载它们可能需要几秒钟,因此应该使用 async/await 加载。

如果你想知道这个对象包含什么,你可以记录它并查看它的属性。

请记住,您不必了解对象中的每个属性才能使用它。

然而,可能有趣的属性之一是模型中的输入属性。



此属性向我们展示了用于训练模型的输入类型。在这种情况下,我们可以看到使用了图像,考虑到这是一个对象检测模型,这是有道理的。更重要的是,我们可以看到训练过程中使用的数据的形状。

shape 属性显示值 [-1, 224, 224, 3],这意味着输入模型的图像是大小为 224*224 的 RGB 图像(数组末尾的值 3 表示数量或通道)像素。

这个值对于本章的下一部分特别有趣,我们将在那里研究使用 mobilenet 模型进行迁移学习。

随意进一步探索模型。

构建此应用程序的下一步是允许 TensorFlow.js 访问来自网络摄像头的输入,以便能够运行预测和检测对象。

由于我们的项目使用设备的网络摄像头,因此我们的 HTML 中有一个 <video> 元素。

在 JavaScript 中,我们需要访问此元素并使用 TensorFlow 的方法之一从数据 API 创建一个对象,该对象可以将图像捕获为张量。

const webcamElement = document.getElementsByTagName("video")[0];
const webcam = await tf.data.webcam(webcamElement);

这两行仍然是我们应用程序设置过程的一部分。 目前,我们只加载了模型并创建了这个网络摄像头变量,它将把快照从相机转换为张量。

现在,为了实现逻辑,我们需要从向 HTML 中添加一个简单的按钮开始。 它将用于在单击时触发图像捕获。

<button class="capture-image">SNAP</button>

在我们的 JavaScript 文件中,我们需要访问此元素,使用 onclick 事件侦听器,并使用 TensorFlow.js 捕获图像并对其进行分类。

const captureButton = document.getElementsByClassName("capture-image")[0];
captureButton.onclick = async () => {
    const img = await webcam.capture();
    const predictions = await model.classify(img);
    return predictions;
};

为了从视频源中捕获图像,TensorFlow.js 有一个 capture() 内置方法,需要在之前使用 tf.data.webcam 创建的对象上调用该方法。

它允许将单个图像直接转换为张量,以便可以轻松地与其他 TensorFlow.js 操作一起使用。

捕获图像后,我们通过在 mobilenet.classify 中传递它来生成预测。

这将返回一组预测。

例如,这张塑料瓶的照片将返回以下预测数组。




如您所见,第一个预测,即模型最有信心的预测,带有“pop 瓶、汽水瓶”的标签。 它成功地检测到图像中存在瓶子; 然而,即使它是正确的结果,概率也确实很差。

预测的置信水平只有 30% 的事实可能是由于对象背后的背景。 背景越复杂,模型就越难找到图像中的对象并对其进行分类。

这个问题更多地与计算机视觉领域本身有关,而不是框架问题。

如下图所示,如果您尝试在更清晰的背景上拍摄相同的图片,预测的质量似乎要好得多。



不仅概率要高得多,接近 89%,而且以下预测也更准确。

在第一个例子中,第二个预测是“真空吸尘器”,这远非准确,但在这里,它返回了“水瓶”,这是一个更接近事实的结果。

如果您计划将对象检测集成到您的应用程序中,则绝对应该考虑此限制。 考虑将使用您的项目的上下文对于避免糟糕的用户体验很重要。

最后,这个过程还有最后一步。 我们需要清理不再需要的内存。 一旦图像被捕获并提供给 TensorFlow.js 进行分类,我们就不再需要它了,因此应该释放它占用的内存。

为此,TensorFlow.js 提供了您像这样使用的 dispose 方法。

img.dispose();

我们已经介绍了对象检测逻辑的主要部分。然而,游戏的第一部分是提示寻找特定的物体来拍照。

此代码不是 TensorFlow.js 特定的,它可以是一个简单的 UI,每次您成功找到前一个对象时,它都会要求您找到一个新对象。

但是,如果您的 UI 要求您查找手机,则需要确保模型已使用手机图片进行训练,以便它可以检测到正确的对象。

幸运的是,在 https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/src/imagenet_classes.ts 的存储库中提供了可以被 mobilenet 模型识别的对象类列表。

如果您在应用程序中导入此列表,则您的代码可以循环遍历这个包含 1000 个条目的对象,并在 UI 中随机显示一个,以要求用户在它们周围找到这个对象。

由于此代码不涉及 TensorFlow.js 库的使用,因此我们不会在本书中对其进行介绍。

但是,如果您想查看之前显示的所有代码示例如何组合在一起,那么它应该是什么样子的。

<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0"
    />
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"></script>
    <title>Snap it</title>
  </head>
  <body>
    <main>
      <section class="content">
        <h1>Snap it</h1>
        <video></video>
        <button>SNAP</button>
      </section>
    </main>
  </body>
  <script src="index.js"></script>
</html>
async function app() {
  const webcamElement = document.getElementsByTagName("video")[0];
  const model = await mobilenet.load();
  const webcam = await tf.data.webcam(webcamElement);
  const captureButton = document.getElementsByTagName("button")[0];
  captureButton.onclick = async () => {
    const img = await webcam.capture();
    const predictions = await model.classify(img);
    img.dispose();
    return predictions;
  };
}
app();

在本子章节中,我们使用对象检测来构建一个小游戏,但它可以用于非常不同的应用程序。

3.2 迁移学习

使用预先训练的模型非常有用,可以让您非常快速地构建项目,但如果您发现自己需要更多定制的东西,您可以很快达到其极限。

在本子章节中,我们将利用我们在前几页中编写的部分代码,并调整它们以使用自定义输入数据。

我们将从我们的网络摄像头收集自定义数据样本,以构建一个可以识别我们头部运动的模型。然后可以将其用作界面的潜在控件,因此您可以想象使用此模型通过上下倾斜头部或使用相同的动作来导航地图来滚动网页。

该项目将专注于训练模型以识别新样本并测试其预测。

您将在接下来的几页中阅读的代码将生成一个界面,其中包含用于收集新数据的按钮和用于运行预测的附加按钮。结果将显示在页面上,供您验证模型的准确性。




正如您在前面的屏幕截图中看到的那样,可以准确预测向下和向左之间的头部运动。

首先,我们需要导入 TensorFlow.js、mobilenet 模块和 K-近邻分类器。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

如前所述,我们还需要一个视频元素来显示网络摄像头提要、一些按钮和一个段落来显示我们的预测结果。

<video class="webcam"></video>
<section class="buttons">
      <button>Up</button>
      <button>Down</button>
      <button>Left</button>
      <button>Right</button>
</section>
<section class="buttons">
      <button class="predict">Predict</button>
</section>
<p class="prediction"></p>

在 JavaScript 文件中,我们需要编写逻辑,当我们单击按钮并将其提供给 KNN 分类器时,将从网络摄像头收集样本。

在深入研究逻辑之前,我们需要先为分类器、模型和网络摄像头实例化一些变量。

const classifier = knnClassifier.create();
const net = await mobilenet.load();
const webcam = await tf.data.webcam(webcamElement);

在最后一行,webcamElement 变量指的是您将通过使用标准文档接口方法(例如 getElementsByClassName)获得的 HTML 视频元素。

为了实现逻辑,我们可以创建一个新的函数,我们将调用 addExample。 该函数将从网络摄像头捕获图像,将其转换为张量,使用图像张量及其标签重新训练 mobilenet 模型,将该示例添加到 KNN 分类器,并处理该张量。

这听起来可能很多,但执行此操作所需的代码实际上只有几行。

const addExample = async classId => {
      const img = await webcam.capture();
      const activation = net.infer(img, "conv_preds");
      classifier.addExample(activation, classId);
      img.dispose();
};

第二行允许我们从网络摄像头提要中捕获单个图像并将其直接转换为张量,因此它可以立即与其他 TensorFlow.js 方法一起使用。

激活变量保存使用来自网络摄像头的新图像张量重新训练的 mobilenet 模型的值,使用其激活函数之一称为“conv_preds”。

激活函数是一种帮助神经网络学习数据中复杂模式的函数。

下一步是使用重新训练模型的结果并将其作为示例添加到我们的分类器中,并使用类 ID 将新样本映射到其标签。

在机器学习中,尽管我们通常将标签视为字符串,例如,在我们的例子中“右”、“左”等,但在训练过程中,这些标签实际上是与它们在标签数组中的索引交换的.

如果我们的类是 ["up", "down", "left", "right"],当我们训练模型识别我们的头部向下移动时,类 ID 将为 1,因为“向下”是我们数组中的第二个元素.

最后,一旦使用了图像张量,我们就会处理它,以释放一些内存。

当我们单击四个按钮之一时,需要触发此 addExample 方法。

for (var i = 0; i < buttons.length; i++) {
    if (buttons[i] !== predictButton) {
      let index = i;
      buttons[i].onclick = () => addExample(index);
    }
}

考虑到按钮变量保存了 DOM 中存在的按钮元素,我们希望在所有按钮上触发我们的 addExample 函数,除了用于运行预测的按钮。

我们将按钮索引传递给函数,因此当我们单击“向上”按钮时,例如,类 ID 将为 0。

这样,每次我们单击四个按钮之一时,都会将一个示例添加到分类器中,并带有相应的类 ID。

一旦我们重新训练我们的模型几次,我们就可以点击预测按钮来运行实时预测。

predictButton.onclick = () => runPredictions();

这个 runPredictions 函数将重复与前面解释的类似的步骤; 然而,它不会将示例添加到 KNN 分类器,而是会触发 predictClass 方法根据我们刚刚经历的训练过程对来自网络摄像头的实时输入进行分类。

async function runPredictions() {
    while (true) {
      if (classifier.getNumClasses() > 0) {
        const img = await webcam.capture();
        const activation = net.infer(img, "conv_preds");
        const result = await classifier.predictClass(activation);
        predictionParagraph.innerText = `
           prediction: ${classes[result.label]},
           probability: ${result.confidences[result.label]}`;
        img.dispose();
      }
      await tf.nextFrame();
    }
}

在前面的示例中,我们将逻辑包装在一个 while 循环中,因为我们希望连续预测来自网络摄像头的输入; 但是,如果您只想在单击元素后获得预测,您也可以将其替换为 onclick 事件。

如果分类器已经用新样本训练过,我们重复从网络摄像头捕获图像并将其与 mobilenet 模型一起使用的两个步骤。

const img = await webcam.capture();
const activation = net.infer(img, "conv_preds");

然后我们将这些数据传递到在 KNN 分类器上调用的 predictClass 方法中以预测其标签。

调用此方法的结果是一个对象,其中包含一个 classIndex、一个标签和一个名为 confidences 的对象。



在这种情况下,我将头向右倾斜,因此 classIndex 和 label 返回值为 3,因为训练模型识别此手势的按钮是 4 中的最后一个。

置信对象向我们展示了预测标签的概率。 值为 1 表示模型非常确信识别出的手势是正确的。

概率值可以在 0 和 1 之间变化。

从预测中得到结果后,我们处理图像以释放一些内存。

最后,我们调用 tf.nextFrame() 等待 requestAnimationFrame 完成,然后再次运行此代码并预测下一帧的类。

以下是代码的整体工作方式。

<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0"
     />
    <title>Transfer learning</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
  </head>
  <body>
    <main>
      <section class="content">
        <video class="webcam"></video>
        <section class="buttons">
          <button>Up</button>
          <button>Down</button>
          <button>Left</button>
          <button>Right</button>
        </section>
        <section class="buttons">
          <button class="predict">Predict</button>
        </section>
        <p class="prediction"></p>
      </section>
    </main>
    <script src="index.js"></script>
  </body>
</html>
const webcamElement = document.getElementsByClassName("webcam")[0];
const buttons = document.getElementsByTagName("button");
const predictButton = document.getElementsByClassName("predict")[0];
const classes = ["up", "down", "left", "right"];
const predictionParagraph = document.getElementsByClassName("prediction")[0];
async function app() {
  const classifier = knnClassifier.create();
  const net = await mobilenet.load();
  const webcam = await tf.data.webcam(webcamElement);
  const addExample = async classId => {
    const img = await webcam.capture();
    const activation = net.infer(img, "conv_preds");
    classifier.addExample(activation, classId);
    img.dispose();
  };
  for (var i = 0; i < buttons.length; i++) {
    if (buttons[i] !== predictButton) {
      let index = i;
      buttons[i].onclick = () => addExample(index);
    }
  }
  predictButton.onclick = () => runPredictions();
  async function runPredictions() {
    while (true) {
      if (classifier.getNumClasses() > 0) {
        const img = await webcam.capture();
        const activation = net.infer(img, "conv_preds");
        const result = await classifier.predictClass(activation);
        predictionParagraph.innerText = `
            prediction: ${classes[result.label]},
            probability: ${result.confidences[result.label]}`;
        img.dispose();
      }
      await tf.nextFrame();
    }
  }
}
app();

使用迁移学习使我们能够非常快速地重新训练模型以适应定制的输入。 只需几行代码,我们就能够创建自定义的图像分类模型。

根据您提供的新输入数据,您可能需要添加更多或更少的新示例才能获得准确的预测,但这总是比收集全新的标记数据集并从头开始创建自己的机器学习模型要快。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,233评论 6 495
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,357评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,831评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,313评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,417评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,470评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,482评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,265评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,708评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,997评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,176评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,827评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,503评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,150评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,391评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,034评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,063评论 2 352

推荐阅读更多精彩内容