在上一篇中我们介绍了 mpi4py 的若干使用技巧,并且简要介绍了 caput 及其 mpiutil 模块,下面我们将介绍 mpiutil 中提供的若干方便和易用的函数,这些函数可以使我们更加方便地进行 Python 并行编程,并且使我们的程序很容易地做到兼容非 MPI 编程环境。
函数接口
以下介绍的所有函数都可以兼容非 MPI 环境(此时 _comm 为 None),当 mpi4py 可用时 _comm 是 MPI.COMM_WORLD,也可以传递一个其它的通信子,此时将在该通信子上执行相应的操作。
mpilist(full_list, method='con', comm=_comm)
将一个序列 full_list
按照方法 method
划分成 n (n 为 comm
的 size,当 comm
为 None 时, 其 size = 1) 份并分配给每个进程,每个进程分得的元素数目相等(如果可以均分)或相差 1 (如果不能均分,rank 较小的进程会多 1 个),返回每个进程得到的子序列。当 fullist
的元素数目少于进程数目时,后面的进程会因为分不到元素而返回一个空序列。method
的取值可以为:
- 'con':连续划分,即 rank 较小的进程得到
full_list
中前面的元素,为默认值; - 'alt':交替划分,即 rank = 0 的进程得到位置为 [0, n, 2n, ...] 的元素, rank == 1 的进程得到位置为 [1, n+1, 2n+1, ...] 的元素, 等;
- 'rand':随机划分,效果相当于对原系列随机重排后再按确定的方式分配给每个进程。
comm
可以是 None 或一个通信子对象,其默认值 _comm 为 None (mpi4py 不可用时) 或者 MPI.COMM_WORLD (mpi4py 可用时),也可以传递一个其它的通信子对象,此时将 full_list
划分给该通信子对象所包含的每个进程。
mpirange(*args, **kargs)
MPI 版本的 range 函数,参数同 range 函数,另外加可选的参数 method
(默认值 con
) 和 comm
(默认值为 _comm),这两个参数的意思同上面介绍的 mpilist。这个函数的执行效果是 mpilist(range(*args), method, comm),即将 range 函数生成的序列作为 full_list
调用 mpilist 函数,返回每个进程得到的子序列。
barrier(comm=_comm)
栅障同步,参数 comm
的意义同函数 mpilist。
bcast(data, root=0, comm=_comm)
广播操作,数据从 root
进程广播到所有其它进程,参数 comm
的意义同函数 mpilist。
reduce(sendobj, root=0, op=None, comm=_comm)
规约操作,将数据按照方法 op
(默认值 None 会执行 MPI.SUM) 规约到 root
进程,参数 comm
的意义同函数 mpilist。
allreduce(sendobj, op=None, comm=_comm)
全规约操作,将数据按照方法 op
作全规约,参数 comm
的意义同函数 mpilist。
gather_list(lst, root=None, comm=_comm)
将各个进程的列表 lst
收集到 root
进程中并合并成一个新的列表,如果 root
为一个整数,则只有 rank 为该整数的进程会收集到数据并返回合并的列表,其它进程返回 None,如果 root
为 None,则所有进程都会收集(全收集操作),每个进程都会返回合并的列表。参数 comm
的意义同函数 mpilist。
parallel_map(func, glist, root=None, method='con', comm=_comm)
将序列 glist
按照方法 method
划分给每个进程,然后将函数 func
作用到每个进程所得的子序列的每个元素上,函数的返回值会被收集到 root
中经合并后返回。如果 root
为一个整数,则只有 rank 为该整数的进程会收集到数据并返回合并的列表,其它进程返回 None,如果 root
为 None,则所有进程都会收集(全收集操作),每个进程都会返回合并的列表。参数 method
和 comm
的意义同函数 mpilist。
split_all(n, comm=_comm)
将一个长度为 n
的序列顺序连续地划分给每个进程,返回一个三元组 (num, start, end),其中 num,start,end 都为长度为 comm
的 size (1 如果 comm
为 None)的 numpy 数组,分别给出每个进程分配到的元素数目,每个进程分配到的元素在原系列中的起始和结束位置。参数 comm
的意义同函数 mpilist。
split_local(n, comm=_comm)
将一个长度为 n
的序列顺序连续地划分给每个进程,返回一个三元组 (num, start, end),其中 num,start,end 都为整数,分别给出该进程自身分配到的元素数目,该进程分配到的元素在原系列中的起始和结束位置。参数 comm
的意义同函数 mpilist。
gather_local(global_array, local_array, local_start, root=0, comm=_comm)
将各个进程中的 numpy 数组 local_array
收集到 root
进程的 global_array
中,local_start
指明 local_array
放置在 global_array
中的起始位置,是一个长度为 global_array
维数的 tuple,其每一个元素指明放置在该维的起始位置。如果 root
为一个整数,则 local_array
只会被收集到 rank 为该整数的进程中,其它进程可以设置 global_array
为 None,如果 root
为 None,则 local_array
会被收集到所有进程中。参数 comm
的意义同函数 mpilist。
gather_array(local_array, axis=0, root=0, comm=_comm)
将各个进程中的 numpy 数组 local_array
沿着轴 axis
收集到 root
进程,合并成一个大的 numpy 数组后返回。如果 root
为一个整数,则 local_array
只会被收集到 rank 为该整数的进程中,其它进程会返回 None,如果 root
为 None,则 local_array
会被收集到所有进程中。参数 comm
的意义同函数 mpilist。
scatter_local(global_array, local_array, local_start, root=None, comm=_comm)
将 numpy 数组 global_array
散发到各个进程的 local_array
中,local_start
指明从 global_array
散发的起始位置,是一个长度为 global_array
维数的 tuple,其每一个元素指明该维的起始位置。如果 root
为一个整数,则只会从 rank 为该整数的进程的 global_array
中散发数据到所有其它进程中,因此其它进程的 global_array
可以为 None,如果 root
为 None,则每个进程会从各自的 global_array
中获取对应的数据放置到 local_array
中,因此一般要求每个进程的 global_array
都相同(但也可以不同)。参数 comm
的意义同函数 mpilist。
scatter_array(global_array, axis=0, root=None, comm=_comm)
将 numpy 数组 global_array
按照轴 axis
散发到各个进程,各个进程返回所得到的子数组。global_array
的 axis
轴会尽量均分,如果不能均分,则 rank 较小的进程会多 1,如果不够分,则 rank 最大的若干进程会返回空数组。如果 root
为一个整数,则只会从 rank 为该整数的进程的 global_array
中散发数据到所有其它进程中,因此其它进程的 global_array
可以为 None,如果 root
为 None,则每个进程会从各自的 global_array
中获取对应的数据,因此一般要求每个进程的 global_array
都相同(但也可以不同)。参数 comm
的意义同函数 mpilist。
例程
下面给出以上介绍的若干函数的使用例程。
# mpiutil_funcs.py
"""
Demonstrates the usage of mpilist, mpirange, bcast, gather_list, parallel_map,
split_all, split_local, gather_array, scatter_array.
Run this with 4 processes like:
$ mpiexec -n 4 python mpiutil_funcs.py
"""
import sys
import time
import numpy as np
from caput import mpiutil
rank = mpiutil.rank
size = mpiutil.size
sec = 5 # seconds to wait
def separator(sec, tag):
# sleep, sync, and flush to avoid output of different parts being mixed
time.sleep(sec)
mpiutil.barrier()
sys.stdout.flush()
if rank == 0:
print
print '-' * 35 + ' ' + tag + ' ' + '-' * 35
# mpilist
separator(sec, 'mpilist')
full_list = [1, 2.5, 'a', True, (3, 4), {'x':1}]
local_list = mpiutil.mpilist(full_list)
print "rank %d has %s with method = 'con'" % (rank, local_list)
local_list = mpiutil.mpilist(full_list, method='alt')
print "rank %d has %s with method = 'alt'" % (rank, local_list)
local_list = mpiutil.mpilist(full_list, method='rand')
print "rank %d has %s with method = 'rand'" % (rank, local_list)
# mpirange
separator(sec, 'mpirange')
local_ary = mpiutil.mpirange(1, 7)
print "rank %d has %s with method = 'con'" % (rank, local_ary)
local_ary = mpiutil.mpirange(1, 7, method='alt')
print "rank %d has %s with method = 'alt'" % (rank, local_ary)
local_ary = mpiutil.mpirange(1, 7, method='rand')
print "rank %d has %s with method = 'rand'" % (rank, local_ary)
# bcast
separator(sec, 'bcast')
if rank == 0:
sendobj = 'obj'
else:
sendobj = None
sendobj = mpiutil.bcast(sendobj, root=0)
print 'rank %d has sendobj = %s after bcast' % (rank, sendobj)
# gather_list
separator(sec, 'gather_list')
if rank == 0:
lst = [0.5, 2]
elif rank == 1:
lst = ['a', False, 'xy']
elif rank == 2:
lst = [{'x': 1}]
else:
lst = []
lst = mpiutil.gather_list(lst, root=None)
print 'rank %d has %s after gather_list' % (rank, lst)
# parallel_map
separator(sec, 'parallel_map')
glist = range(6)
result = mpiutil.parallel_map(lambda x: x*x, glist, root=0)
if rank == 0:
print 'result = %s' % result
# split_all
separator(sec, 'split_all')
print 'rank %d has: %s' % (rank, mpiutil.split_all(6))
# split_local
separator(sec, 'split_local')
print 'rank %d has: %s' % (rank, mpiutil.split_local(6))
# gather_array
separator(sec, 'gather_array')
if rank == 0:
local_ary = np.array([[0, 1], [6, 7]])
elif rank == 1:
local_ary = np.array([[2], [8]])
elif rank == 2:
local_ary = np.array([[3], [9]])
if rank == 3:
local_ary = np.array([[4, 5], [10, 11]])
global_ary = mpiutil.gather_array(local_ary, axis=1, root=0)
if rank == 0:
print 'global_ary = %s' % global_ary
# scatter_array
separator(sec, 'scatter_array')
local_ary = mpiutil.scatter_array(global_ary, axis=1, root=0)
print 'rank %d has local_ary = %s' % (rank, local_ary)
运行结果如下:
$ mpiexec -n 4 python mpiutil_funcs.py
Starting MPI rank=3 [size=4]
Starting MPI rank=2 [size=4]
Starting MPI rank=0 [size=4]
Starting MPI rank=1 [size=4]
----------------------------------- mpilist -----------------------------------
rank 1 has ['a', True] with method = 'con'
rank 1 has [2.5, {'x': 1}] with method = 'alt'
rank 2 has [(3, 4)] with method = 'con'
rank 2 has ['a'] with method = 'alt'
rank 0 has [1, 2.5] with method = 'con'
rank 0 has [1, (3, 4)] with method = 'alt'
rank 3 has [{'x': 1}] with method = 'con'
rank 3 has [True] with method = 'alt'
rank 0 has [{'x': 1}, 'a'] with method = 'rand'
rank 3 has [True] with method = 'rand'
rank 1 has [2.5, (3, 4)] with method = 'rand'
rank 2 has [1] with method = 'rand'
----------------------------------- mpirange -----------------------------------
rank 1 has [3, 4] with method = 'con'
rank 1 has [2, 6] with method = 'alt'
rank 0 has [1, 2] with method = 'con'
rank 0 has [1, 5] with method = 'alt'
rank 2 has [5] with method = 'con'
rank 2 has [3] with method = 'alt'
rank 3 has [6] with method = 'con'
rank 3 has [4] with method = 'alt'
rank 3 has [3] with method = 'rand'
rank 2 has [2] with method = 'rand'
rank 0 has [5, 1] with method = 'rand'
rank 1 has [4, 6] with method = 'rand'
----------------------------------- bcast -----------------------------------
rank 1 has sendobj = obj after bcast
rank 3 has sendobj = obj after bcast
rank 2 has sendobj = obj after bcast
rank 0 has sendobj = obj after bcast
----------------------------------- gather_list -----------------------------------
rank 1 has [0.5, 2, 'a', False, 'xy', {'x': 1}] after gather_list
rank 3 has [0.5, 2, 'a', False, 'xy', {'x': 1}] after gather_list
rank 2 has [0.5, 2, 'a', False, 'xy', {'x': 1}] after gather_list
rank 0 has [0.5, 2, 'a', False, 'xy', {'x': 1}] after gather_list
----------------------------------- parallel_map -----------------------------------
result = [0, 1, 4, 9, 16, 25]
----------------------------------- split_all -----------------------------------
rank 3 has: [[2 2 1 1]
[0 2 4 5]
[2 4 5 6]]
rank 0 has: [[2 2 1 1]
[0 2 4 5]
[2 4 5 6]]
rank 2 has: [[2 2 1 1]
[0 2 4 5]
[2 4 5 6]]
rank 1 has: [[2 2 1 1]
[0 2 4 5]
[2 4 5 6]]
----------------------------------- split_local -----------------------------------
rank 1 has: [2 2 4]
rank 0 has: [2 0 2]
rank 2 has: [1 4 5]
rank 3 has: [1 5 6]
----------------------------------- gather_array -----------------------------------
global_ary = [[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]]
----------------------------------- scatter_array -----------------------------------
rank 0 has local_ary = [[0 1]
[6 7]]
rank 1 has local_ary = [[2 3]
[8 9]]
rank 3 has local_ary = [[ 5]
[11]]
rank 2 has local_ary = [[ 4]
[10]]
以上我们介绍了 mpiutil 中提供的若干方便和易用的函数,在下一篇中我们将介绍建立在 numpy array 基础上的并行分布式数组 MPIArray。