from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from numpy import *
import numpy as np
import time
import struct
#读取图片
def read_image(file_name):
#先用二进制方式把文件都读进来
file_handle=open(file_name,"rb") #以二进制打开文档
file_content=file_handle.read() #读取到缓冲区中
offset=0
head = struct.unpack_from('>IIII', file_content, offset) # 取前4个整数,返回一个元组
offset += struct.calcsize('>IIII')
imgNum = head[1] #图片数
rows = head[2] #宽度
cols = head[3] #高度
images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法
image_size=rows*cols#单个图片的大小
fmt='>' + str(image_size) + 'B'#单个图片的format
for i in range(imgNum):
images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
# images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
offset += struct.calcsize(fmt)
return images
#读取标签
def read_label(file_name):
file_handle = open(file_name, "rb") # 以二进制打开文档
file_content = file_handle.read() # 读取到缓冲区中
head = struct.unpack_from('>II', file_content, 0) # 取前2个整数,返回一个元组
offset = struct.calcsize('>II')
labelNum = head[1] # label数
# print(labelNum)
bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
return np.array(label)
def loadDataSet():
train_x_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/train-images-idx3-ubyte"
train_y_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/train-labels-idx1-ubyte"
test_x_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/t10k-images-idx3-ubyte"
test_y_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/t10k-labels-idx1-ubyte"
train_x=read_image(train_x_filename)
train_y=read_label(train_y_filename)
test_x=read_image(test_x_filename)
test_y=read_label(test_y_filename)
return train_x, test_x, train_y, test_y
if __name__=='__main__':
print("Start reading data...")
time1=time.time()
train_x, test_x, train_y, test_y = loadDataSet()
clf = LogisticRegression()
clf.fit(train_x, train_y)
y_pred = clf.predict(test_x)
print('准确率:'+accuracy_score(test_y, y_pred))
Logistic 回归(mnist数据集)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
推荐阅读更多精彩内容
- MNIST数据集是一个入门级的计算机视觉数据集,它包含各种手写数字照片,它也包含每一张图片对应的标签,告诉我们这是...
- 本文作者:陈 鼎,中南财经政法大学统计与数学学院文字编辑:任 哲技术总编:张馨月 Logistic回归分析是一...
- MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范” 1.逻辑回...