LinerUnit

线性单元

  1. 感知器有一个问题,当面对的数据集不是线性可分的时候,『感知器规则』可能无法收敛,这意味着我们永远也无法完成一个感知器的训练。为了解决这个问题,我们使用一个连续的线性函数来替代感知器的阶跃函数,这种感知器就叫做线性单元。线性单元在面对线性不可分的数据集时,会收敛到一个最佳的近似上。
  2. 那么线性单元就是将感知机的输出激活函数由分段函数改为了连续函数,进而输出的值域也由{0,1}\rightarrow[-\infty,+\infty]

举例说明

当我们说模型时,我们实际上在谈论根据输入x预测输出y的算法。比如,x可以是一个人的工作年限,y可以是他的月薪,我们可以用某种算法来根据一个人的工作年限来预测他的收入。\\y=w*x+b

其中w,b是可以拟合年限输入和月薪输出的待求权重参数。工作年限称为一个特征,输入可以包含多个特征如:行业,公司,职级等。当特征变多时,对应的每个特征都需要一个权重w_i用于拟合输入和输出之间的关系。
\\y = w_1*x_1+w_2*x_2+\dots+w_n*x_n+b,矩阵表示
y=\textbf{W}^T\textbf{X}\\其中
\textbf{W}=\begin{bmatrix} w_i\\ \vdots \\ w_n\\ b \\ \end{bmatrix}, \textbf{X}=\begin{bmatrix} x_i \\ \vdots \\ x_n \\ 1\\ \end{bmatrix}\\

代码

由于相较于Perceptron只改变了激活函数,所以我们可以继承Perceptron快速实现LinerUnit

class LinerUnit(Perceptron):
    def __init__(self, input_dim, activator) -> None:
        super().__init__(input_dim, activator)

生成训练数据,定义可视化

# 新定义的连续线性激活函数
def liner_activater(x):
    return x

def get_training_dataset():
    """
    construct training_set, consist of n samples
    Working years and corresponding salary.
    """
    data = [[5], [3], [8], [1.4], [10.1], [8.1]]
    labels = [5500, 2300, 7600, 1800, 11400, 20000]
    return data, labels

def train_liner_unit(iterations, lr):
    """
    Train a liner_unit with training_set.
    """
    lu = LinerUnit(input_dim=1, activator=liner_activater)
    lu.train(*get_training_dataset(), iterations=iterations, lr=lr)
    return lu

def show_results(linear_unit, samples):
    """
    Visualize the line after the linear unit fit
    """
    predicts = [linear_unit.predict(s) for s in samples]
    plt.scatter(samples, predicts, marker="o")
    x_fit = np.linspace(start=0, stop=max(samples), num=100)
    y_fit = linear_unit.weights * x_fit + linear_unit.bias
    plt.plot(x_fit, y_fit, linestyle="-")
    plt.xlabel("Working years")
    plt.ylabel("Salary")
    plt.show()

训练,测试,并可视化

if __name__ == "__main__":
    linear_unit = train_liner_unit(10, 0.1)
    test_samples = [[3.4], [15], [1.5], [6.3], [8]]
    # test
    for year in test_samples:
        print(f"Work {year} years, monthly salary = {linear_unit.predict(year)}")

    show_results(linear_unit=linear_unit, samples=test_samples)

结果

控制台输出.png
可视化结果.png
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容