How it works(14) GDAL2Tiles源码阅读

引入

gdal2tiles(以下简称g2t),这个历史悠久的切图脚本依然能发挥其功用,因为它稳定的做好了它应做的东西.相比前面说过的gdal2mbtiles(以下简称g2m),我倒是更喜欢它,单文件脚本,运行只安装一个GDAL库足矣.同样因为有了g2m,我也是带着对比的心态提出几个问题:

  • 从表现来看,g2t更慢
    • 慢的原因是什么
    • 可以采用g2m加速吗
  • 与g2m对比,其算法有何差异

精简

原始的g2t脚本近3000行,包含了详细的注释和一些其他的功能,分析起来会产生干扰,因此我精简掉全部注释和我所用不上的功能,保留了核心功能,方便分析和使用.

相比原来精简掉的内容:

  • 与KML输出相关的功能和变量
  • 与生成HTML脚本相关的功能和变量
  • 只保留了生成web墨卡托投影瓦片的功能
  • 只保留了"average"一种重采样算法
  • 限定只能生成png格式的瓦片

相比原来修改的内容:

  • 将相关的功能整合进类中
  • 修改命名规范为小驼峰

以下是修改后的脚本,不到500行,下面就基于这个修改后的核心脚本进行分析

from xml.etree import ElementTree
import json
from osgeo import gdal, osr
from uuid import uuid4
import sys
import shutil
import tempfile
import os
import math
from multiprocessing import Pool, Process, Manager

MAXZOOMLEVEL = 24
TILESIZE = 256
TILEDRIVER = 'PNG'
TILEEXT = 'png'


class GlobalMercator(object):
    def __init__(self):
        self.initialResolution = 2 * math.pi * 6378137 / TILESIZE
        self.originShift = 2 * math.pi * 6378137 / 2.0

    def MetersToLatLon(self, mx, my):
        lon = (mx / self.originShift) * 180.0
        lat = (my / self.originShift) * 180.0
        lat = 180 / math.pi * (2 * math.atan(math.exp(lat * math.pi / 180.0)) - math.pi / 2.0)
        return lat, lon

    def PixelsToMeters(self, px, py, zoom):
        res = self.Resolution(zoom)
        mx = px * res - self.originShift
        my = py * res - self.originShift
        return mx, my

    def MetersToPixels(self, mx, my, zoom):
        res = self.Resolution(zoom)
        px = (mx + self.originShift) / res
        py = (my + self.originShift) / res
        return px, py

    def PixelsToTile(self, px, py):
        tx = int(math.ceil(px / float(TILESIZE)) - 1)
        ty = int(math.ceil(py / float(TILESIZE)) - 1)
        return tx, ty

    def MetersToTile(self, mx, my, zoom):
        px, py = self.MetersToPixels(mx, my, zoom)
        return self.PixelsToTile(px, py)

    def TileBounds(self, tx, ty, zoom):
        minx, miny = self.PixelsToMeters(tx * TILESIZE, ty * TILESIZE, zoom)
        maxx, maxy = self.PixelsToMeters((tx + 1) * TILESIZE, (ty + 1) * TILESIZE, zoom)
        return (minx, miny, maxx, maxy)

    def Resolution(self, zoom):
        return self.initialResolution / (2**zoom)


def check(status, message):
    if status:
        sys.stderr.write("运行出错: %s\n" % message)
        sys.exit(3)


