package com.mushiny.task;

import java.util.Collection;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;

/**
 * @author lihao
 */
public class TaskExecutor<T, U, R> {

    /**
     * 线程池
     */
    private ExecutorService pool;

    /**
     * 执行的任务数据
     */
    private TaskDefinition<T, U, R> task;

    private CountDownLatch countDownLatch;

    private Long countDownTimeOut;

    private TimeUnit timeUnit;

    public TaskExecutor(ThreadPoolExecutor pool, Long countDownTimeOut, TimeUnit timeUnit) {
        this.pool = pool;
        this.countDownTimeOut = countDownTimeOut;
        this.timeUnit = timeUnit;
    }


    /**
     * 初始化函数
     */
    public void init(TaskDefinition<T, U, R> task) {
        this.task = task;
    }

    public void init(BiFunction<T, U, R> handler, Collection<T> data, U context) {
        init(new TaskDefinition<>(data, handler, context));
    }

    /**
     * 执行函数
     */
    public boolean execute() {
        TaskDefinition<T, U, R> t = this.task;
        Collection<T> data = t.getData();
        countDownLatch = new CountDownLatch(data.size());
        AtomicInteger failCount = new AtomicInteger();
        data.forEach(d -> pool.execute(() -> {
            try {
                Object apply = t.getHandler().apply(d, t.getContext());
            } catch (Exception ex) {
                failCount.addAndGet(1);
            } finally {
                countDownLatch.countDown();
            }
        }));
        try {
            countDownLatch.await(countDownTimeOut, timeUnit);
            return failCount.get() <= 0;
        } catch (InterruptedException e) {
            return false;
        } finally {
            this.destroy();
        }
    }

    /**
     * 销毁函数
     */
    public void destroy() {
        this.task = null;
        this.countDownLatch = null;
        this.pool = null;
    }





}
