import type { Edge, NodeId } from "../widget/types";
import { uniq } from "ramda";

interface Neighbours {
  getPredecessors: (partIds: NodeId[], levels?: number[]) => NodeId[];
  getSuccessors: (partIds: NodeId[], levels?: number[]) => NodeId[];
}

export const useNeighbours = (edges: Edge[]): Neighbours => {
  const getNeighbours = (
    direction: "predecessors" | "successors",
    partIds: NodeId[],
    levels?: number[],
  ): NodeId[] => {
    const maxLevel =
      levels === undefined || levels.length === 0
        ? Infinity
        : Math.max(...levels);

    const levelsMap = new Map<number, Set<NodeId>>();
    levelsMap.set(0, new Set(partIds));

    const getPredecessorsSingle = (level: number, partId: NodeId) => {
      if (!levelsMap.has(level)) {
        levelsMap.set(level, new Set<NodeId>());
      }

      const partialResult = new Set<NodeId>();
      edges.forEach((edge) => {
        switch (direction) {
          case "predecessors":
            if (partId === edge.from) {
              partialResult.add(edge.to);
              levelsMap.get(level)?.add(edge.to);
            }
            break;
          case "successors":
            if (partId === edge.to) {
              partialResult.add(edge.from);
              levelsMap.get(level)?.add(edge.from);
            }
            break;
          default:
        }
      });
      if (level < maxLevel) {
        partialResult.forEach((partId) =>
          getPredecessorsSingle(level + 1, partId),
        );
      }
    };

    if (maxLevel > 0) {
      partIds.forEach((partId) => getPredecessorsSingle(1, partId));
    }

    const finalResult = new Array<NodeId>();
    if (levels === undefined) {
      levelsMap.forEach((value) => finalResult.push(...value));
    } else {
      levels.forEach((level) => {
        const levelResult = levelsMap.get(level) || [];
        finalResult.push(...levelResult);
      });
    }
    return uniq(finalResult);
  };

  return {
    getPredecessors: (partIds, levels?) =>
      getNeighbours("predecessors", partIds, levels),
    getSuccessors: (partIds, levels?) =>
      getNeighbours("successors", partIds, levels),
  };
};
