第56章 字符卷积






hello -> [‘h’, ‘e’, ‘l’, ‘l’, ‘o’]

这样可以看到,一个单词hello被人为拆成了’h’, ‘e’, ‘l’, ‘l’, ‘o’这5个字母。对于hello的处理有两种方法,

  • 独热编码。
  • 字符嵌入。

处理的结果,单词“hello”将被转成一个[5, n]的矩阵。本例将采用独热编码的方法处理。


图1 使用CNN处理字符文本分类的原理

对于AG News数据集来说,每条新闻都有对应的分类,也有标题和正文。对于正文的抽取在前几章中已经介绍。这里直接对新闻标题进行处理,如下所示,

3 Wall St. Bears Claw Back Into the Black (Reuters)
3 Wall St. Bears Claw Back Into the Black (Reuters)
3 Carlyle Looks Toward Commercial Aerospace (Reuters)
3 Oil and Economy Cloud Stocks' Outlook (Reuters)
3 Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)
3 Oil prices soar to all-time record, posing new menace to US economy (AFP)
3 Stocks End Up, But Near Year Lows (Reuters)
3 Money Funds Fell in Latest Week (AP)
3 Fed minutes show dissent over inflation (USATODAY.com)
3 Safety Net (Forbes.com)
3 Wall St. Bears Claw Back Into the Black

由于只对文本标题进行处理,因此在进行数据清洗时不用处理时不用处理停用词和进行词根还原。对于空格,由于是字符计算,因此不需要保留,直接删除即可。 修改原来代码如下,

def stop_words():
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        ssl._create_default_https_context = _create_unverified_https_context
    nltk.download("stopwords", download_dir = "/tmp/");
    stops = nltk.corpus.stopwords.words("English")
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # Trim the string
    string = string.strip()
    string = string + "eos"
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # Trim the string
    string = string.strip()
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    strings = ["bos"] + strings
    return strings

def setup():
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        labels = []
        titles = []
        descriptions = []
        trains = csv.reader(handler)
        for line in trains:
        return labels, titles, descriptions


def one_hot(strings):
    alphabet = "abcdefghijklmnopqrstuvwxyz"


