/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.write;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.common.write.PushStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SlowStartPushStrategy
extends PushStrategy {
    private static final Logger logger = LoggerFactory.getLogger(SlowStartPushStrategy.class);
    private final int maxInFlightPerWorker;
    private final long initialSleepMills;
    private final long maxSleepMills;
    private final ConcurrentHashMap<String, CongestControlContext> congestControlInfoPerAddress;

    public SlowStartPushStrategy(CelebornConf conf) {
        super(conf);
        this.maxInFlightPerWorker = conf.clientPushMaxReqsInFlightPerWorker();
        this.initialSleepMills = conf.clientPushSlowStartInitialSleepTime();
        this.maxSleepMills = conf.clientPushSlowStartMaxSleepMills();
        this.congestControlInfoPerAddress = JavaUtils.newConcurrentHashMap();
    }

    @VisibleForTesting
    protected CongestControlContext getCongestControlContextByAddress(String hostAndPushPort) {
        return this.congestControlInfoPerAddress.computeIfAbsent(hostAndPushPort, host -> new CongestControlContext(this.maxInFlightPerWorker));
    }

    @Override
    public void onSuccess(String hostAndPushPort) {
        CongestControlContext congestControlContext = this.getCongestControlContextByAddress(hostAndPushPort);
        congestControlContext.increaseCurrentMaxReqs();
    }

    @Override
    public void onCongestControl(String hostAndPushPort) {
        CongestControlContext congestControlContext = this.getCongestControlContextByAddress(hostAndPushPort);
        congestControlContext.decreaseCurrentMaxReqs();
    }

    protected long getSleepTime(CongestControlContext context) {
        int currentMaxReqs = context.getCurrentMaxReqsInFlight();
        if (currentMaxReqs >= this.maxInFlightPerWorker) {
            return 0L;
        }
        long sleepInterval = this.initialSleepMills - 60L * (long)currentMaxReqs;
        if (currentMaxReqs == 1) {
            return Math.min(sleepInterval + (long)context.getContinueCongestedNumber() * 1000L, this.maxSleepMills);
        }
        return Math.max(sleepInterval, 0L);
    }

    @Override
    public void limitPushSpeed(PushState pushState, String hostAndPushPort) throws IOException {
        if (pushState.exception.get() != null) {
            throw pushState.exception.get();
        }
        CongestControlContext congestControlContext = this.getCongestControlContextByAddress(hostAndPushPort);
        long sleepInterval = this.getSleepTime(congestControlContext);
        if (sleepInterval > 0L) {
            try {
                logger.debug("Will sleep {} ms to control the push speed to {}.", (Object)sleepInterval, (Object)hostAndPushPort);
                Thread.sleep(sleepInterval);
            }
            catch (InterruptedException e) {
                pushState.exception.set(new CelebornIOException(e));
            }
        }
    }

    @Override
    public int getCurrentMaxReqsInFlight(String hostAndPushPort) {
        return this.getCongestControlContextByAddress(hostAndPushPort).getCurrentMaxReqsInFlight();
    }

    @Override
    public void clear() {
        this.congestControlInfoPerAddress.clear();
    }

    protected static class CongestControlContext {
        private final AtomicInteger currentMaxReqsInFlight = new AtomicInteger(1);
        private final AtomicInteger continueCongestedNumber = new AtomicInteger(0);
        private int congestionAvoidanceFlag = 0;
        private int reqsInFlightBlockThreshold;

        public CongestControlContext(int reqsInFlightBlockThreshold) {
            this.reqsInFlightBlockThreshold = reqsInFlightBlockThreshold;
        }

        public synchronized void increaseCurrentMaxReqs() {
            this.continueCongestedNumber.set(0);
            if (this.currentMaxReqsInFlight.get() >= this.reqsInFlightBlockThreshold) {
                ++this.congestionAvoidanceFlag;
                if (this.congestionAvoidanceFlag >= this.currentMaxReqsInFlight.get()) {
                    this.currentMaxReqsInFlight.incrementAndGet();
                    this.congestionAvoidanceFlag = 0;
                }
            } else {
                this.currentMaxReqsInFlight.incrementAndGet();
            }
        }

        public synchronized void decreaseCurrentMaxReqs() {
            if (this.currentMaxReqsInFlight.get() <= 1) {
                this.currentMaxReqsInFlight.set(1);
                this.continueCongestedNumber.incrementAndGet();
            } else {
                this.currentMaxReqsInFlight.updateAndGet(pre -> pre / 2);
            }
            this.reqsInFlightBlockThreshold = this.currentMaxReqsInFlight.get();
            this.congestionAvoidanceFlag = 0;
        }

        public int getCurrentMaxReqsInFlight() {
            return this.currentMaxReqsInFlight.get();
        }

        public int getContinueCongestedNumber() {
            return this.continueCongestedNumber.get();
        }
    }
}

