/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.TopicPartitionInfo;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.assignment.AssignmentConfigs;
import org.apache.kafka.streams.processor.assignment.ProcessId;
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.Graph;
import org.apache.kafka.streams.processor.internals.assignment.MinTrafficGraphConstructor;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructorFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RackAwareTaskAssignor {
    public static final int STATELESS_TRAFFIC_COST = 1;
    public static final int STATELESS_NON_OVERLAP_COST = 0;
    private static final Logger log = LoggerFactory.getLogger(RackAwareTaskAssignor.class);
    public static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
    private final Cluster fullMetadata;
    private final Map<TaskId, Set<TopicPartition>> partitionsForTask;
    private final Map<TaskId, Set<TopicPartition>> changelogPartitionsForTask;
    private final Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup;
    private final AssignmentConfigs assignmentConfigs;
    private final Map<TopicPartition, Set<String>> racksForPartition;
    private final Map<ProcessId, String> rackForProcess;
    private final InternalTopicManager internalTopicManager;
    private final boolean validClientRack;
    private final Time time;
    private Boolean canEnable = null;

    public RackAwareTaskAssignor(Cluster fullMetadata, Map<TaskId, Set<TopicPartition>> partitionsForTask, Map<TaskId, Set<TopicPartition>> changelogPartitionsForTask, Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup, Map<ProcessId, Map<String, Optional<String>>> racksForProcessConsumer, InternalTopicManager internalTopicManager, AssignmentConfigs assignmentConfigs, Time time) {
        this.fullMetadata = fullMetadata;
        this.partitionsForTask = partitionsForTask;
        this.changelogPartitionsForTask = changelogPartitionsForTask;
        this.tasksForTopicGroup = tasksForTopicGroup;
        this.internalTopicManager = internalTopicManager;
        this.assignmentConfigs = assignmentConfigs;
        this.racksForPartition = new HashMap<TopicPartition, Set<String>>();
        this.rackForProcess = new HashMap<ProcessId, String>();
        this.time = Objects.requireNonNull(time, "Time was not specified");
        this.validClientRack = RackAwareTaskAssignor.validateClientRack(racksForProcessConsumer, assignmentConfigs, this.rackForProcess);
    }

    public boolean validClientRack() {
        return this.validClientRack;
    }

    public synchronized boolean canEnableRackAwareAssignor() {
        if ("none".equals(this.assignmentConfigs.rackAwareAssignmentStrategy())) {
            return false;
        }
        if (this.canEnable != null) {
            return this.canEnable;
        }
        this.canEnable = this.validClientRack && this.validateTopicPartitionRack(false);
        if (this.assignmentConfigs.numStandbyReplicas() == 0 || !this.canEnable.booleanValue()) {
            return this.canEnable;
        }
        this.canEnable = this.validateTopicPartitionRack(true);
        return this.canEnable;
    }

    boolean populateTopicsToDescribe(Set<String> topicsToDescribe, boolean changelog) {
        if (changelog) {
            this.changelogPartitionsForTask.values().stream().flatMap(Collection::stream).forEach(tp -> topicsToDescribe.add(tp.topic()));
            return true;
        }
        for (Set<TopicPartition> topicPartitions : this.partitionsForTask.values()) {
            for (TopicPartition topicPartition : topicPartitions) {
                PartitionInfo partitionInfo = this.fullMetadata.partition(topicPartition);
                if (partitionInfo == null) {
                    log.error("TopicPartition {} doesn't exist in cluster", (Object)topicPartition);
                    return false;
                }
                Node[] replica = partitionInfo.replicas();
                if (replica == null || replica.length == 0) {
                    topicsToDescribe.add(topicPartition.topic());
                    continue;
                }
                for (Node node : replica) {
                    if (!node.hasRack()) {
                        log.warn("Node {} for topic partition {} doesn't have rack", (Object)node, (Object)topicPartition);
                        return false;
                    }
                    this.racksForPartition.computeIfAbsent(topicPartition, k -> new HashSet()).add(node.rack());
                }
            }
        }
        return true;
    }

    private boolean validateTopicPartitionRack(boolean changelogTopics) {
        HashSet<String> topicsToDescribe = new HashSet<String>();
        if (!this.populateTopicsToDescribe(topicsToDescribe, changelogTopics)) {
            return false;
        }
        if (!topicsToDescribe.isEmpty()) {
            log.info("Fetching PartitionInfo for topics {}", topicsToDescribe);
            try {
                Map<String, List<TopicPartitionInfo>> topicPartitionInfo = this.internalTopicManager.getTopicPartitionInfo(topicsToDescribe);
                if (topicsToDescribe.size() > topicPartitionInfo.size()) {
                    topicsToDescribe.removeAll(topicPartitionInfo.keySet());
                    log.error("Failed to describe topic for {}", topicsToDescribe);
                    return false;
                }
                for (Map.Entry<String, List<TopicPartitionInfo>> entry : topicPartitionInfo.entrySet()) {
                    List<TopicPartitionInfo> partitionInfos = entry.getValue();
                    for (TopicPartitionInfo partitionInfo : partitionInfos) {
                        int partition = partitionInfo.partition();
                        List replicas = partitionInfo.replicas();
                        if (replicas == null || replicas.isEmpty()) {
                            log.error("No replicas found for topic partition {}: {}", (Object)entry.getKey(), (Object)partition);
                            return false;
                        }
                        TopicPartition topicPartition = new TopicPartition(entry.getKey(), partition);
                        for (Node node : replicas) {
                            if (node.hasRack()) {
                                this.racksForPartition.computeIfAbsent(topicPartition, k -> new HashSet()).add(node.rack());
                                continue;
                            }
                            return false;
                        }
                    }
                }
            }
            catch (Exception e) {
                log.error("Failed to describe topics {}", topicsToDescribe, (Object)e);
                return false;
            }
        }
        return true;
    }

    public static boolean validateClientRack(Map<ProcessId, Map<String, Optional<String>>> racksForProcessConsumer, AssignmentConfigs assignmentConfigs, Map<ProcessId, String> rackForProcess) {
        if (racksForProcessConsumer == null) {
            return false;
        }
        for (Map.Entry<ProcessId, Map<String, Optional<String>>> entry : racksForProcessConsumer.entrySet()) {
            ProcessId processId = entry.getKey();
            KeyValue<String, String> previousRackInfo = null;
            for (Map.Entry<String, Optional<String>> rackEntry : entry.getValue().entrySet()) {
                if (!rackEntry.getValue().isPresent()) {
                    if (!"none".equals(assignmentConfigs.rackAwareAssignmentStrategy())) {
                        log.error(String.format("RackId doesn't exist for process %s and consumer %s", processId, rackEntry.getKey()));
                    }
                    return false;
                }
                if (previousRackInfo == null) {
                    previousRackInfo = KeyValue.pair(rackEntry.getKey(), rackEntry.getValue().get());
                    continue;
                }
                if (((String)previousRackInfo.value).equals(rackEntry.getValue().get())) continue;
                log.error(String.format("Consumers %s and %s for same process %s has different rackId %s and %s. File a ticket for this bug", previousRackInfo.key, rackEntry.getKey(), entry.getKey(), previousRackInfo.value, rackEntry.getValue().get()));
                return false;
            }
            if (previousRackInfo == null) {
                if (!"none".equals(assignmentConfigs.rackAwareAssignmentStrategy())) {
                    log.error(String.format("RackId doesn't exist for process %s", processId));
                }
                return false;
            }
            rackForProcess.put(entry.getKey(), (String)previousRackInfo.value);
        }
        return true;
    }

    public Map<ProcessId, String> racksForProcess() {
        return Collections.unmodifiableMap(this.rackForProcess);
    }

    public Map<TopicPartition, Set<String>> racksForPartition() {
        return Collections.unmodifiableMap(this.racksForPartition);
    }

    private int getCost(TaskId taskId, ProcessId processId, boolean inCurrentAssignment, int trafficCost, int nonOverlapCost, boolean isStandby) {
        Set<TopicPartition> topicPartitions;
        String clientRack = this.rackForProcess.get(processId);
        if (clientRack == null) {
            throw new IllegalStateException("Client " + String.valueOf(processId) + " doesn't have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
        }
        Set<TopicPartition> set = topicPartitions = isStandby ? this.changelogPartitionsForTask.get(taskId) : this.partitionsForTask.get(taskId);
        if (topicPartitions == null || topicPartitions.isEmpty()) {
            throw new IllegalStateException("Task " + String.valueOf(taskId) + " has no TopicPartitions");
        }
        int cost = 0;
        for (TopicPartition tp : topicPartitions) {
            Set<String> tpRacks = this.racksForPartition.get(tp);
            if (tpRacks == null || tpRacks.isEmpty()) {
                throw new IllegalStateException("TopicPartition " + String.valueOf(tp) + " has no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
            }
            if (tpRacks.contains(clientRack)) continue;
            cost += trafficCost;
        }
        if (!inCurrentAssignment) {
            cost += nonOverlapCost;
        }
        return cost;
    }

    long activeTasksCost(SortedSet<TaskId> activeTasks, SortedMap<ProcessId, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        return this.tasksCost(activeTasks, clientStates, trafficCost, nonOverlapCost, ClientState::hasActiveTask, false, false);
    }

    long standByTasksCost(SortedSet<TaskId> standbyTasks, SortedMap<ProcessId, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        return this.tasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost, ClientState::hasStandbyTask, true, true);
    }

    private long tasksCost(SortedSet<TaskId> tasks, SortedMap<ProcessId, ClientState> clientStates, int trafficCost, int nonOverlapCost, BiPredicate<ClientState, TaskId> hasAssignedTask, boolean hasReplica, boolean isStandby) {
        if (tasks.isEmpty()) {
            return 0L;
        }
        ArrayList<ProcessId> clientList = new ArrayList<ProcessId>(clientStates.keySet());
        ArrayList<TaskId> taskIdList = new ArrayList<TaskId>(tasks);
        Graph<Integer> graph = new MinTrafficGraphConstructor<ClientState>().constructTaskGraph(clientList, taskIdList, clientStates, new HashMap<TaskId, ProcessId>(), new HashMap<ProcessId, Integer>(), hasAssignedTask, this::getCost, trafficCost, nonOverlapCost, hasReplica, isStandby);
        return graph.totalCost();
    }

    public long optimizeActiveTasks(SortedSet<TaskId> activeTasks, SortedMap<ProcessId, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        if (activeTasks.isEmpty()) {
            return 0L;
        }
        log.info("Assignment before active task optimization is {}\n with cost {}", clientStates, (Object)this.activeTasksCost(activeTasks, clientStates, trafficCost, nonOverlapCost));
        long startTime = this.time.milliseconds();
        ArrayList<ProcessId> clientList = new ArrayList<ProcessId>(clientStates.keySet());
        ArrayList<TaskId> taskIdList = new ArrayList<TaskId>(activeTasks);
        HashMap<TaskId, ProcessId> taskClientMap = new HashMap<TaskId, ProcessId>();
        HashMap<ProcessId, Integer> originalAssignedTaskNumber = new HashMap<ProcessId, Integer>();
        RackAwareGraphConstructor<ClientState> graphConstructor = RackAwareGraphConstructorFactory.create(this.assignmentConfigs, this.tasksForTopicGroup);
        Graph<Integer> graph = graphConstructor.constructTaskGraph(clientList, taskIdList, clientStates, taskClientMap, originalAssignedTaskNumber, ClientState::hasActiveTask, this::getCost, trafficCost, nonOverlapCost, false, false);
        graph.solveMinCostFlow();
        long cost = graph.totalCost();
        graphConstructor.assignTaskFromMinCostFlow(graph, clientList, taskIdList, clientStates, originalAssignedTaskNumber, taskClientMap, ClientState::assignActive, ClientState::unassignActive, ClientState::hasActiveTask);
        long duration = this.time.milliseconds() - startTime;
        log.info("Assignment after {} milliseconds for active task optimization is {}\n with cost {}", new Object[]{duration, clientStates, cost});
        return cost;
    }

    public long optimizeStandbyTasks(SortedMap<ProcessId, ClientState> clientStates, int trafficCost, int nonOverlapCost, MoveStandbyTaskPredicate moveStandbyTask) {
        BiFunction<ClientState, ClientState, List> getMovableTasks = (source, destination) -> source.standbyTasks().stream().filter(task -> !destination.hasAssignedTask((TaskId)task)).filter(task -> moveStandbyTask.canMove((ClientState)source, (ClientState)destination, (TaskId)task, (Map<ProcessId, ClientState>)clientStates)).sorted().collect(Collectors.toList());
        long startTime = this.time.milliseconds();
        ArrayList<ProcessId> clientList = new ArrayList<ProcessId>(clientStates.keySet());
        TreeSet<TaskId> standbyTasks = new TreeSet<TaskId>();
        clientStates.values().forEach(clientState -> standbyTasks.addAll(clientState.standbyTasks()));
        log.info("Assignment before standby task optimization is {}\n with cost {}", clientStates, (Object)this.standByTasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost));
        boolean taskMoved = true;
        int round = 0;
        MinTrafficGraphConstructor<ClientState> graphConstructor = new MinTrafficGraphConstructor<ClientState>();
        while (taskMoved && round < 4) {
            taskMoved = false;
            ++round;
            for (int i = 0; i < clientList.size(); ++i) {
                ClientState clientState1 = (ClientState)clientStates.get(clientList.get(i));
                for (int j = i + 1; j < clientList.size(); ++j) {
                    String rack2;
                    ClientState clientState2 = (ClientState)clientStates.get(clientList.get(j));
                    String rack1 = this.rackForProcess.get(clientState1.processId());
                    if (rack1.equals(rack2 = this.rackForProcess.get(clientState2.processId()))) continue;
                    List movable1 = getMovableTasks.apply(clientState1, clientState2);
                    List movable2 = getMovableTasks.apply(clientState2, clientState1);
                    if (movable1.isEmpty() || movable2.isEmpty()) continue;
                    List<TaskId> taskIdList = Stream.concat(movable1.stream(), movable2.stream()).sorted().collect(Collectors.toList());
                    List<ProcessId> clients = Stream.of((ProcessId)clientList.get(i), (ProcessId)clientList.get(j)).sorted().collect(Collectors.toList());
                    HashMap<TaskId, ProcessId> taskClientMap = new HashMap<TaskId, ProcessId>();
                    HashMap<ProcessId, Integer> originalAssignedTaskNumber = new HashMap<ProcessId, Integer>();
                    Graph<Integer> graph = graphConstructor.constructTaskGraph(clients, taskIdList, clientStates, taskClientMap, originalAssignedTaskNumber, ClientState::hasStandbyTask, this::getCost, trafficCost, nonOverlapCost, true, true);
                    graph.solveMinCostFlow();
                    taskMoved |= graphConstructor.assignTaskFromMinCostFlow(graph, clients, taskIdList, clientStates, originalAssignedTaskNumber, taskClientMap, ClientState::assignStandby, ClientState::unassignStandby, ClientState::hasStandbyTask);
                }
            }
        }
        long cost = this.standByTasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost);
        long duration = this.time.milliseconds() - startTime;
        log.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}", new Object[]{round, duration, clientStates, cost});
        return cost;
    }

    @FunctionalInterface
    public static interface CostFunction {
        public int getCost(TaskId var1, ProcessId var2, boolean var3, int var4, int var5, boolean var6);
    }

    @FunctionalInterface
    public static interface MoveStandbyTaskPredicate {
        public boolean canMove(ClientState var1, ClientState var2, TaskId var3, Map<ProcessId, ClientState> var4);
    }
}

