package common import ( "context" "errors" "fmt" "math" "sync" "sync/atomic" "git.shuncheng.lu/bigthing/gocommon/pkg/internal/util" ) var ( GoWithRecover = util.GoWithRecover ) /** fork join todo 1、添加自定义 fork 方法 //2、添加限制最大运行的g数量(但是Go里面没有池的概念,需要自己实现,池的效果没有重新开辟一个G合适,优化的方式可以是,保证当前运行中的最大G个数) */ // Job 并发处理,有线程安全问题 // start , end 属于左闭右开 // result 子任务返回的结果 type Job func(ctx context.Context, startIndex, endIndex uint64) (result interface{}, err error) // JobHandler 串行处理,无线程安全问题 // data 是每个job返回的结果 type JobResultHandler func(ctx context.Context, result interface{}) error // ctx 上下文传递 // totalCount 全部的任务数 // forkCount 每个任务切分的数量 // 所以开启的G的数量差不多是 totalCount/forkCount // job 并行处理子任务的逻辑 // handlerData 子任务结果的处理逻辑 // maxRunG 最大并发运行的g数量,目前来看 并发 100w个G=4800M内存,所以需要最大并发量限制 func ParallelJobRun(ctx context.Context, totalCount, forkCount uint64, job Job, jobResultHandler JobResultHandler, maxRunG int64) error { if ctx == nil || totalCount == 0 || forkCount == 0 || job == nil || jobResultHandler == nil { return errors.New("the params has error") } if maxRunG <= 0 { maxRunG = math.MaxInt32 } receiveChannel := make(chan interface{}, 0) // 串行处理,所以不需要buffer,设置了buffer可以解决的问题是可以提前释放g errChannel := make(chan error, 0) ctx, cancel := context.WithCancel(ctx) defer func() { cancel() // 最后关闭,强行通知上下文,关闭处理,防止的问题一个子任务出现问题,需要告知其他的全部出现问题,不做任何处理 close(errChannel) // 这个必须关闭 channel,如果在job线程中去关闭,我们的主线程会强行收到两个通知,也就是假如rc收到后return就不会关闭ec,所以job只关闭rc,主程序退出再关闭rc }() go func() { getGNum := func() uint64 { if totalCount%forkCount != 0 { return (totalCount / forkCount) + 1 } return totalCount / forkCount } var ( gNum = getGNum() wg = sync.WaitGroup{} ) defer func() { wg.Wait() close(receiveChannel) }() jobInfo := jobInfo{ ctx: ctx, wg: &wg, forkCount: forkCount, totalCount: totalCount, errChannel: errChannel, receiveChannel: receiveChannel, job: job, } limitGoroutineRunJob(gNum, &jobInfo, maxRunG) }() for { select { case data, isOpen := <-receiveChannel: if !isOpen { return nil } // 同步执行 err := jobResultHandler(ctx, data) if err != nil { return err } case err := <-errChannel: if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } func runJob(jobInfo *jobInfo, num uint64) { jobInfo.wg.Add(1) defer func() { if err := recover(); err != nil { jobInfo.errChannel <- errors.New(fmt.Sprintf("panic:\n %v", err)) } // 最后关闭wg,保证所有的channel都未关闭,防止panic jobInfo.wg.Done() }() start := num * jobInfo.forkCount end := (num + 1) * jobInfo.forkCount // 左闭右开 if end > jobInfo.totalCount { end = jobInfo.totalCount } result, err := jobInfo.job(jobInfo.ctx, start, end) if err != nil { select { case <-jobInfo.ctx.Done(): case jobInfo.errChannel <- err: } return } select { case <-jobInfo.ctx.Done(): case jobInfo.receiveChannel <- result: } } type jobInfo struct { ctx context.Context wg *sync.WaitGroup forkCount, totalCount uint64 errChannel chan<- error receiveChannel chan<- interface{} job Job } func limitGoroutineRunJob(gNum uint64, job *jobInfo, maxRunNum int64) { var ( init int64 = 0 curRunningJob = &init cond = NewCond() count uint64 = 0 ) for ; count < gNum; count++ { if atomic.AddInt64(curRunningJob, 1) > maxRunNum { cond.Wait() } go func(count uint64) { defer func() { atomic.AddInt64(curRunningJob, -1) cond.Notify() // 可以多次notify,所以不需要条件判断 //if atomic.AddInt64(curRunningJob, -1) < maxRunNum { // cond.Notify() //} }() runJob(job, count) }(count) } }