背景
现在大多业务都使用机器学习,程序启动时加载训练好的模型文件,运行期也会触发模型的 reload。 在程序启动时如果加载耗时比较长,那么程序自然有段时间不可服务(模型没有准备好),但是运行期由于是双 buffer 切换,耗时长些也无所谓。
优化前
加载 14 个模型文件,并行加载,文件最小的几k,最大有 300m,加载时间取短板最长耗时 20s
优化后
对最大的四个文件,采用并行加载,耗时最大减少到 3s,优化完成
代码串行处理逻辑
原有单个文件也是串行处理逻辑
- os.Open 打开文件
- buf.ReadString 按行读取数据
- 根据业务需求,解析各个字段
- 追加到模型字典
这个逻辑非常简单,符合人的直觉思维,但同时也非常低效。
并行优化1
思路很简单:模型文件没有顺序,可以一次性全读到内存中,然后按行去并行解析,最后合并到字典,非常类似 MapReduce
- ioutil.ReadFile 全部读到内存中
- bytes.Split 根据 '\n' 分隔符打散
- 开启 n 个 goroutine 并行解析每行数据
- 合并到模型字典
第一次优化后耗时降为 10s,初步成效,但是仍然不理想。纺计每一步耗时后发现,对于最大 300m 的文件,bytes.Split 打散耗时 4s, 模型 Map 合并耗时 5s
并行优化2
和同事探讨下如何继续优化,对于 Map 无法并行。当前模型实现方式用单一 Map,如果加锁就和串行合并行为是一致的。当初始化 Map 指定大小时,合并时间从 5s 降到 2s,避免了 rehash copy 的开销,效果很明显。
另外 bytes.Split 打散耗时超长是没有想到的,看了下源码,内部两次遍历,耗时自然和数量成正比。同事提义将打散移到并行阶段,由每个 goroutine 去完成,预估并行数量,然后按 batch 打散。有几点需要注意:
- 无所提前知道总数据量大小,模型 Map 初始化要预估大小,按 30 byte 一行猜测即可
- 每个 gorouinte 划分数据也是不均等的,但一定要以 '\n' 分隔符打散,不能打数据截断
最后共耗时 3s,一次性加载内存维 150ms,并行解析 1s,合并 Map 2s
代码示例:
当前性价比最高的优化,如果大家有更好的方式可以共同交流一下,第一个是抽象的执行函数,第二个是示例使用方示
// ParallelLoadModelFile 并行加载模型文件
// @params data 文件二进制数据
// @params sep 分隔符
// @params name 识别标记
// @params parallel 并发数目, 一般不超过20, 过大没用
// @params parse 用户处理函数
// @params merge 用户合并函数
// 原类类似 MapReduce, 先将文件并行处理, 最后 reduce 合并。使用请参考 loadPassengerFeatures2
// 原则:尽量将耗时操作并行化
// 注意:
// 1. map 初始化时一定要指定大小,否则 rehash copy 成本非常高 测试 800W 条记录合并消耗 2s
// 2. 数据在 parse 和 merge 函数流动要用 channel, 具体类型及解析合并由调用方决定
// 3. 需要特殊处理行不能使用这个函数, 要单独处理
//
// 流程优化:
// 读文件 | 解析每行数据并写到map
// +------------------------+
//
// load内存并打散 分片 聚合
// +-----+
// +--------+ |-----| +------+
// +-----+
// load 打散分片 聚合
//
// +----+
// +----+ |----| +---+
// +----+
//
// 1. ioutil.ReadFile一次性读入内存 2. bytes.Split 按\n打散 3. 分片计算 4. 合并merge
// 在大文件时 bytes.Split 非常耗时, 将第2步移到并行阶段, 和3一起算。合并 map 非常耗时
// Map 操作只能串行, 并发也需要加锁来互斥, 等同于串行, 暂时没想到好的合并方法
func ParallelLoadModelFile(data []byte, sep []byte, name string, parallel int, parse func([]byte), merge func()) {
if parallel <= 0 || parallel > 30 || parse == nil || merge == nil || len(sep) == 0 {
panic("ParallelLoadModelFile params illegal")
}
var (
wait = sync.WaitGroup{} // sync
size = len(data) // file size
batch = size / parallel // batch size
num = size/batch + 1 // parallel goroutine
start = 0
end = batch
)
for i := 0; i < num; i++ {
wait.Add(1)
// 获取第一个 sep 所在的 index
idx := bytes.Index(data[end:], sep)
if idx == -1 {
end = len(data) - 1
} else {
end += idx
}
go parse(data[start:end])
start = end
if (end + batch) < len(data) {
end += batch
} else {
end = len(data) - 1
}
}
go func() {
for i := 0; i < num; i++ {
merge()
wait.Done()
}
}()
// 同步阻塞,等待所有 MapReduce
wait.Wait()
}
//加载小时特征 并行版本
func LoadHourGEOInfo2(model_data_center *ModelDataCenter, file_name string) error {
now := time.Now().UnixNano()
defer func() {
logger.Info("load[%s] time=%dms", file_name, (time.Now().UnixNano()-now)/1e6)
}()
content, err := ioutil.ReadFile(file_name)
if err != nil {
logger.Error("ioutil readfile error, file_name=%s", file_name)
return err
}
// 预估map大小
model_data_center.HourGEOInfoData = make(map[string]DynamicDiscountGEOInfo, len(content)/30)
// model 消息
modelChan := make(chan map[string]DynamicDiscountGEOInfo, 10)
// map 并行处理函数
mapParse := func(content []byte) {
var (
data = bytes.Split(content, SepLine)
m = make(map[string]DynamicDiscountGEOInfo, len(data))
)
defer func() {
// 将数据扔到 chan 待合并
// 用 defer 防止遗望
modelChan <- m
}()
for _, l := range data {
line := string(l)
// 兼容\r\n换行的情况
line = strings.Replace(line, "\r", "", -1)
list := strings.Split(line, ",")
var hour_geo_info DynamicDiscountGEOInfo
if len(list) != 8 && len(list) != 10 {
logger.Warn("wrong fomat file=%s line=%s cols.Size=%d", file_name, line, len(list))
continue
}
lng_lat, err := strconv.Atoi(list[0])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 0)
continue
}
hour, err := strconv.Atoi(list[1])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 1)
continue
}
hour_geo_key := GetGEOKey(hour, lng_lat, "HOUR", 0)
hour_geo_info.StartGEOInfo.CarpoolNum, err = strconv.Atoi(list[2])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 2)
continue
}
hour_geo_info.StartGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[3])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 3)
continue
}
hour_geo_info.StartGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[4])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 4)
continue
}
hour_geo_info.DestGEOInfo.CarpoolNum, err = strconv.Atoi(list[5])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 5)
continue
}
hour_geo_info.DestGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[6])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 6)
continue
}
hour_geo_info.DestGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[7])
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 7)
continue
}
if len(list) == 10 {
hour_geo_info.StartGEOInfo.InComeRate, err = strconv.ParseFloat(list[8], 64)
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 8)
continue
}
hour_geo_info.DestGEOInfo.InComeRate, err = strconv.ParseFloat(list[9], 64)
if err != nil {
logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 9)
continue
}
} else {
hour_geo_info.StartGEOInfo.InComeRate = -1.0
hour_geo_info.DestGEOInfo.InComeRate = -1.0
}
// 更新 map
m[hour_geo_key] = hour_geo_info
}
}
// reduce 最终合并函数
mergeReduce := func() {
select {
// merge model msg
case m := <-modelChan:
logger.Info("parallel load[%s]||line_num=%d", file_name, len(m))
for k := range m {
model_data_center.HourGEOInfoData[k] = m[k]
}
}
}
ParallelLoadModelFile(content, SepLine, file_name, 3, mapParse, mergeReduce)
return nil
}