class TileDetail(object):
    tx = 0
    ty = 0
    tz = 0
    rx = 0
    ry = 0
    rxsize = 0
    rysize = 0
    wx = 0
    wy = 0
    wxsize = 0
    wysize = 0

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobInfo(object):
    srcFile = ""
    nbDataBands = 0
    outputFilePath = ""
    tminmax = []
    tminz = 0
    tmaxz = 0
    outGeoTrans = []

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobsMaker(object):
    def __init__(self, inputFile, outputFolder, options):
        self.dataBandsCount = 4
        self.vrtFilename = os.path.join(tempfile.mkdtemp(), str(uuid4()) + '.vrt')
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        minmax = self.options.zoom.split('-', 1)
        minmax.extend([''])
        zoom_min, zoom_max = minmax[:2]
        self.tminz = int(zoom_min)
        if zoom_max:
            self.tmaxz = int(zoom_max)
        else:
            self.tmaxz = int(zoom_min)

    def updateNoDataValue(self):
        def gdalVrtWarp(options, key, value):
            tb = ElementTree.TreeBuilder()
            tb.start("Option", {"name": key})
            tb.data(value)
            tb.end("Option")
            elem = tb.close()
            options.insert(0, elem)

        tempFile = tempfile.mktemp('-TileJobsMaker.vrt')
        self.warpedDataset.GetDriver().CreateCopy(tempFile, self.warpedDataset)
        with open(tempFile, 'r', encoding='utf-8') as f:
            vrtString = f.read()
            vrtRoot = ElementTree.fromstring(vrtString)
            options = vrtRoot.find("GDALWarpOptions")
            gdalVrtWarp(options, "INIT_DEST", "NO_DATA")
            gdalVrtWarp(options, "UNIFIED_SRC_NODATA", "YES")
            vrtString = ElementTree.tostring(vrtRoot).decode()
        with open(tempFile, 'w') as f:
            f.write(vrtString)
        correctedDataset = gdal.Open(tempFile)
        os.unlink(tempFile)
        correctedDataset.SetMetadataItem('NODATA_VALUES', '0 0 0 0')
        self.warpedDataset = correctedDataset

    def openData(self):
        gdal.AllRegister()
        inputDataset = gdal.Open(self.inputFile, gdal.GA_ReadOnly)
        check(not inputDataset, "数据无法打开")
        check(inputDataset.RasterCount == 0, "数据无波段")
        GetGeoTransform = inputDataset.GetGeoTransform()
        gcpCount = inputDataset.GetGCPCount()
        check(GetGeoTransform == (0.0, 1.0, 0.0, 0.0, 0.0, 1.0) and gcpCount == 0, "数据缺少空间信息")
        inputSrs = osr.SpatialReference()
        inputSrs.ImportFromWkt(inputDataset.GetProjection())
        outputSrs = osr.SpatialReference()
        outputSrs.ImportFromEPSG(3857)
        self.warpedDataset = gdal.AutoCreateWarpedVRT(inputDataset,
                                                      inputSrs.ExportToWkt(), outputSrs.ExportToWkt())
        self.updateNoDataValue()
        self.warpedDataset.GetDriver().CreateCopy(self.vrtFilename, self.warpedDataset)
        outGeotrans = self.warpedDataset.GetGeoTransform()
        check((outGeotrans[2], outGeotrans[4]) != (0, 0), "不支持变形后的数据")
        self.ominx = outGeotrans[0]
        self.omaxx = outGeotrans[0] + self.warpedDataset.RasterXSize * outGeotrans[1]
        self.omaxy = outGeotrans[3]
        self.ominy = outGeotrans[3] - self.warpedDataset.RasterYSize * outGeotrans[1]
        self.mercator = GlobalMercator()
        self.tminmax = list(range(0, 32))
        for tz in range(0, 32):
            tminx, tminy = self.mercator.MetersToTile(self.ominx, self.ominy, tz)
            tmaxx, tmaxy = self.mercator.MetersToTile(self.omaxx, self.omaxy, tz)
            tminx, tminy = max(0, tminx), max(0, tminy)
            tmaxx, tmaxy = min(2**tz - 1, tmaxx), min(2**tz - 1, tmaxy)
            self.tminmax[tz] = (tminx, tminy, tmaxx, tmaxy)

    def makeMetadata(self):
        south, west = self.mercator.MetersToLatLon(self.ominx, self.ominy)
        north, east = self.mercator.MetersToLatLon(self.omaxx, self.omaxy)
        south, west = max(-85.05112878, south), max(-180.0, west)
        north, east = min(85.05112878, north), min(180.0, east)
        metadata = {"south": south, "north": north, "west": west, "east": east}
        with open(os.path.join(self.outputFolder, 'metadata.json'), 'w') as f:
            json.dump(metadata, f)

    def makeBaseTiles(self):
        tminx, tminy, tmaxx, tmaxy = self.tminmax[self.tmaxz]
        tileDetails = []
        tz = self.tmaxz
        for ty in range(tmaxy, tminy - 1, -1):
            for tx in range(tminx, tmaxx + 1):
                tilefilename = os.path.join(self.outputFolder, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
                if not os.path.exists(os.path.dirname(tilefilename)):
                    os.makedirs(os.path.dirname(tilefilename))
                b = self.mercator.TileBounds(tx, ty, tz)
                rb, wb = self.geoQuery(b[0], b[3], b[2], b[1])
                rx, ry, rxsize, rysize = rb
                wx, wy, wxsize, wysize = wb
                tileDetails.append(
                    TileDetail(
                        tx=tx,
                        ty=ty,
                        tz=tz,
                        rx=rx,
                        ry=ry,
                        rxsize=rxsize,
                        rysize=rysize,
                        wx=wx,
                        wy=wy,
                        wxsize=wxsize,
                        wysize=wysize,
                    ))
        conf = TileJobInfo(
            srcFile=self.vrtFilename,
            nbDataBands=self.dataBandsCount,
            outputFilePath=self.outputFolder,
            tminmax=self.tminmax,
            tminz=self.tminz,
            tmaxz=self.tmaxz,
        )
        return conf, tileDetails

    def geoQuery(self, ulx, uly, lrx, lry):
        ds = self.warpedDataset
        geotran = ds.GetGeoTransform()
        rx = int((ulx - geotran[0]) / geotran[1] + 0.001)
        ry = int((uly - geotran[3]) / geotran[5] + 0.001)
        rxsize = int((lrx - ulx) / geotran[1] + 0.5)
        rysize = int((lry - uly) / geotran[5] + 0.5)
        wxsize, wysize = 4 * TILESIZE, 4 * TILESIZE
        wx = 0
        if rx < 0:
            rxshift = abs(rx)
            wx = int(wxsize * (float(rxshift) / rxsize))
            wxsize = wxsize - wx
            rxsize = rxsize - int(rxsize * (float(rxshift) / rxsize))
            rx = 0
        if rx + rxsize > ds.RasterXSize:
            wxsize = int(wxsize * (float(ds.RasterXSize - rx) / rxsize))
            rxsize = ds.RasterXSize - rx
        wy = 0
        if ry < 0:
            ryshift = abs(ry)
            wy = int(wysize * (float(ryshift) / rysize))
            wysize = wysize - wy
            rysize = rysize - int(rysize * (float(ryshift) / rysize))
            ry = 0
        if ry + rysize > ds.RasterYSize:
            wysize = int(wysize * (float(ds.RasterYSize - ry) / rysize))
            rysize = ds.RasterYSize - ry
        return (rx, ry, rxsize, rysize), (wx, wy, wxsize, wysize)


class ProgressBar(object):
    def __init__(self, total_items, title):
        sys.stdout.write("%s 共%d张 \n" % (title, total_items))
        self.total_items = total_items
        self.nb_items_done = 0
        self.current_progress = 0
        self.STEP = 2.5

    def start(self):
        sys.stdout.write("0")

    def updateProgress(self, nb_items=1):
        self.nb_items_done += nb_items
        progress = float(self.nb_items_done) / self.total_items * 100
        if progress >= self.current_progress + self.STEP:
            done = False
            while not done:
                if self.current_progress + self.STEP <= progress:
                    self.current_progress += self.STEP
                    if self.current_progress % 10 == 0:
                        sys.stdout.write(str(int(self.current_progress)))
                        if self.current_progress == 100:
                            sys.stdout.write("\n")
                    else:
                        sys.stdout.write(".")
                else:
                    done = True
        sys.stdout.flush()


class SingleProcessTiling(object):
    def __init__(self, inputFile, outputFolder, options):
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        self.total = 0
        self.tiling()
        self.createOverviewTiles()
        shutil.rmtree(os.path.dirname(self.tileJobInfo.srcFile))

    def tiling(self):
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        progressBar = ProgressBar(tilecount, '切割顶层瓦片')
        progressBar.start()
        for tileDetail in tileDetails:
            self.createBaseTile(tileDetail)
            progressBar.updateProgress()

    def createBaseTile(self, tileDetail, queue=None):
        gdal.AllRegister()
        tileJobInfo = self.tileJobInfo
        output = tileJobInfo.outputFilePath
        tilebands = tileJobInfo.nbDataBands
        ds = gdal.Open(tileJobInfo.srcFile, gdal.GA_ReadOnly)
        memDrv = gdal.GetDriverByName('MEM')
        outDrv = gdal.GetDriverByName(TILEDRIVER)
        alphaband = ds.GetRasterBand(1).GetMaskBand()
        tx = tileDetail.tx
        ty = tileDetail.ty
        tz = tileDetail.tz
        rx = tileDetail.rx
        ry = tileDetail.ry
        rxsize = tileDetail.rxsize
        rysize = tileDetail.rysize
        wx = tileDetail.wx
        wy = tileDetail.wy
        wxsize = tileDetail.wxsize
        wysize = tileDetail.wysize
        querysize = 4 * TILESIZE
        tilefilename = os.path.join(output, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
        dstile = memDrv.Create('', TILESIZE, TILESIZE, tilebands)
        data = alpha = None
        if rxsize != 0 and rysize != 0 and wxsize != 0 and wysize != 0:
            data = ds.ReadRaster(rx, ry, rxsize, rysize, wxsize, wysize, band_list=list(range(1, tilebands)))
            alpha = alphaband.ReadRaster(rx, ry, rxsize, rysize, wxsize, wysize)
            if data:
                dsquery = memDrv.Create('', querysize, querysize, tilebands)
                dsquery.WriteRaster(wx, wy, wxsize, wysize, data, band_list=list(range(1, tilebands)))
                dsquery.WriteRaster(wx, wy, wxsize, wysize, alpha, band_list=[tilebands])
                self.scaleQueryToTile(dsquery, dstile, tilefilename)
                del dsquery
        del ds
        del data
        outDrv.CreateCopy(tilefilename, dstile, strict=0)
        del dstile
        if queue:
            queue.put("tile %s %s %s" % (tx, ty, tz))

    def workerTileDetails(self):
        tileJobsMaker = TileJobsMaker(self.inputFile, self.outputFolder, self.options)
        tileJobsMaker.openData()
        tileJobsMaker.makeMetadata()
        conf, tileDetails = tileJobsMaker.makeBaseTiles()
        self.tileJobInfo = conf
        return tileDetails

    def scaleQueryToTile(self, dsquery, dstile, tilefilename=''):
        tilebands = dstile.RasterCount
        for i in range(1, tilebands + 1):
            res = gdal.RegenerateOverview(dsquery.GetRasterBand(i), dstile.GetRasterBand(i), 'average')
            check(res != 0, "概览生成失败 %s,%d" % (tilefilename, res))

    def createOverviewTiles(self):
        tileJobInfo = self.tileJobInfo
        memDriver = gdal.GetDriverByName('MEM')
        outDriver = gdal.GetDriverByName(TILEDRIVER)
        tilebands = tileJobInfo.nbDataBands
        tcount = 0
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            tcount += (1 + abs(tmaxx - tminx)) * (1 + abs(tmaxy - tminy))
        if tcount == 0:
            return
        self.total += tcount
        progressBar = ProgressBar(tcount, '切割下层瓦片')
        progressBar.start()
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            for ty in range(tmaxy, tminy - 1, -1):
                for tx in range(tminx, tmaxx + 1):
                    tilefilename = os.path.join(self.outputFolder, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
                    if not os.path.exists(os.path.dirname(tilefilename)):
                        os.makedirs(os.path.dirname(tilefilename))
                    dsquery = memDriver.Create('', 2 * TILESIZE, 2 * TILESIZE, tilebands)
                    dstile = memDriver.Create('', TILESIZE, TILESIZE, tilebands)
                    for y in range(2 * ty, 2 * ty + 2):
                        for x in range(2 * tx, 2 * tx + 2):
                            minx, miny, maxx, maxy = tileJobInfo.tminmax[tz + 1]
                            if x >= minx and x <= maxx and y >= miny and y <= maxy:
                                path = os.path.join(self.outputFolder, str(tz + 1), str(x), "%s.%s" % (y, TILEEXT))
                                dsquerytile = gdal.Open(path, gdal.GA_ReadOnly)
                                if (ty == 0 and y == 1) or (ty != 0 and(y % (2 * ty)) != 0):
                                    tileposy = 0
                                else:
                                    tileposy = TILESIZE
                                if tx:
                                    tileposx = x % (2 * tx) * TILESIZE
                                elif tx == 0 and x == 1:
                                    tileposx = TILESIZE
                                else:
                                    tileposx = 0
                                tempRaseter = dsquerytile.ReadRaster(0, 0, TILESIZE, TILESIZE)
                                dsquery.WriteRaster(tileposx, tileposy, TILESIZE, TILESIZE, tempRaseter, band_list=list(range(1, tilebands + 1)))
                    self.scaleQueryToTile(dsquery, dstile, tilefilename=tilefilename)
                    outDriver.CreateCopy(tilefilename, dstile, strict=0)
                    progressBar.updateProgress()


class MultiProcessTiling(SingleProcessTiling):
    def __init__(self, inputFile, outputFolder, options):
        super().__init__(inputFile, outputFolder, options)

    def tiling(self):
        processes = self.options.processes or 1
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        manager = Manager()
        queue = manager.Queue()
        pool = Pool(processes=processes)
        for tileDetail in tileDetails:
            pool.apply_async(self.createBaseTile, (tileDetail, ), {"queue": queue})
        p = Process(target=self.progressPrinter, args=[queue, tilecount, '切割顶层瓦片'])
        p.start()
        pool.close()
        pool.join()
        p.join()

    def progressPrinter(self, queue, nb_jobs, title):
        pb = ProgressBar(nb_jobs, title)
        pb.start()
        for _ in range(nb_jobs):
            queue.get()
            pb.updateProgress()
            queue.task_done()


def process_args(argv):
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-z", dest='zoom', metavar='切割级别', required=True, help="如 '12-20'")
    parser.add_argument("-p", dest='processes', metavar='进程数', type=int, default=1, help='默认单进程')
    parser.add_argument("-i", dest='input', metavar='tif文件', required=True)
    parser.add_argument("-o", dest='output', metavar="输出文件夹", help="可选,默认在输入文件夹下的同名文件夹")
    args = parser.parse_args(argv)
    inputFile = args.input
    check(not os.path.isfile(inputFile), "%s不存在或非文件" % inputFile)
    outputFolder = args.output
    if not outputFolder:
        tifname = os.path.basename(inputFile).split('.')[0]
        outputFolder = os.path.join(os.path.dirname(os.path.abspath(inputFile)), tifname)
    if not os.path.exists(outputFolder):
        os.makedirs(outputFolder)
    return inputFile, outputFolder, args


def main():
    argv = gdal.GeneralCmdLineProcessor(sys.argv)
    inputFile, outputFolder, options = process_args(argv[1:])

    if options.processes == 1:
        SingleProcessTiling(inputFile, outputFolder, options)
    else:
        MultiProcessTiling(inputFile, outputFolder, options)


if __name__ == '__main__':
    main()

运行流程

g2t的流程没有分支,非常清晰:


从原始数据获取所需的最高级别的瓦片,更低级的瓦片只需从这些最高级瓦片一层一层生成.

这样速度更快:因为最高级的瓦片只能利用gdal从原始tif中获取,其速度受tif尺寸影响很大,且从tif上取得级别越低,单次所取范围越大,速度也越慢.举个实际的例子,从原始tif上获取某位置17级的瓦片的时间将远远大于从原始tif获取4张对应位置的18级瓦片,并将其合成的时间.

整体架构

1. 生成切割顶级瓦片文物列表

生成任务列表的功能被完全封装在TileJobsMaker这个类中,观察TileJobsMaker的调用可以发现其实际运行流程如下

# 打开数据
tileJobsMaker = TileJobsMaker(self.inputFile, self.outputFolder, self.options)
tileJobsMaker.openData()
# 生成元数据文件
tileJobsMaker.makeMetadata()
# 生成任务列表
conf, tileDetails = tileJobsMaker.makeBaseTiles()

1.1 打开数据

def __init__(self, inputFile, outputFolder, options):
    # 默认只支持包含RGB波段的数据
    self.dataBandsCount = 4
    # 流程采用gdal的vrt模式,加快运行(类似g2m)
    self.vrtFilename = os.path.join(tempfile.mkdtemp(), str(uuid4()) + '.vrt')
    self.inputFile = inputFile
    self.outputFolder = outputFolder
    self.options = options
    # 格式化缩放范围
    minmax = self.options.zoom.split('-', 1)
    minmax.extend([''])
    zoom_min, zoom_max = minmax[:2]
    self.tminz = int(zoom_min)
    if zoom_max:
        self.tmaxz = int(zoom_max)
    else:
        self.tmaxz = int(zoom_min)

def openData(self):
    gdal.AllRegister()
    inputDataset = gdal.Open(self.inputFile, gdal.GA_ReadOnly)
    # 强制将数据投影到web墨卡托投影,因为一般我们都是把瓦片发布为互联网服务的,3857无疑是最方便的
    inputSrs = osr.SpatialReference()
    inputSrs.ImportFromWkt(inputDataset.GetProjection())
    outputSrs = osr.SpatialReference()
    outputSrs.ImportFromEPSG(3857)
    # 投影.在这里,操作的数据其实已经变为vrt格式
    self.warpedDataset = gdal.AutoCreateWarpedVRT(inputDataset,inputSrs.ExportToWkt(), outputSrs.ExportToWkt())
    # 强制将nodata值设为透明
    self.updateNoDataValue()
    # 将vrt格式的数据集写入到指定的vrt文件中,供后期使用
    self.warpedDataset.GetDriver().CreateCopy(self.vrtFilename, self.warpedDataset)
    # 计算元数据的四至,因为已经投影为web墨卡托,这里的单位是米
    outGeotrans = self.warpedDataset.GetGeoTransform()
    self.ominx = outGeotrans[0]
    self.omaxx = outGeotrans[0] + self.warpedDataset.RasterXSize * outGeotrans[1]
    self.omaxy = outGeotrans[3]
    self.ominy = outGeotrans[3] - self.warpedDataset.RasterYSize * outGeotrans[1]
    # GlobalMercator是一个封装了各种投影方法的类
    self.mercator = GlobalMercator()
    self.tminmax = list(range(0, 32))
    # 计算每一缩放级,瓦片的行列号范围
    # 这里可以只计算用户指定的缩放范围,但其实影响不大
    for tz in range(0, 32):
        tminx, tminy = self.mercator.MetersToTile(self.ominx, self.ominy, tz)
        tmaxx, tmaxy = self.mercator.MetersToTile(self.omaxx, self.omaxy, tz)
        tminx, tminy = max(0, tminx), max(0, tminy)
        tmaxx, tmaxy = min(2**tz - 1, tmaxx), min(2**tz - 1, tmaxy)
        self.tminmax[tz] = (tminx, tminy, tmaxx, tmaxy)
        
def updateNoDataValue(self):
        """更新nodata的地方为透明"""
        def gdalVrtWarp(options, key, value):
            """vrt文件修改"""
            tb = ElementTree.TreeBuilder()
            tb.start("Option", {"name": key})
            tb.data(value)
            tb.end("Option")
            elem = tb.close()
            options.insert(0, elem)

        tempFile = tempfile.mktemp('-TileJobsMaker.vrt')
        self.warpedDataset.GetDriver().CreateCopy(tempFile, self.warpedDataset)
        # 直接通过修改vrt文件的方式添加nodata值
        with open(tempFile, 'r', encoding='utf-8') as f:
            vrtString = f.read()
            vrtRoot = ElementTree.fromstring(vrtString)
            options = vrtRoot.find("GDALWarpOptions")
            # 设定数据集的每一个像素初始值都为no_data
            gdalVrtWarp(options, "INIT_DEST", "NO_DATA")
            # 当所有波段都符合no_data时,将整个波段都视为no_data,而不将每个波段独立对待
            gdalVrtWarp(options, "UNIFIED_SRC_NODATA", "YES")
            vrtString = ElementTree.tostring(vrtRoot).decode()
        with open(tempFile, 'w') as f:
            f.write(vrtString)
        # 加载修改后的vrt文件    
        correctedDataset = gdal.Open(tempFile)
        os.unlink(tempFile)
        # 设置no_data值为透明(RGBA四个波段都是0)
        correctedDataset.SetMetadataItem('NODATA_VALUES', '0 0 0 0')
        self.warpedDataset = correctedDataset

1.2 生成元数据文件

def makeMetadata(self):
    # 以经纬度的方式计算瓦片的四至,供前端地图跳转使用
    south, west = self.mercator.MetersToLatLon(self.ominx, self.ominy)
    north, east = self.mercator.MetersToLatLon(self.omaxx, self.omaxy)
    south, west = max(-85.05112878, south), max(-180.0, west)
    north, east = min(85.05112878, north), min(180.0, east)
    metadata = {"south": south, "north": north, "west": west, "east": east}
    with open(os.path.join(self.outputFolder, 'metadata.json'), 'w') as f:
        json.dump(metadata, f)

当然可以尽可能的将有用的信息写入元数据中.

1.3 生成任务列表

def makeBaseTiles(self):
        # 最高缩放级别的瓦片行列号范围
        tminx, tminy, tmaxx, tmaxy = self.tminmax[self.tmaxz]
        tileDetails = []
        tz = self.tmaxz
        # y倒序,从左上角开始切图
        for ty in range(tmaxy, tminy - 1, -1):
            for tx in range(tminx, tmaxx + 1):
                # tz,yx,ty对应着行列号中的x/y/z
                tilefilename = os.path.join(self.outputFolder, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
                if not os.path.exists(os.path.dirname(tilefilename)):
                  os.makedirs(os.path.dirname(tilefilename))
                # 计算该瓦片的投影经纬度范围
                b = self.mercator.TileBounds(tx, ty, tz)
                # 获取该瓦片具体的各种偏移参数
                rb, wb = self.geoQuery(b[0], b[3], b[2], b[1])
                rx, ry, rxsize, rysize = rb
                wx, wy, wxsize, wysize = wb
                tileDetails.append(
                    TileDetail(
                        tx=tx,
                        ty=ty,
                        tz=tz,
                        rx=rx,
                        ry=ry,
                        rxsize=rxsize,
                        rysize=rysize,
                        wx=wx,
                        wy=wy,
                        wxsize=wxsize,
                        wysize=wysize,
                    ))
        conf = TileJobInfo(
            srcFile=self.vrtFilename,
            nbDataBands=self.dataBandsCount,
            outputFilePath=self.outputFolder,
            tminmax=self.tminmax,
            tminz=self.tminz,
            tmaxz=self.tmaxz,
        )
        return conf, tileDetails

def geoQuery(self, ulx, uly, lrx, lry):
        ds = self.warpedDataset
        # geotran[0/3]是tif左上角点x/y
        # geotran[1/5]是像源宽/高
        geotran = ds.GetGeoTransform()
        # 计算该瓦片的左上角在源图上的x/y像素偏移量
        rx = int((ulx - geotran[0]) / geotran[1] + 0.001)
        ry = int((uly - geotran[3]) / geotran[5] + 0.001)
        # 计算该瓦片在源图上的像素宽度
        rxsize = int((lrx - ulx) / geotran[1] + 0.5)
        rysize = int((lry - uly) / geotran[5] + 0.5)
        # 窗口尺寸.4倍于瓦片尺寸,提高缩放重采样时的瓦片效果
        wxsize, wysize = 4 * TILESIZE, 4 * TILESIZE
        # 窗口偏移
        wx = 0
        # 特殊情况下的修正
        if rx < 0:
            rxshift = abs(rx)
            wx = int(wxsize * (float(rxshift) / rxsize))
            wxsize = wxsize - wx
            rxsize = rxsize - int(rxsize * (float(rxshift) / rxsize))
            rx = 0
        if rx + rxsize > ds.RasterXSize:
            wxsize = int(wxsize * (float(ds.RasterXSize - rx) / rxsize))
            rxsize = ds.RasterXSize - rx
        wy = 0
        if ry < 0:
            ryshift = abs(ry)
            wy = int(wysize * (float(ryshift) / rysize))
            wysize = wysize - wy
            rysize = rysize - int(rysize * (float(ryshift) / rysize))
            ry = 0
        if ry + rysize > ds.RasterYSize:
            wysize = int(wysize * (float(ds.RasterYSize - ry) / rysize))
            rysize = ds.RasterYSize - ry
        return (rx, ry, rxsize, rysize), (wx, wy, wxsize, wysize)

这里最重要的部分就是坐偏移量的计算,生成的几组变量在后面切图的时候都会用到:

  • 控制如何在源图上找到瓦片所需位置的数据
    • rx/ry:该张瓦片的左上角相对于原始tif的像素偏移
    • rxsize/rysize:该张瓦片的长/宽对应于原始tif多少像素.不过由于瓦片都是正方形且是固定值,所以这两个值在一般情况下都相等且等于同一值
  • 控制如何把对应的数据绘制到瓦片上.有个绘图
    • wx/wy:窗口偏移,大多数情况为0
    • wxsize/wysize:实际绘制时的大小,大多数情况为4倍瓦片尺寸

特殊情况:在所有tif的边缘处,都会出现瓦片冗余,这些地方需要单独处理:

  • x最小的瓦片,rx会小于0,即取到不在tif上的位置,需要将rx置零,缩小rxsize到有数据的范围,保证只在有数据的地方取数据.同时调整wx,wxsize,将数据绘制在对应的地方
  • x最大的瓦片,rx+rxsize会超过tif本身宽度,需要缩小rxsize,同样保证只在有数据的地方取数据.调整wxsize,只在有数据的地方绘制,其他无数据的地方已经在前面设为透明了,无需绘制.

g2t本可以不生成任务列表,边算边切瓦片的.强制生成任务列表的目的只有一个:可以多进程切割.

2. 切割顶级瓦片

def createBaseTile(self, tileDetail, queue=None):
        gdal.AllRegister()
        # 获取任务参数
        tileJobInfo = self.tileJobInfo
        output = tileJobInfo.outputFilePath
        tilebands = tileJobInfo.nbDataBands
        # 打开vrt文件作为数据集,读取模式,在多进程模式下不会出现锁?
        ds = gdal.Open(tileJobInfo.srcFile, gdal.GA_ReadOnly)
        memDrv = gdal.GetDriverByName('MEM')
        outDrv = gdal.GetDriverByName(TILEDRIVER)
        alphaband = ds.GetRasterBand(1).GetMaskBand()
        tx = tileDetail.tx
        ty = tileDetail.ty
        tz = tileDetail.tz
        rx = tileDetail.rx
        ry = tileDetail.ry
        rxsize = tileDetail.rxsize
        rysize = tileDetail.rysize
        wx = tileDetail.wx
        wy = tileDetail.wy
        wxsize = tileDetail.wxsize
        wysize = tileDetail.wysize
        # '窗口'数据集尺寸就是4倍瓦片大小
        querysize = 4 * TILESIZE
        tilefilename = os.path.join(output, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
        # 最终要写入的数据集
        dstile = memDrv.Create('', TILESIZE, TILESIZE, tilebands)
        data = alpha = None
        if rxsize != 0 and rysize != 0 and wxsize != 0 and wysize != 0:
            # 根据上文获取到的参数,读取每一张瓦片对应的数据
            data = ds.ReadRaster(rx, ry, rxsize, rysize, wxsize, wysize, band_list=list(range(1, tilebands)))
            # alpha波段直接创建
            alpha = alphaband.ReadRaster(rx, ry, rxsize, rysize, wxsize, wysize)
            if data:
                # 所谓的'窗口'数据集
                dsquery = memDrv.Create('', querysize, querysize, tilebands)
                # 先将数据读取到窗口数据集中
                dsquery.WriteRaster(wx, wy, wxsize, wysize, data, band_list=list(range(1, tilebands)))
                dsquery.WriteRaster(wx, wy, wxsize, wysize, alpha, band_list=[tilebands])
                # 重采样到目标数据集
                self.scaleQueryToTile(dsquery, dstile, tilefilename)
                del dsquery
        del ds
        del data
        # 目标数据集导出png图片到指定位置
        outDrv.CreateCopy(tilefilename, dstile, strict=0)
        del dstile
        # 如果是多进程模式则向主进程传递进度
        if queue:
            queue.put("tile %s %s %s" % (tx, ty, tz))

def scaleQueryToTile(self, dsquery, dstile, tilefilename=''):
    """从'窗口'数据集使用'average'算法重采样到目标数据集"""
    tilebands = dstile.RasterCount
    for i in range(1, tilebands + 1):
        res = gdal.RegenerateOverview(dsquery.GetRasterBand(i), dstile.GetRasterBand(i), 'average')

单进程调用:

# 产生任务
tileDetails = self.workerTileDetails()
# 每个任务串行执行
for tileDetail in tileDetails:
    self.createBaseTile(tileDetail)

多进程调用:

processes = self.options.processes or 1
tileDetails = self.workerTileDetails()
manager = Manager()
queue = manager.Queue()
# 同时只允许指定个数的进程在运行
pool = Pool(processes=processes)
for tileDetail in tileDetails:
    pool.apply_async(self.createBaseTile, (tileDetail, ), {"queue": queue})
pool.close()
pool.join()

切瓦片的时候充分利用核心资源无异是非常合算的.在没有阅读之前,我构思的多线程是固定若干个处理进程,然后给每个进程单独分配任务.g2t的实现方式更加简单,实现也清晰.因为不断有进程创建/结束,我无法确定创建/结束进程的会耽误多长时间,不过实际来看应该影响不大.

3. 生成下层瓦片

 def createOverviewTiles(self):
        tileJobInfo = self.tileJobInfo
        memDriver = gdal.GetDriverByName('MEM')
        outDriver = gdal.GetDriverByName(TILEDRIVER)
        tilebands = tileJobInfo.nbDataBands
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            for ty in range(tmaxy, tminy - 1, -1):
                for tx in range(tminx, tmaxx + 1):
                    # 遍历所有底层瓦片
                    tilefilename = os.path.join(self.outputFolder, str(tz), str(tx), "%s.%s" % (ty, TILEEXT))
                    if not os.path.exists(os.path.dirname(tilefilename)):
                        os.makedirs(os.path.dirname(tilefilename))
                    dsquery = memDriver.Create('', 2 * TILESIZE, 2 * TILESIZE, tilebands)
                    dstile = memDriver.Create('', TILESIZE, TILESIZE, tilebands)
                    # 每级的行列号都是它大一级的2分之1
                    # 每张瓦片都与比它大一级的4张对应瓦片所示范围相同,所以根据这4张瓦片就能拼接出本级瓦片
                    for y in range(2 * ty, 2 * ty + 2):
                        for x in range(2 * tx, 2 * tx + 2):
                            minx, miny, maxx, maxy = tileJobInfo.tminmax[tz + 1]
                            # 只拼接有数据的瓦片
                            if x >= minx and x <= maxx and y >= miny and y <= maxy:
                                path = os.path.join(self.outputFolder, str(tz + 1), str(x), "%s.%s" % (y, TILEEXT))
                                # 读取4张中的每一张瓦片
                                dsquerytile = gdal.Open(path, gdal.GA_ReadOnly)
                                # 把4张瓦片放到对应的位置
                                if (ty == 0 and y == 1) or (ty != 0 and(y % (2 * ty)) != 0):
                                    tileposy = 0
                                else:
                                    tileposy = TILESIZE
                                if tx:
                                    tileposx = x % (2 * tx) * TILESIZE
                                elif tx == 0 and x == 1:
                                    tileposx = TILESIZE
                                else:
                                    tileposx = 0
                                # 读取瓦片再在'窗口'中绘制
                                tempRaseter = dsquerytile.ReadRaster(0, 0, TILESIZE, TILESIZE)
                                dsquery.WriteRaster(tileposx, tileposy, TILESIZE, TILESIZE, tempRaseter, band_list=list(range(1, tilebands + 1)))
                    # 重采样后写入本地对应文件
                    self.scaleQueryToTile(dsquery, dstile, tilefilename=tilefilename)
                    outDriver.CreateCopy(tilefilename, dstile, strict=0)

切下层瓦片时并没有采用多进程模式,我猜可能是已经够快了,事实也确实如此.或许如果也采用多进程模式,整体运行还会更快?应在在瓦片数量比较多的情况下才有所体现吧.

总结

只用GDAL使得g2t部署简单,但也限制了其性能.如果把读取tif和生成png都替换成类似g2m的vips,或许能获得和g2m一样的效率.

修改后的gdal2tiles

  • 默认改为谷歌瓦片模式
  • 支持火星偏移
  • 支持瓦片压缩和删除无内容瓦片
# -*- coding=utf-8 -*-
import platform
from xml.etree import ElementTree
import json
from osgeo import gdal, osr
from uuid import uuid4
import sys
import shutil
import tempfile
import os
import math
from multiprocessing import Pool, Process, Manager, cpu_count
# 瓦片大小,只能是512或256
TILESIZE = 256
# 瓦片格式
TILEDRIVER = 'PNG'
TILEEXT = 'png'
# 空瓦片的大小
EMPTY = {'512': 1096, "256": 334}
QUANTFILE = None
if(platform.system() == 'Windows'):
    QUANTFILE = os.path.join(os.getcwd(), "pngquant.exe")
# 压缩
COMPRESS = False
if COMPRESS:
    import pngquant


class GlobalMercator(object):
    def __init__(self):
        self.initialResolution = 2 * math.pi * 6378137 / TILESIZE
        self.originShift = 2 * math.pi * 6378137 / 2.0

    def MetersToLatLon(self, mx, my):
        lon = (mx / self.originShift) * 180.0
        lat = (my / self.originShift) * 180.0
        lat = 180 / math.pi * \
            (2 * math.atan(math.exp(lat * math.pi / 180.0)) - math.pi / 2.0)
        # 去掉无用精度
        lat = round(lat, 6)
        lon = round(lon, 6)
        return lat, lon

    def PixelsToMeters(self, px, py, zoom):
        res = self.Resolution(zoom)
        mx = px * res - self.originShift
        my = py * res - self.originShift
        return mx, my

    def MetersToPixels(self, mx, my, zoom):
        res = self.Resolution(zoom)
        px = (mx + self.originShift) / res
        py = (my + self.originShift) / res
        return px, py

    def PixelsToTile(self, px, py):
        tx = int(math.ceil(px / float(TILESIZE)) - 1)
        ty = int(math.ceil(py / float(TILESIZE)) - 1)
        return tx, ty

    def MetersToTile(self, mx, my, zoom):
        px, py = self.MetersToPixels(mx, my, zoom)
        return self.PixelsToTile(px, py)

    def LatLonToMeters(self, lat, lon):
        mx = lon * self.originShift / 180.0
        my = math.log(math.tan((90 + lat) * math.pi / 360.0)) / \
            (math.pi / 180.0)
        my = my * self.originShift / 180.0
        return mx, my

    def TileBounds(self, tx, ty, zoom):
        minx, miny = self.PixelsToMeters(tx * TILESIZE, ty * TILESIZE, zoom)
        maxx, maxy = self.PixelsToMeters(
            (tx + 1) * TILESIZE, (ty + 1) * TILESIZE, zoom)
        # 翻转
        return (minx, miny, maxx, maxy)

    def Resolution(self, zoom):
        return self.initialResolution / (2 ** zoom)


    def WGS84ToGCJ02(self, lat, lng):
        pi = 3.1415926535897932384626
        ee = 0.00669342162296594323
        a = 6378245.0
        # 算法来源于https://github.com/wandergis/coordTransform_py

        def _transformlat(lng, lat):
            ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
                0.1 * lng * lat + 0.2 * math.sqrt(math.fabs(lng))
            ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
                    math.sin(2.0 * lng * pi)) * 2.0 / 3.0
            ret += (20.0 * math.sin(lat * pi) + 40.0 *
                    math.sin(lat / 3.0 * pi)) * 2.0 / 3.0
            ret += (160.0 * math.sin(lat / 12.0 * pi) + 320 *
                    math.sin(lat * pi / 30.0)) * 2.0 / 3.0
            return ret

        def _transformlng(lng, lat):
            ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
                0.1 * lng * lat + 0.1 * math.sqrt(math.fabs(lng))
            ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
                    math.sin(2.0 * lng * pi)) * 2.0 / 3.0
            ret += (20.0 * math.sin(lng * pi) + 40.0 *
                    math.sin(lng / 3.0 * pi)) * 2.0 / 3.0
            ret += (150.0 * math.sin(lng / 12.0 * pi) + 300.0 *
                    math.sin(lng / 30.0 * pi)) * 2.0 / 3.0
            return ret
        dlat = _transformlat(lng - 105.0, lat - 35.0)
        dlng = _transformlng(lng - 105.0, lat - 35.0)
        radlat = lat / 180.0 * pi
        magic = math.sin(radlat)
        magic = 1 - ee * magic * magic
        sqrtmagic = math.sqrt(magic)
        dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
        dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
        mglat = lat + dlat
        mglng = lng + dlng
        mglat = round(mglat, 6)
        mglng = round(mglng, 6)
        return mglat, mglng

    def ZoomForPixelSize(self, pixelSize):
        for i in range(32):
            if pixelSize > self.Resolution(i):
                if i != -1:
                    return i-1
                else:
                    return 0


def check(status, message):
    if status:
        sys.stderr.write("运行出错: %s\n" % message)
        sys.exit(3)

def GoogleTile( zoom,ty):
    offset=0
    if TILESIZE==512:
        offset=1
    return (2**(zoom-offset) - 1) - ty

def gettempfilename(suffix):
    if '_' in os.environ:
        if os.environ['_'].find('wine') >= 0:
            tmpdir = '.'
            if 'TMP' in os.environ:
                tmpdir = os.environ['TMP']
            import time
            import random
            random.seed(time.time())
            random_part = 'file%d' % random.randint(0, 1000000000)
            return os.path.join(tmpdir, random_part + suffix)

    return tempfile.mktemp(suffix)


def add_alpha_band_to_string_vrt(vrt_string):
    vrt_root = ElementTree.fromstring(vrt_string)
    index = 0
    nb_bands = 0
    for subelem in list(vrt_root):
        if subelem.tag == "VRTRasterBand":
            nb_bands += 1
            color_node = subelem.find("./ColorInterp")
            if color_node is not None and color_node.text == "Alpha":
                raise Exception("Alpha band already present")
        else:
            if nb_bands:
                break
        index += 1

    tb = ElementTree.TreeBuilder()
    tb.start("VRTRasterBand", {'dataType': "Byte", "band": str(nb_bands + 1),
                               "subClass": "VRTWarpedRasterBand"})
    tb.start("ColorInterp", {})
    tb.data("Alpha")
    tb.end("ColorInterp")
    tb.end("VRTRasterBand")
    elem = tb.close()
    vrt_root.insert(index, elem)
    warp_options = vrt_root.find(".//GDALWarpOptions")
    tb = ElementTree.TreeBuilder()
    tb.start("DstAlphaBand", {})
    tb.data(str(nb_bands + 1))
    tb.end("DstAlphaBand")
    elem = tb.close()
    warp_options.append(elem)
    tb = ElementTree.TreeBuilder()
    tb.start("Option", {"name": "INIT_DEST"})
    tb.data("0")
    tb.end("Option")
    elem = tb.close()
    warp_options.append(elem)

    return ElementTree.tostring(vrt_root).decode()


class TileDetail(object):
    tx = 0
    ty = 0
    tz = 0
    rx = 0
    ry = 0
    rxsize = 0
    rysize = 0
    wx = 0
    wy = 0
    wxsize = 0
    wysize = 0

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobInfo(object):
    srcFile = ""
    nbDataBands = 0
    outputFilePath = ""
    tminmax = []
    tminz = 0
    tmaxz = 0
    outGeoTrans = []

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobsMaker(object):
    def __init__(self, inputFile, outputFolder, options):
        # 默认只支持包含RGB波段的数据
        self.dataBandsCount = 4
        # 流程采用gdal的vrt模式,加快运行(类似g2m)
        self.vrtFilename = os.path.join(
            tempfile.mkdtemp(), str(uuid4()) + '.vrt')
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        # 格式化缩放范围
        self.tminz = -1
        if self.options.zoom:
            minmax = self.options.zoom.split('-', 1)
            minmax.extend([''])
            zoom_min, zoom_max = minmax[:2]
            self.tminz = int(zoom_min)
            if zoom_max:
                self.tmaxz = int(zoom_max)
            else:
                self.tmaxz = int(zoom_min)

    def updateNoDataValue(self):
        """更新nodata的地方为透明"""
        def gdalVrtWarp(options, key, value):
            """vrt文件修改"""
            tb = ElementTree.TreeBuilder()
            tb.start("Option", {"name": key})
            tb.data(value)
            tb.end("Option")
            elem = tb.close()
            options.insert(0, elem)

        tempFile = tempfile.mktemp('-TileJobsMaker.vrt')
        self.warpedDataset.GetDriver().CreateCopy(tempFile, self.warpedDataset)
        with open(tempFile, 'r', encoding='utf-8') as f:
            vrtString = f.read()
            vrtRoot = ElementTree.fromstring(vrtString)
            options = vrtRoot.find("GDALWarpOptions")
            # 设定数据集的每一个像素初始值都为no_data
            gdalVrtWarp(options, "INIT_DEST", "NO_DATA")
            # 当所有波段都符合no_data时,将整个波段都视为no_data,而不将每个波段独立对待
            gdalVrtWarp(options, "UNIFIED_SRC_NODATA", "YES")
            vrtString = ElementTree.tostring(vrtRoot).decode()
        with open(tempFile, 'w') as f:
            f.write(vrtString)
        # 加载修改后的vrt文件
        correctedDataset = gdal.Open(tempFile)
        os.unlink(tempFile)
        # 设置no_data值为透明(RGBA四个波段都是0)
        correctedDataset.SetMetadataItem('NODATA_VALUES', '0 0 0 0')
        self.warpedDataset = correctedDataset

    def updateAlphaForNonAlphaData(self):
        warpedDataset = self.warpedDataset
        if warpedDataset.RasterCount in [1, 3]:
            tempfilename = gettempfilename('-gdal2tiles.vrt')
            warpedDataset.GetDriver().CreateCopy(tempfilename, warpedDataset)
            with open(tempfilename) as f:
                orig_data = f.read()
            alpha_data = add_alpha_band_to_string_vrt(orig_data)
            with open(tempfilename, 'w') as f:
                f.write(alpha_data)
            warpedDataset = gdal.Open(tempfilename)
            os.unlink(tempfilename)
        self.warpedDataset = warpedDataset

    def openData(self):
        gdal.AllRegister()
        self.mercator = GlobalMercator()
        inputDataset = gdal.Open(self.inputFile, gdal.GA_ReadOnly)
        check(not inputDataset, "数据无法打开")
        check(inputDataset.RasterCount == 0, "数据无波段")
        GetGeoTransform = inputDataset.GetGeoTransform()
        gcpCount = inputDataset.GetGCPCount()
        check(GetGeoTransform == (0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
              and gcpCount == 0, "数据缺少空间信息")
        # 强制将数据投影到web墨卡托投影,因为一般我们都是把瓦片发布为互联网服务的,3857无疑是最方便的
        inputSrs = osr.SpatialReference()
        if self.options.s_srs:
            inputSrs.SetFromUserInput(self.options.s_srs)
        else:
            inputSrs.ImportFromWkt(inputDataset.GetProjection())
        outputSrs = osr.SpatialReference()
        outputSrs.ImportFromEPSG(3857)
        # 投影.在这里,操作的数据其实已经变为vrt格式
        self.warpedDataset = gdal.AutoCreateWarpedVRT(inputDataset,
                                                      inputSrs.ExportToWkt(),
                                                      outputSrs.ExportToWkt())

        # 强制将nodata值设为透明
        inNodata = []
        for i in range(1, inputDataset.RasterCount+1):
            rasterNoData = inputDataset.GetRasterBand(i).GetNoDataValue()
            if rasterNoData is not None:
                inNodata.append(rasterNoData)
        if inNodata:
            self.updateNoDataValue()
        else:
            self.updateAlphaForNonAlphaData()
        # 火星偏移
        if self.options.gcj02:
            geotrans = self.warpedDataset.GetGeoTransform()
            lat, lng = self.mercator.MetersToLatLon(geotrans[0], geotrans[3])
            lat, lng = self.mercator.WGS84ToGCJ02(lat, lng)
            x, y = self.mercator.LatLonToMeters(lat, lng)
            warpedGeotrans = [x, geotrans[1], 0, y, 0, geotrans[5]]
            self.warpedDataset.SetGeoTransform(warpedGeotrans)
        # 将vrt格式的数据集写入到指定的vrt文件中,供后期使用
        self.warpedDataset.GetDriver().CreateCopy(self.vrtFilename,
                                                  self.warpedDataset)
        outGeotrans = self.warpedDataset.GetGeoTransform()
        check((outGeotrans[2], outGeotrans[4]) != (0, 0), "不支持变形后的数据")
        print(outGeotrans[1])
        # 自动计算缩放
        if self.tminz == -1:
            self.tminz = self.mercator.ZoomForPixelSize(
                outGeotrans[1] *
                max(self.warpedDataset.RasterXSize,
                    self.warpedDataset.RasterYSize) /
                float(TILESIZE))
            tmaxz = self.mercator.ZoomForPixelSize(outGeotrans[1])
            if tmaxz < 6:
                tmaxz = 6
            self.tmaxz = tmaxz
        # 计算元数据的四至,因为已经投影为web墨卡托,这里的单位是米
        self.ominx = outGeotrans[0]
        self.omaxx = outGeotrans[0] + \
            self.warpedDataset.RasterXSize * outGeotrans[1]
        self.omaxy = outGeotrans[3]
        self.ominy = outGeotrans[3] - \
            self.warpedDataset.RasterYSize * outGeotrans[1]
        self.tminmax = list(range(0, 32))
        # 计算每一缩放级,瓦片的行列号范围
        # 这里可以只计算用户指定的缩放范围,但其实影响不大
        for tz in range(0, 32):
            _tz = int(tz - (TILESIZE / 256 - 1))
            tminx, tminy = self.mercator.MetersToTile(
                self.ominx, self.ominy, _tz)
            tmaxx, tmaxy = self.mercator.MetersToTile(
                self.omaxx, self.omaxy, _tz)
            tminx, tminy = max(0, tminx), max(0, tminy)
            tmaxx, tmaxy = min(2**_tz - 1, tmaxx), min(2**_tz - 1, tmaxy)
            self.tminmax[tz] = (tminx, tminy, tmaxx, tmaxy)

    def makeMetadata(self):
        # 以经纬度的方式计算瓦片的四至,供前端地图跳转使用
        minlat, minlng = self.mercator.MetersToLatLon(self.ominx, self.ominy)
        maxlat, maxlng = self.mercator.MetersToLatLon(self.omaxx, self.omaxy)
        minlat, minlng = max(-85.05112878, minlat), max(-180.0, minlng)
        maxlat, maxlng = min(85.05112878, maxlat), min(180.0, maxlng)
        bbox = [[minlat, minlng], [maxlat, maxlng]]
        metadata = {"bbox": bbox, "tilesize": TILESIZE, 'minZoom': self.tminz,
                    'maxZoom': self.tmaxz, "gcj02": self.options.gcj02}
        with open(os.path.join(self.outputFolder, 'metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=4)

    def makeBaseTiles(self):
        # 最高缩放级别的瓦片行列号范围
        tminx, tminy, tmaxx, tmaxy = self.tminmax[self.tmaxz]
        tileDetails = []
        tz = self.tmaxz
        # y倒序,从左上角开始切图
        for ty in range(tmaxy, tminy - 1, -1):
            for tx in range(tminx, tmaxx + 1):
                # tz,yx,ty对应着行列号中的x/y/z
                ty_final = GoogleTile(tz,ty)
                tilefilename = os.path.join(self.outputFolder, str(
                    tz), str(tx), "%s.%s" % (ty, TILEEXT))
                if not os.path.exists(os.path.dirname(tilefilename)):
                    os.makedirs(os.path.dirname(tilefilename))
                # 兼容大瓦片
                _tz = int(tz - (TILESIZE / 256 - 1))
                # 计算该瓦片的投影经纬度范围
                b = self.mercator.TileBounds(tx, ty, _tz)
                # 获取该瓦片具体的各种偏移参数
                rb, wb = self.geoQuery(b[0], b[3], b[2], b[1])
                rx, ry, rxsize, rysize = rb
                wx, wy, wxsize, wysize = wb
                tileDetails.append(
                    TileDetail(
                        tx=tx,
                        ty=ty,
                        tz=tz,
                        rx=rx,
                        ry=ry,
                        rxsize=rxsize,
                        rysize=rysize,
                        wx=wx,
                        wy=wy,
                        wxsize=wxsize,
                        wysize=wysize,
                    ))
        conf = TileJobInfo(
            srcFile=self.vrtFilename,
            nbDataBands=self.dataBandsCount,
            outputFilePath=self.outputFolder,
            tminmax=self.tminmax,
            tminz=self.tminz,
            tmaxz=self.tmaxz,
        )
        return conf, tileDetails

    def geoQuery(self, ulx, uly, lrx, lry):
        ds = self.warpedDataset
        geotran = ds.GetGeoTransform()
        # geotran[0/3]是tif左上角点x/y
        # geotran[1/5]是像源宽/高
        # 计算该瓦片的左上角在源图上的x/y像素偏移量
        rx = int((ulx - geotran[0]) / geotran[1] + 0.001)
        ry = int((uly - geotran[3]) / geotran[5] + 0.001)
        # 计算该瓦片在源图上的像素宽度
        rxsize = int((lrx - ulx) / geotran[1] + 0.5)
        rysize = int((lry - uly) / geotran[5] + 0.5)
        # 窗口尺寸.4倍于瓦片尺寸,提高缩放重采样时的瓦片效果
        wxsize, wysize = 4 * TILESIZE, 4 * TILESIZE
        wx = 0
        if rx < 0:
            rxshift = abs(rx)
            wx = int(wxsize * (float(rxshift) / rxsize))
            # 等比例缩减多少
            wxsize = wxsize - wx
            rxsize = rxsize - int(rxsize * (float(rxshift) / rxsize))
            rx = 0
        if rx + rxsize > ds.RasterXSize:
            wxsize = int(wxsize * (float(ds.RasterXSize - rx) / rxsize))
            rxsize = ds.RasterXSize - rx
        wy = 0
        if ry < 0:
            ryshift = abs(ry)
            wy = int(wysize * (float(ryshift) / rysize))
            wysize = wysize - wy
            rysize = rysize - int(rysize * (float(ryshift) / rysize))
            ry = 0
        if ry + rysize > ds.RasterYSize:
            wysize = int(wysize * (float(ds.RasterYSize - ry) / rysize))
            rysize = ds.RasterYSize - ry
        return (rx, ry, rxsize, rysize), (wx, wy, wxsize, wysize)


class ProgressBar(object):
    def __init__(self, total_items, title):
        sys.stdout.write("%s 共%d张 \n" % (title, total_items))
        self.total_items = total_items
        self.nb_items_done = 0
        self.current_progress = 0
        self.STEP = 2.5

    def start(self):
        sys.stdout.write("0")

    def updateProgress(self, nb_items=1):
        self.nb_items_done += nb_items
        progress = float(self.nb_items_done) / self.total_items * 100
        if progress >= self.current_progress + self.STEP:
            done = False
            while not done:
                if self.current_progress + self.STEP <= progress:
                    self.current_progress += self.STEP
                    if self.current_progress % 10 == 0:
                        sys.stdout.write(str(int(self.current_progress)))
                        if self.current_progress == 100:
                            sys.stdout.write("\n")
                    else:
                        sys.stdout.write(".")
                else:
                    done = True
        sys.stdout.flush()


class SingleProcessTiling(object):
    def __init__(self, inputFile, outputFolder, options):
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        self.total = 0
        self.tiling()
        self.createOverviewTiles()
        shutil.rmtree(os.path.dirname(self.tileJobInfo.srcFile))
        if COMPRESS:
            self.compressPng()

    def tiling(self):
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        progressBar = ProgressBar(tilecount, '切割顶层瓦片')
        progressBar.start()
        for tileDetail in tileDetails:
            self.createBaseTile(tileDetail)
            progressBar.updateProgress()

    def createBaseTile(self, tileDetail, queue=None):
        gdal.AllRegister()
        # 获取任务参数
        tileJobInfo = self.tileJobInfo
        output = tileJobInfo.outputFilePath
        tilebands = tileJobInfo.nbDataBands
        # 打开vrt文件作为数据集,读取模式,在多进程模式下不会出现锁?
        ds = gdal.Open(tileJobInfo.srcFile, gdal.GA_ReadOnly)
        memDrv = gdal.GetDriverByName('MEM')
        outDrv = gdal.GetDriverByName(TILEDRIVER)
        alphaband = ds.GetRasterBand(1).GetMaskBand()
        tx = tileDetail.tx
        ty = tileDetail.ty
        tz = tileDetail.tz
        rx = tileDetail.rx
        ry = tileDetail.ry
        rxsize = tileDetail.rxsize
        rysize = tileDetail.rysize
        wx = tileDetail.wx
        wy = tileDetail.wy
        wxsize = tileDetail.wxsize
        wysize = tileDetail.wysize
        # '窗口'数据集尺寸就是4倍瓦片大小
        querysize = 4 * TILESIZE
        ty_final = GoogleTile(tz,ty)
        tilefilename = os.path.join(output, str(
            tz), str(tx), "%s.%s" % (ty_final, TILEEXT))
        # 最终要写入的数据集
        dstile = memDrv.Create('', TILESIZE, TILESIZE, tilebands)
        data = alpha = None
        if rxsize != 0 and rysize != 0 and wxsize != 0 and wysize != 0:
            # 根据上文获取到的参数,读取每一张瓦片对应的数据
            data = ds.ReadRaster(rx, ry, rxsize, rysize, wxsize,
                                 wysize, band_list=list(range(1, tilebands)))
            # alpha波段直接创建
            alpha = alphaband.ReadRaster(
                rx, ry, rxsize, rysize, wxsize, wysize)
            if data:
                dsquery = memDrv.Create('', querysize, querysize, tilebands)
                # 先将数据读取到窗口数据集中
                dsquery.WriteRaster(wx, wy, wxsize, wysize,
                                    data, band_list=list(range(1, tilebands)))
                dsquery.WriteRaster(wx, wy, wxsize, wysize,
                                    alpha, band_list=[tilebands])
                # 重采样到目标数据集
                self.scaleQueryToTile(dsquery, dstile, tilefilename)
                del dsquery
        del ds
        del data
        # 目标数据集导出png图片到指定位置
        outDrv.CreateCopy(tilefilename, dstile, strict=0)
        del dstile
        # 如果是多进程模式则向主进程传递进度
        if queue:
            queue.put("tile %s %s %s" % (tx, ty, tz))

    def workerTileDetails(self):
        # 打开数据
        tileJobsMaker = TileJobsMaker(
            self.inputFile, self.outputFolder, self.options)
        tileJobsMaker.openData()
        # 生成元数据文件
        tileJobsMaker.makeMetadata()
        # 生成任务列表
        conf, tileDetails = tileJobsMaker.makeBaseTiles()
        self.tileJobInfo = conf
        return tileDetails

    def compressPng(self):
        progressBar = ProgressBar(self.total, '压缩全部瓦片')
        progressBar.start()
        pngquant.config(quant_file=QUANTFILE, min_quality=70,
                        max_quality=95, speed=10)
        for root, dirs, files in os.walk(self.outputFolder, True):
            for file in files:
                if file.endswith('.png'):
                    realPath = os.path.join(root, file)
                    self.compress(realPath)
                    progressBar.updateProgress()

    def compress(self, imgPath, queue=None):
        if queue is not None:
            # 多进程下每个进程单独创建临时路径
            import uuid
            tmp_file = os.path.join(tempfile.gettempdir(
            ), '{0}.quant.tmp.png'.format(uuid.uuid4().hex))
            pngquant.config(quant_file=QUANTFILE, min_quality=70,
                            max_quality=95, speed=5, tmp_file=tmp_file)
            queue.put('imgPath')
        # 不保留空瓦片,直接删除
        if os.path.getsize(imgPath) == EMPTY[str(TILESIZE)]:
            os.remove(imgPath)
        else:
            pngquant.quant_image(imgPath, imgPath)

    def scaleQueryToTile(self, dsquery, dstile, tilefilename=''):
        """从'窗口'数据集使用'average'算法重采样到目标数据集"""
        tilebands = dstile.RasterCount
        for i in range(1, tilebands + 1):
            res = gdal.RegenerateOverview(dsquery.GetRasterBand(
                i), dstile.GetRasterBand(i), 'average')
            check(res != 0, "概览生成失败 %s,%d" % (tilefilename, res))

    def createOverviewTiles(self):
        tileJobInfo = self.tileJobInfo
        memDriver = gdal.GetDriverByName('MEM')
        outDriver = gdal.GetDriverByName(TILEDRIVER)
        tilebands = tileJobInfo.nbDataBands
        tcount = 0
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            tcount += (1 + abs(tmaxx - tminx)) * (1 + abs(tmaxy - tminy))
        if tcount == 0:
            return
        self.total += tcount
        progressBar = ProgressBar(tcount, '切割下层瓦片')
        progressBar.start()
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            # 遍历所有底层瓦片
            for ty in range(tmaxy, tminy - 1, -1):
                for tx in range(tminx, tmaxx + 1):
                    ty_final = GoogleTile(tz,ty)
                    tilefilename = os.path.join(self.outputFolder, str(
                        tz), str(tx), "%s.%s" % (ty_final, TILEEXT))
                    if not os.path.exists(os.path.dirname(tilefilename)):
                        os.makedirs(os.path.dirname(tilefilename))
                    dsquery = memDriver.Create(
                        '', 2 * TILESIZE, 2 * TILESIZE, tilebands)
                    dstile = memDriver.Create(
                        '', TILESIZE, TILESIZE, tilebands)
                    # 每级的行列号都是它大一级的2分之1
                    # 每张瓦片都与比它大一级的4张对应瓦片所示范围相同,所以根据这4张瓦片就能拼接出本级瓦片
                    for y in range(2 * ty, 2 * ty + 2):
                        for x in range(2 * tx, 2 * tx + 2):
                            minx, miny, maxx, maxy = tileJobInfo.tminmax[tz + 1]
                            # 只拼接有数据的瓦片
                            if x >= minx and x <= maxx and y >= miny and y <= maxy:
                                y_final = GoogleTile(tz+1,y)
                                path = os.path.join(self.outputFolder, str(
                                    tz + 1), str(x), "%s.%s" % (y_final, TILEEXT))
                                # 读取4张中的每一张瓦片
                                dsquerytile = gdal.Open(path, gdal.GA_ReadOnly)
                                # 把4张瓦片放到对应的位置
                                if (ty == 0 and y == 1) or (ty != 0 and(y % (2 * ty)) != 0):
                                    tileposy = 0
                                else:
                                    tileposy = TILESIZE
                                if tx:
                                    tileposx = x % (2 * tx) * TILESIZE
                                elif tx == 0 and x == 1:
                                    tileposx = TILESIZE
                                else:
                                    tileposx = 0
                                # 读取瓦片再在'窗口'中绘制
                                tempRaseter = dsquerytile.ReadRaster(
                                    0, 0, TILESIZE, TILESIZE)
                                dsquery.WriteRaster(tileposx, tileposy, TILESIZE, TILESIZE, tempRaseter, band_list=list(
                                    range(1, tilebands + 1)))
                    # 重采样后写入本地对应文件
                    self.scaleQueryToTile(
                        dsquery, dstile, tilefilename=tilefilename)
                    outDriver.CreateCopy(tilefilename, dstile, strict=0)
                    progressBar.updateProgress()


class MultiProcessTiling(SingleProcessTiling):
    """
    多进程生成瓦片并压缩
    """

    def __init__(self, inputFile, outputFolder, options):
        super().__init__(inputFile, outputFolder, options)

    def tiling(self):
        processes = self.options.processes or cpu_count()
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        manager = Manager()
        queue = manager.Queue()
        pool = Pool(processes=processes)
        for tileDetail in tileDetails:
            pool.apply_async(self.createBaseTile,
                             (tileDetail, ), {"queue": queue})
        p = Process(target=self.progressPrinter,
                    args=[queue, tilecount, '切割顶层瓦片'])
        p.start()
        pool.close()
        pool.join()
        p.join()

    def progressPrinter(self, queue, nb_jobs, title):
        """供多进程模式下打印进度"""
        pb = ProgressBar(nb_jobs, title)
        pb.start()
        for _ in range(nb_jobs):
            queue.get()
            pb.updateProgress()
            queue.task_done()

    def compressPng(self):
        """多进程压缩png"""
        processes = self.options.processes or cpu_count()
        manager = Manager()
        queue = manager.Queue()
        pool = Pool(processes=processes)
        p = Process(target=self.progressPrinter, args=[
                    queue, self.total, '压缩全部瓦片'])
        for root, dirs, files in os.walk(self.outputFolder, True):
            for file in files:
                if file.endswith('.png'):
                    realPath = os.path.join(root, file)
                    pool.apply_async(self.compress, args=(realPath, queue,),)
        p.start()
        pool.close()
        pool.join()
        p.join()


def process_args(argv):
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-z", dest='zoom', metavar='切割级别', help="\t如 '12-20'")
    parser.add_argument("-p", dest='processes', metavar='进程数',
                        type=int, default=1, help='\t默认单进程')
    parser.add_argument("-g", dest='gcj02', action='store_true',
                        default=False, help='\t火星偏移')
    parser.add_argument("-i", dest='input', metavar='tif文件', required=True)
    parser.add_argument("-s", dest='s_srs', metavar='输入文件的参考希')
    parser.add_argument("-o", dest='output', metavar="输出文件夹",
                        help="\t可选,默认在输入文件夹下的同名文件夹")
    args = parser.parse_args(argv)
    inputFile = args.input
    check(not os.path.isfile(inputFile), "%s不存在或非文件" % inputFile)
    outputFolder = args.output
    if not outputFolder:
        tifname = os.path.basename(inputFile).split('.')[0]
        outputFolder = os.path.join(os.path.dirname(
            os.path.abspath(inputFile)), tifname)
    if not os.path.exists(outputFolder):
        os.makedirs(outputFolder)
    return inputFile, outputFolder, args


def main():
    import time
    start = int(time.time())
    argv = gdal.GeneralCmdLineProcessor(sys.argv)
    inputFile, outputFolder, options = process_args(argv[1:])
    if options.processes == 1:
        SingleProcessTiling(inputFile, outputFolder, options)
    else:
        MultiProcessTiling(inputFile, outputFolder, options)
    print('全部结束,用时:%d秒' % (int(time.time())-start))


if __name__ == '__main__':
    main()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,133评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,682评论 3 390
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,784评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,508评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,603评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,607评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,604评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,359评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,805评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,121评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,280评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,959评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,588评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,206评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,442评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,193评论 2 367
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,144评论 2 352

推荐阅读更多精彩内容