[ 小工具 ] 使用python实现多线程网络文件下载器

前言

其实没打算写的,刷了很长时间没有找到符合需求的工具。

下面写的有些问题,最后使用 Aria2 解决的。

aria2官网:http://aria2.github.io/

aria2使用说明:http://aria2c.com/usage.html

其中遇到最深刻的问题是:比如要下载一个视频,下载完成后播放时,会花屏。解决方法是给文件的写操作加互斥锁,保证写权限只在一个线程中。

# -*- coding:utf-8 -*-
# 多线程文件下载器
import os,requests,math,time,threading
from concurrent.futures import ThreadPoolExecutor

# 每个线程只有一个小于等于 1048576字节 的文件下载任务
class threadDownLoad:
    def __init__(self,fileUrl,savePath,cookie="",threadCount=20,timeOut=150):
        self.fileUrl = fileUrl          # 下载文件url
        self.savePath = savePath        # 下载文件保存路径
        self.cookie = cookie            # 请求时的cookie
        self.threadCount = threadCount  # 下载文件线程数
        self.timeOut = timeOut          # 下载单区块时的超时时间
        self.mutex = threading.Lock()   # 互斥锁
        self.headers = {
            "Cookie":self.cookie,
            "User-Agent":"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.105 Safari/537.36 Edg/84.0.522.52"
        }
        self.fileSize = int(requests.get(self.fileUrl,headers=self.headers,stream=True,timeout=self.timeOut).headers['Content-Length'])
    
    def splitFile(self):
        """分割文件份数"""
        count = math.ceil(self.fileSize/1048576)
        splitList = []
        onSize = 0
        for i in range(0,count):
            if i+1 == count:
                splitList.append([onSize,self.fileSize])
            else:
                splitList.append([onSize,(onSize+1048576)-1])
                onSize = onSize+1048576
        return splitList
    
    def creactFile(self):
        """创建保存的文件"""
        path,_ = self.fileDirAndName(self.savePath)
        if not os.path.exists(path):
            os.mkdir(path)
        tempf = open(self.savePath,'wb')
        tempf.close()

    def download(self):
        """下载"""
        self.creactFile()
        # 创建线程池
        taskList = []
        getSplitFile = self.splitFile()
        pool = ThreadPoolExecutor(self.threadCount)
        for start,end in getSplitFile:
            taskList.append(pool.submit(self.downBlock,start,end))
        # 持续监听线程的完成情况
        startTime = int(time.time()*1000)
        dataLen = len(taskList)
        isAchieve,achieveCount = self.isPoolDone(taskList)
        while not isAchieve:
            when = (int(time.time()*1000) - startTime) / 1000
            self.processBar((achieveCount / (dataLen / 100)) / 100, start_str='正在下载...', end_str="{0}/{1} 用时: {2} 秒".format(achieveCount, dataLen,when),total_length=30)
            time.sleep(0.5)
            isAchieve,achieveCount = self.isPoolDone(taskList)
        # 当有区块下载失败时
        if False in self.getPoolResData(taskList):
            print("")
            return False
        self.processBar(100 / 100, start_str='正在下载...', end_str="{0}/{1} 用时: {2} 秒".format(dataLen, dataLen,(int(time.time()*1000) - startTime) / 1000),total_length=30)
        print("")
        print("successful download!")
        return True

    def downBlock(self,start,end,errorCount=0):
        """下载某区块"""
        try:
            headers = self.headers
            headers["Range"]="bytes={0}-{1}".format(start,end)
            res = requests.get(self.fileUrl,headers=headers,stream=True,timeout=self.timeOut)
            res.raise_for_status()
            self.writeFile(start,end,res.content)
            return True
        except:
            if errorCount < 3:
                # 容错
                return self.downBlock(start,end,errorCount+1)
            return False
    
    def writeFile(self,start,end,content):
        """线程安全的写文件"""
        self.mutex.acquire(self.timeOut)
        with open(self.savePath,'rb') as  fd:
            if fd.seek(start) == -1:
                raise Exception("文件指针移动失败")
            fd.write(content)
            # fd.flush()
        self.mutex.release()

    def isPoolDone(self,tasks):
        """判断所有线程任务是否已经完成"""
        isCount = 0
        for i in tasks:
            if i.done():
                isCount += 1
        return len(tasks) == isCount,isCount

    def getPoolResData(self,tasks):
        """获取所有线程任务的返回值"""
        data = []
        for i in tasks:
            if i.done():
                resData = i.result()
                data.append(resData)
        return data,None

    def fileDirAndName(self,file):
        """
        文件名与路径分离
        """
        name = os.path.basename(file)
        dirStr = file.replace(name, "").replace("/", "\\")
        return dirStr, name
    
    def processBar(self,percent, start_str='', end_str='', total_length=0):
        """进度条"""
        bar = ''.join(['■'] * int(percent * total_length)) + ''
        bar = '\r' + start_str + bar.ljust(total_length) + ' {:0>4.1f}%|'.format(percent*100) + end_str
        print(bar, end='', flush=True)

if __name__ == "__main__":
    for i in range(0,11):
        threadDownLoad("https://gw.alipayobjects.com/mdn/prod_resou/afts/file/A*oIndSLbrp0kAAAAAAAAAAABjARQnAQ","./yuque_{0}.mp4".format(i)).download()