文章源于下面对话

做Python开发三年多了,一直不怎么擅长线程类编程。近来在开发一个项目时又遇到了多线程提速的问题,项目需求比较简单,获取到视频链接后下载到指定文件夹。
我寻思着反正是做个demo,随便写写,for循环获取视频链接再下载,能满足功能就可以了,于是出现了文章开头的对话。唉!客户就是上帝,还是政府客户,要求得满足啊!给他们来个多线程下载吧,一个线程获取视频链接,塞到队列里,再起个多线程从队列拿视频链接下载。
# 只要队列不为空,就一直取元素并下载
while not queue.empty():
item = queue.get()
download(item)
当我开发完之后运行发现,不对劲啊,队列里面还有数据咋下载线程就结束了呢,仔细一排查发现获取视频链接的速度赶不上拿视频链接下载的速度。这好办呀,给它加个结束标志,只有拿到这个标志时才可以结束下载线程。
# 结束标志
sentinel = object()
# 线程循环取元素并下载,直到拿到结束标志,才退出线程
while True:
item = queue.get()
if item == sentinel:
break
else:
download(item)
紧接着运行又出现问题了,程序执行到一半不执行了,再次排查发现,n个下载线程只有一个下载线程拿到了结束标志退出了,其他还跑着呢,看来拿到队列结束标志后还得再塞回队列里,让其他线程也能获取到。
# 结束标志
sentinel = object()
# 线程循环取元素并下载,直到拿到结束标志,将其塞回队列供其它运行的线程继续获取,然后才退出线程
while True:
item = queue.get()
if item == sentinel:
queue.put(item)
break
else:
download(item)
其实这里获取视频链接塞入队列就相当于生产者,从队列拿视频链接并下载就是消费者,于是一个生产者多个消费者的模型就出来了。又加了一点修改,将其封装成了多个生产者多个消费者的模型,代码如下。
# -*- coding:utf-8 -*-
import queue
import threading
from loguru import logger
from collections import deque
producer_count = 3
consumer_count = 9
queue = queue.Queue()
sentinels = [object() for _ in range(0, producer_count)]
class Producer(threading.Thread):
def __init__(self, name, sentinel):
super().__init__(name=name)
self.queue = queue
self.sentinel = sentinel
logger.debug('{} 已创建'.format(name))
def run(self):
for item in range(1, 201):
self.queue.put(item)
logger.info("{} 已生产 {} 到队列".format(self.getName(), item))
self.queue.put(self.sentinel)
logger.debug("{} 完成生产并销毁".format(self.getName()))
class Consumer(threading.Thread):
def __init__(self, name):
super().__init__(name=name)
self.queue = queue
self.deque = deque()
logger.debug('{} 已创建'.format(name))
def run(self):
while True:
item = self.queue.get()
if item in sentinels:
if item not in self.deque:
self.deque.append(item)
self.queue.put(item)
if len(self.deque) == len(sentinels):
logger.debug('{} 已销毁'.format(self.getName()))
break
else:
logger.info("{} 已经被 {} 消费".format(item, self.getName()))
def main():
producer_threads = []
for i in range(0, producer_count):
producer = Producer('生产者' + str(i + 1), sentinels[i])
producer_threads.append(producer)
producer.start()
consumer_threads = []
for i in range(0, consumer_count):
consumer = Consumer('消费者' + str(i + 1))
consumer_threads.append(consumer)
consumer.start()
for consumer in consumer_threads:
consumer.join()
for producer in producer_threads:
producer.join()
logger.debug(queue.qsize())
while not queue.empty():
logger.debug(queue.get())
logger.debug(queue.empty())
if __name__ == '__main__':
main()