记一个简单的协程池

记一个简单的协程池
github地址
https://github.com/JeonYang/chanPool
包结构

|channal
        |dispatcher.go
        |error.go
        |job.go
        |pool.go
        |pool_test.go
        |worker.go

dispatcher.go

package channal

import (
    "errors"
)

// 协调器
type Dispatcher interface {
    Start()
    Stop()
    AddJob(job Job) error
    JobQueueLen() int
    WorkerPool() chan Worker
}

// 调度员
type dispatcher struct {
    workerPool chan Worker
    jobQueue   chan Job
    stopSignal chan struct{}
    stop       bool
    stoped     bool
}

// 创建调度器
func NewDispatcher(workerPool chan Worker, jobQueue chan Job) Dispatcher {
    return &dispatcher{workerPool: workerPool, jobQueue: jobQueue, stopSignal: make(chan struct{})}
}

// 分派工作给自由工人
func (dis *dispatcher) Start() {
    dis.stoped = false
    dis.stop = false
    go func() {
        for {
            select {
            // 监听调度器的工作通道
            case job := <-dis.jobQueue:
                worker := <-dis.workerPool
                worker.AddJob(job)
                // 监听调度器的停止信号
            case <-dis.stopSignal:
                for i := 0; i < len(dis.workerPool); i++ {
                    worker := <-dis.workerPool
                    worker.Stop()
                }
                dis.stopSignal <- struct{}{}
                return
            }
        }
    }()
}

func (dis *dispatcher) Stop() {
    dis.stop = true
    dis.stopSignal <- struct{}{}
    <-dis.stopSignal
    dis.stoped = true
}

func (dis *dispatcher) JobQueueLen() int {
    return len(dis.jobQueue)
}

func (dis *dispatcher) WorkerPool() chan Worker {
    return dis.workerPool
}

func (dis *dispatcher) AddJob(job Job) error {
    if dis.stop {
        errors.New(Stoped)
    }
    dis.jobQueue <- job
    return nil
}

error.go

package channal

var Stoped = "STOPED"


job.go

package channal

// 工作
type Job func()



pool.go

package channal

import (
    "sync"
    "errors"
)

type Pool interface {
    Start()
    Stop()
    AddJob(job Job) error
    WaitForAll()
    EnableWaitForAll(enable bool)
}

type pool struct {
    dispatcher       Dispatcher
    wg               sync.WaitGroup
    enableWaitForAll bool // 启用所有等待
    workerNum        int  // 工人总数
    jobNum           int  // 工作数
    workerCount      int  // 正在工作工人的数量
    stoped           bool
    stop             bool
}

//workerNum 工人池中的工人数量
//
//jobNum job池中的job数量
func NewPool(workerNum, jobNum int) Pool {
    workers := make(chan Worker, workerNum)
    jobs := make(chan Job, jobNum)
    return &pool{
        dispatcher:       NewDispatcher(workers, jobs),
        enableWaitForAll: false,
        workerNum:        workerNum,
        jobNum:           jobNum,
    }
}

// 添加一个job到job池中
func (p *pool) AddJob(job Job) error {
    if p.stop {
        return errors.New(Stoped)
    }
    if p.enableWaitForAll {
        p.wg.Add(1)
    }
    err := p.dispatcher.AddJob(func() {
        job()
        if p.enableWaitForAll {
            p.wg.Done()
        }
    })
    if err != nil {
        return err
    }
    if p.dispatcher.JobQueueLen() > 0 {
        if p.workerCount < p.workerNum {
            worker := NewWorker(p.dispatcher.WorkerPool())
            worker.Start()
            p.workerCount++
        }
    }
    return nil
}

// 等待所有协程操作完成
func (p *pool) WaitForAll() {
    if p.enableWaitForAll {
        p.wg.Wait()
    }
}

// 停止所有进程
func (p *pool) Stop() {
    p.stop = true
    p.dispatcher.Stop()
    p.stoped = true
    p.workerCount = 0
}

// 是否允许等待所有
func (p *pool) EnableWaitForAll(enable bool) {
    p.enableWaitForAll = enable
}

//Start worker pool and dispatch
func (p *pool) Start() {
    p.dispatcher.Start()
    p.stoped = false
    p.stop = false
}

pool_test.go

package channal

import (
    "testing"
    "fmt"
    "time"
)

func TestNewPool(t *testing.T) {
    pool := NewPool(10, 10)
    pool.Start()
    pool.EnableWaitForAll(false)
    pool.AddJob(job_.do)
    pool.AddJob(job_.do)
    pool.AddJob(job_.do)
    pool.WaitForAll()
    //time.Sleep(time.Second)
    pool.Stop()
    fmt.Println("   pool.AddJob(do)", pool.AddJob(do))
    pool.Start()
    time.Sleep(time.Minute)
}
func do() {
    fmt.Println("=========")
}
var job_ job=job{"123","321"}
type job struct {
    name string
    val string
}

func (job job) do()  {
    fmt.Println("name=========",job.name)
    fmt.Println("val=========",job.val)
}


worker.go

package channal

// 工人
type Worker interface {
    Start()
    Stop()
    AddJob(job Job)
}

type worker struct {
    workerPool chan Worker
    jobQueue   chan Job
    stopSignal chan struct{}
    stoped     bool
    stop       bool
}

func NewWorker(workerPool chan Worker) *worker {
    return &worker{
        workerPool: workerPool,
        jobQueue:   make(chan Job),
        stopSignal: make(chan struct{}),
    }
}

// 工人开始工作
func (w *worker) Start() {
    go func() {
        for {
            // 将本身注册给相应的工人池
            w.workerPool <- w
            select {
            // 监听工人的工作通道
            case job := <-w.jobQueue:
                job()
                //监听工人停止信号
            case <-w.stopSignal:
                w.stopSignal <- struct{}{}
                return
            }

        }
    }()

}

// 工人开始工作
func (w *worker) Stop() {
    w.stopSignal <- struct{}{}
    <-w.stopSignal
    close(w.stopSignal)
    close(w.jobQueue)
}

// 工人开始工作
func (w *worker) AddJob(job Job) {
    w.jobQueue <- job
}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容