[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


alphabet = "abcdefghijklmnopqrstuvwxyz"

def one_hot(characters):
    array = numpy.array(characters)
    length = len(alphabet) + 1
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array th>
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    return eyes


def indexes_of(characters):
    indexes = []
    for character in characters:
        index = alphabet.index(character)
    return indexes

def train():
    string = "hello"
    indexes = indexes_of(string)
    print("string =", string, ", indexes =", indexes)

if __name__ == "__main__":


string = hello , indexes = [7, 4, 11, 11, 14]


import numpy

def one_hot(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    return eyes

def indexes_of(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    indexes = []
    for character in characters:
        index = alphabet.index(character)
    return indexes

def indexes_matrix(string):
    indexes = indexes_of(string)
    matrix = one_hot(indexes)
    return matrix

def train():
    #labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    string = "hello"
    indexes = indexes_matrix(string)
    print("string =", string, ", indexes =", indexes)

if __name__ == "__main__":


string = hello , indexes = [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]]

可以看到,单词“hello”被转换成一个[5, 26]大小的矩阵,供下一步处理。有了上面定义的方法,下一步就是对新闻标题进行独热编码处理。代码如下,

import numpy
import sys
import AgNewsCsvReader

def one_hot(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    return eyes

def indexes_of(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    indexes = []
    for character in characters:
        index = alphabet.index(character)
    return indexes

def indexes_matrix(string):
    indexes = indexes_of(string)
    matrix = one_hot(indexes)
    return matrix

def train():
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    for title in titles[: 10]:
        indexes = indexes_matrix(title)
        print("string =", title, ", indexes.shape =", indexes.shape)

if __name__ == "__main__":


string = wallstbearsclawbackintotheblackreuterseos , indexes.shape = (41, 27)
string = carlylelookstowardcommercialaerospacereuterseos , indexes.shape = (47, 27)
string = oilandeconomycloudstocksoutlookreuterseos , indexes.shape = (41, 27)
string = iraqhaltsoilexportsfrommainsouthernpipelinereuterseos , indexes.shape = (53, 27)
string = oilpricessoartoalltimerecordposingnewmenacetouseconomyafpeos , indexes.shape = (60, 27)
string = stocksendupbutnearyearlowsreuterseos , indexes.shape = (36, 27)
string = moneyfundsfellinlatestweekapeos , indexes.shape = (31, 27)
string = fedminutesshowdissentoverinflationusatodaycomeos , indexes.shape = (48, 27)
string = safetynetforbescomeos , indexes.shape = (21, 27)
string = wallstbearsclawbackintotheblackeos , indexes.shape = (34, 27)




def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    length = len(string)
    if length > maximum_length:
        string = string[: maximum_length]
        matrix = indexes_matrix(string)
        return matrix
        matrix = indexes_matrix(string)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        return matrix

def train():

    string = "hello"
    indexes = indexes_matrix(string)
    print("string =", string, ", indexes =", indexes)
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    for title in titles[: 10]:
        indexes = align_string_matrix(title)
        print("string =", title, ", indexes.shape =", indexes.shape)
if __name__ == "__main__":


  • 大于maximum_length(默认64,可根据需要自行设置该值)的字符串,截取前部分进行矩阵转换。
  • 长度小于maximum_length的,先生成由0构成的补全矩阵,再与原矩阵进行串接(numpy.concatenate)。


string = wall st bears claw back into the black reuterseos , indexes.shape = (64, 27)
string = carlyle looks toward commercial aerospace reuterseos , indexes.shape = (64, 27)
string = oil and economy cloud stocks outlook reuterseos , indexes.shape = (64, 27)
string = iraq halts oil exports from main southern pipeline reuterseos , indexes.shape = (64, 27)
string = oil prices soar to all time record posing new menace to us economy afpeos , indexes.shape = (64, 27)
string = stocks end up but near year lows reuterseos , indexes.shape = (64, 27)
string = money funds fell in latest week apeos , indexes.shape = (64, 27)
string = fed minutes show dissent over inflation usatoday comeos , indexes.shape = (64, 27)
string = safety net forbes comeos , indexes.shape = (64, 27)
string = wall st bears claw back into the blackeos , indexes.shape = (64, 27)



def one_hot_numbers(numbers):
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    eyes = numpy.eye(maximum)[array]
    return eyes

def train():

    string = "hello"
    indexes = indexes_matrix(string)
    print("string =", string, ", indexes =", indexes)
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    for title in titles[: 10]:
        indexes = align_string_matrix(title)
        print("string =", title, ", indexes.shape =", indexes.shape)
    one_hoted_labels = one_hot_numbers(labels)
    print("one_hoted_labels.shape = ", one_hoted_labels.shape)
if __name__ == "__main__":



import csv
import re
import jax
import ssl
import nltk
def stop_words():
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        ssl._create_default_https_context = _create_unverified_https_context
    nltk.download("stopwords", download_dir = "/tmp/");
    stops = nltk.corpus.stopwords.words("English")
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # string = re.sub(pattern = " ", repl = "", string = string)
    # Trim the string
    string = string.strip()
    string = string + "eos"
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # Trim the string
    string = string.strip()
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    strings = ["bos"] + strings
    return strings

def setup():
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        labels = []
        titles = []
        descriptions = []
        trains = csv.reader(handler)
        trains = list(trains)
        for i in range(len(trains)):
            line = trains[I]
        return labels, titles, descriptions


import numpy
import sys
import AgNewsCsvReader

def one_hot(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    return eyes

def one_hot_numbers(numbers):
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    eyes = numpy.eye(maximum)[array]
    return eyes

def indexes_of(characters, alphabet = None):
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    indexes = []
    for character in characters:
        index = alphabet.index(character)
    return indexes

def indexes_matrix(string, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    indexes = indexes_of(string, alphabet)
    matrix = one_hot(indexes, alphabet)
    return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    length = len(string)
    if length > maximum_length:
        string = string[: maximum_length]
        matrix = indexes_matrix(string)
        return matrix
        matrix = indexes_matrix(string)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        return matrix

def train():

    string = "hello"
    indexes = indexes_matrix(string)
    print("string =", string, ", indexes =", indexes)
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    trains = []
    for title in titles[: 10]:
        matrix = align_string_matrix(title)
    trains = numpy.expand_dims(trains, axis = -1)
    labels = one_hot_numbers(labels)
    print("trains.shape =", trains.shape, ", labels.shape =", labels.shape)
if __name__ == "__main__":


trains.shape = (120000, 64, 27, 1) , labels.shape = (120000, 5)


  • 训练集的维度为[120000, 64, 27, 1],第一个数字代表样本总数,第二个和第三个数字为生成的矩阵维度,最后一个1代表这里只使用1个通道。
  • 标签数据为[120000, 5],是一个二维矩阵,120000是样本的总数,5是类别。注意,one-hot是从0开始的,而标签的分类是从1开始的,因此会自动添加一个0的标签。




层级 名称
1 Conv 3 x 3, 1 x 1
2 Conv 5 x 5, 1 x 1
3 Conv 3 x 3, 1 x 1
4 Fully Connected 256
5 Fully Connected 5


def cnn(number_classes):
    return jax.example_libraries.stax.serial(
        jax.example_libraries.stax.Conv(1, (3, 3)),
        jax.example_libraries.stax.Conv(1, (5, 5)),



import csv
import re
import jax
import ssl
import nltk
def stop_words():
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        ssl._create_default_https_context = _create_unverified_https_context
    nltk.download("stopwords", download_dir = "/tmp/");
    stops = nltk.corpus.stopwords.words("English")
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # string = re.sub(pattern = " ", repl = "", string = string)
    # Trim the string
    string = string.strip()
    string = string + " eos"
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    string = string.lower()
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # Trim the string
    string = string.strip()
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    strings = ["bos"] + strings
    return strings

def setup():
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        train_labels = []
        train_titles = []
        train_descriptions = []
        trains = csv.reader(handler)
        trains = list(trains)
        for i in range(len(trains)):
            line = trains[I]
    with open("../../Shares/ag_news_csv/test.csv", "r") as handler:
        test_labels = []
        test_titles = []
        test_descriptions = []
        tests = csv.reader(handler)
        tests = list(tests)
        for i in range(len(tests)):
            line = tests[I]
        return (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions)

def main():
    (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = setup()
    print((train_labels.shape, train_titles.shape, train_descriptions.shape), (test_labels.shape, test_titles.shape, test_descriptions.shape))
if __name__ == "__main__":


import numpy
import jax
import jax.example_libraries.stax
import jax.example_libraries.optimizers
import sys


import AgNewsCsvReader

def one_hot(characters, alphabet):
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    return eyes

def one_hot_numbers(numbers):
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    eyes = numpy.eye(maximum)[array]
    return eyes

def indexes_of(characters, alphabet):
    indexes = []
    for character in characters:
        index = alphabet.index(character)
    return indexes

def indexes_matrix(string, alphabet):
    indexes = indexes_of(string, alphabet)
    matrix = one_hot(indexes, alphabet)
    return matrix

def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    length = len(string)
    if length > maximum_length:
        string = string[: maximum_length]
        matrix = indexes_matrix(string, alphabet)
        return matrix
        matrix = indexes_matrix(string, alphabet)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        return matrix
def cnn(number_classes):
    return jax.example_libraries.stax.serial(
        jax.example_libraries.stax.Conv(1, (3, 3)),
        jax.example_libraries.stax.Conv(1, (5, 5)),

def setup():
    prng = jax.random.PRNGKey(15)
    (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = AgNewsCsvReader.setup()
    train_texts = []
    for title in train_titles:
        matrix = align_string_matrix(title)
    train_texts = numpy.expand_dims(train_texts, axis = -1)
    train_labels = one_hot_numbers(train_labels)
    test_texts = []
    for title in test_titles:
        matrix = align_string_matrix(title)
    test_texts = numpy.expand_dims(test_texts, axis = -1)
    test_labels = one_hot_numbers(test_labels)
    number_classes = 5
    input_shape = [-1, 64, 28, 1]
    batch_size = 100
    epochs = 5
    init_random_params, predict = cnn(number_classes)
    optimizer_init_function, optimizer_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 2.17e-4)
    _, init_params = init_random_params(prng, input_shape = input_shape)
    optimizer_state = optimizer_init_function(init_params)
    return (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, tes>
def verify_accuracy(params, batch, predict_function):
    inputs, targets = batch
    predictions = predict_function(params, inputs)
    class_ = jax.numpy.argmax(predictions, axis = 1)
    targets = jax.numpy.argmax(targets, axis = 1)
    return jax.numpy.sum(predictions == targets)

def loss_function(params, batch, predict_function):
    inputs, targets = batch
    predictions = predict_function(params, inputs)
    losses = -targets * predictions
    losses = jax.numpy.sum(losses, axis = 1)
    losses = jax.numpy.mean(losses)
    return losses

def update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict_function):
    params = get_params_function(optimizer_state)
    loss_function_grad = jax.grad(loss_function)
    gradients = loss_function_grad(params, batch, predict_function)
    return optimizer_update_function(i, gradients, optimizer_state)
def train():
    (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, test_label>
    print("train_texts.shape =", train_texts.shape, ", train_labels.shape =", train_labels.shape, ", test_texts.shape =", test_texts.shape, ", test_labels.shape =", test_labels.shape)
    train_batch_number = int(len(train_texts) / batch_size)
    test_batch_number = int(len(test_texts) / batch_size)
    for i in range(epochs):
        print(f"Epoch {i} started")
        for j in range(train_batch_number):
            start = j * batch_size
            end = (j + 1) * batch_size
            batch = (train_texts[start: end], train_labels[start: end])
            optimizer_state = update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict)
            if (j + 1) % 10 == 0:
                params = get_params_function(optimizer_state)
                losses = loss_function(params, batch)
                print("Losses now is =", losses)
        params = get_params_function(optimizer_state)
        print(f"Epoch {i} compeleted")
        accuracies = []
        predictions = 0.0
        for j in range(test_batch_number):
            start = j * batch_size
            end = (j + 1) * batch_size
            batch = (test_texts[start: end], test_labels[start: end])
            predictions += verify_accuracy(params, batch)
        accuracies.append(predictions / float(len(train_texts)))
        print(f"Training accuracies =", accuracies)
if __name__ == "__main__":



本章基于AG News新闻标题和分类标签,使用一层卷积和全连接层建构了一个文本分类模型。注意,这个示例知识为了说明问题,效果并不一定好。

