import {
  DatasetVersion,
  InputData,
  ModelGraph,
  Node,
  OutputData,
} from '@tensorleap/api-client';
import { Connection } from './interfaces/Connection';
import { Position } from '../core/position';
import { NodeWithLabels } from './descriptor/types';
import { reorganizeMap } from './autoorganize';
import { isKindOfLossNode } from './graph-calculation/utils';

export type NodePropertyType =
  | string
  | number
  | number[]
  | string[]
  | boolean
  | undefined
  | DatasetVersion;

export type ChangeNodePropFunc = (_: {
  nodeId: string;
  nodeDataPropsToUpdate: Record<string, NodePropertyType>;
  nodePropsToUpdate?: Pick<NodeWithLabels, 'labels'>;
  override?: boolean;
}) => void;

export type UpdateConnectionFunc = (
  nodeId: string,
  currentInputNames?: string[],
  currentOutputNames?: string[]
) => void;

export function addPositions(position: Position, offest: Position): Position {
  return [position[0] + offest[0], position[1] + offest[1]];
}

export function createOldReteData(
  nodesMap: ROMap<string, Node>,
  connections: Connection[]
): ModelGraph {
  const nodes = Array.from(nodesMap.values()).reduce((ret, node) => {
    ret[node.id] = {
      ...node,
      inputs: {},
      outputs: {},
    };
    return ret;
  }, {} as ModelGraph['nodes']);
  connections.forEach(
    ({ outputNodeId, outputName, inputNodeId, inputName }) => {
      const outputId = isModelGraphLayer(outputNodeId)
        ? outputName
        : `${outputNodeId}-${outputName}`;
      const inputId = isModelGraphLayer(inputNodeId)
        ? inputName
        : `${inputNodeId}-${inputName}`;

      const { outputs = {} } = nodes[outputNodeId] || {};
      const { inputs = {} } = nodes[inputNodeId] || {};

      if (!outputs[outputId]) {
        outputs[outputId] = { connections: [] };
      }
      if (!inputs[inputId]) {
        inputs[inputId] = { connections: [] };
      }

      outputs[outputId].connections.push({
        node: inputNodeId,
        input: inputId,
      });
      inputs[inputId].connections.push({
        node: outputNodeId,
        output: outputId,
      });
    }
  );
  return {
    id: 'vizard@0.1.0',
    nodes,
  };
}

export function getLayersConnectedToLoss(
  connections: Connection[],
  nodes: ROMap<string, Node>
): Node[] {
  const layerNodes = new Set<Node>();
  for (const connection of connections) {
    const inputNode = nodes.get(connection.inputNodeId);
    const outputNode = nodes.get(connection.outputNodeId);

    if (
      inputNode &&
      outputNode &&
      isKindOfLossNode(inputNode) &&
      isLayerTypeNode(outputNode)
    ) {
      layerNodes.add(outputNode);
    }
  }

  return Array.from(layerNodes);
}

type calcConnectionsToUpdateProps = {
  nodeId: string;

  connections: Connection[];
  currentOutputNames?: string[];
  currentInputNames?: string[];
};

type calcConnectionsToUpdateResult = {
  previous: Connection[];
  added: Connection[];
  removed: Connection[];
};

export function calcConnectionsToUpdate({
  connections,
  nodeId,
  currentInputNames,
  currentOutputNames,
}: calcConnectionsToUpdateProps): calcConnectionsToUpdateResult {
  const added: Connection[] = [];
  const removed: Connection[] = [];

  const [previous, inputConnections, outputConnections] = connections.reduce(
    ([previous, inputConnections, outputConnections], conn) => {
      if (conn.inputNodeId === nodeId && currentInputNames)
        inputConnections.push(conn);
      else if (conn.outputNodeId === nodeId && currentOutputNames)
        outputConnections.push(conn);
      else previous.push(conn);
      return [previous, inputConnections, outputConnections];
    },
    [[], [], []] as [Connection[], Connection[], Connection[]]
  );

  currentInputNames &&
    calcInputOrOutputToUpdate({
      connections: inputConnections,
      currentSocketNames: currentInputNames,
      socketType: 'input',
      previous,
      added,
      removed,
    });
  currentOutputNames &&
    calcInputOrOutputToUpdate({
      connections: outputConnections,
      currentSocketNames: currentOutputNames,
      socketType: 'output',
      previous,
      added,
      removed,
    });

  return { previous, added, removed };
}

function calcInputOrOutputToUpdate({
  connections,
  currentSocketNames,
  socketType,
  added,
  removed,
  previous,
}: {
  connections: Connection[];
  currentSocketNames: string[];
  socketType: 'output' | 'input';
} & calcConnectionsToUpdateResult) {
  const inputOrOutputName = socketType == 'output' ? 'outputName' : 'inputName';
  const isConnectionsHasSameSocketName =
    new Set(connections.map((c) => c[inputOrOutputName])).size <= 1;

  for (const conn of connections) {
    const existed = currentSocketNames.includes(conn[inputOrOutputName]);
    const replaceOneByOne =
      isConnectionsHasSameSocketName && currentSocketNames.length === 1;

    if (existed) previous.push(conn);
    else if (replaceOneByOne) {
      removed.push(conn);
      added.push({ ...conn, [inputOrOutputName]: currentSocketNames[0] });
    } else {
      removed.push(conn);
    }
  }
}

export const MODEL_LAYER_ID = '0';

type FirstLayer = { node: Node; inputsKeys: Set<string> };
export function modelGraphToGroupedModelGraph(
  modelGraph: ModelGraph
): ModelGraph {
  const groupedModelGraph: ModelGraph = { id: '', nodes: {} };
  const firstLayers: FirstLayer[] = [];
  const lastLayers: Node[] = [];
  const hiddenLayersWithDecorators: Node[] = [];

  const hiddenLayers = new Set<string>();
  Object.values(modelGraph.nodes).forEach((node) => {
    if (isLayerTypeNode(node)) {
      hiddenLayers.add(node.id);
      if (node.data['output_blocks']) return;
      handleFirstLayer(node, modelGraph, firstLayers);
      if (isLastLayer(node, modelGraph)) {
        lastLayers.push(node);
        groupedModelGraph.nodes[node.id] = structuredClone(node);
        hiddenLayers.delete(node.id);
      }
      if (isHiddenLayerWithConnectedDocorator(node, modelGraph))
        hiddenLayersWithDecorators.push(node);
    } else {
      groupedModelGraph.nodes[node.id] = structuredClone(node);
    }
  });

  attachInputsLayers(firstLayers, groupedModelGraph);

  attachOutputLayers(lastLayers, groupedModelGraph, modelGraph);

  attachHiddenLayersWithDecorators(
    hiddenLayersWithDecorators,
    groupedModelGraph,
    modelGraph
  );

  const modelLayer = groupedModelGraph.nodes[MODEL_LAYER_ID];
  if (modelLayer) {
    modelLayer.data['hidden_layers'] = hiddenLayers;
  }

  groupedModelGraph.nodes = reorganizeMap(groupedModelGraph.nodes);

  return groupedModelGraph;
}

function attachInputsLayers(
  firstLayers: FirstLayer[],
  groupedModelGraph: ModelGraph
): void {
  firstLayers.forEach(({ node: firstInputLayer, inputsKeys }) => {
    if (groupedModelGraph.nodes[firstInputLayer.id]) return;

    generateDefaultModelLayer(groupedModelGraph);
    const modelLayer = groupedModelGraph.nodes[MODEL_LAYER_ID];
    const firstInputLayerKeys = Object.keys(firstInputLayer.inputs);

    if (firstInputLayerKeys.length === 0) {
      const inputKey = `${firstInputLayer.id}-input${
        Object.keys(modelLayer.inputs).length
      }`;
      modelLayer.data['inputsInfo'][inputKey] = {
        nodeId: firstInputLayer.id,
        key: inputKey,
      };
      modelLayer.inputs[inputKey] = { connections: [] };
      return;
    }

    firstInputLayerKeys.forEach((firstInputLayerKey) => {
      if (!inputsKeys.has(firstInputLayerKey)) return;
      modelLayer.inputs[firstInputLayerKey] =
        firstInputLayer.inputs[firstInputLayerKey];

      modelLayer.data['inputsInfo'][firstInputLayerKey] = {
        nodeId: firstInputLayer.id,
        key: firstInputLayerKey,
      };
      modelLayer.inputs[firstInputLayerKey].connections.forEach(
        ({ node, output }) => {
          const inputNode = groupedModelGraph.nodes[node];
          const { connections } = { ...inputNode.outputs[output] };
          inputNode.outputs[output].connections = connections.map(
            ({ node, input }) => ({
              node: firstInputLayer.id === node ? MODEL_LAYER_ID : node,
              input,
            })
          );
          groupedModelGraph.nodes[node] = inputNode;
        }
      );
    });
  });
}

function attachOutputLayers(
  lastLayers: Node[],
  groupedModelGraph: ModelGraph,
  originModelGraph: ModelGraph
): void {
  lastLayers.forEach((outputLayer) => {
    const outputLayerInputKeys = Object.keys(outputLayer.inputs);
    outputLayerInputKeys.forEach((outputLayerInputKey) => {
      const { connections } = outputLayer.inputs[outputLayerInputKey];

      const inputKeyConnections: InputData[] = [];
      connections.forEach(({ node, output }) => {
        const inputNode = originModelGraph.nodes[node];
        if (!isLayerTypeNode(inputNode)) {
          inputKeyConnections.push({ node, output });
          return;
        }

        generateDefaultModelLayer(groupedModelGraph);
        const modelLayer = groupedModelGraph.nodes[MODEL_LAYER_ID];
        if (!modelLayer.outputs[output]) {
          modelLayer.outputs[output] = { connections: [] };
        }

        const { connections: modelOutputKeyConnections } = modelLayer.outputs[
          output
        ];

        modelOutputKeyConnections.push({
          node: outputLayer.id,
          input: outputLayerInputKey,
        });

        inputKeyConnections.push({
          node: MODEL_LAYER_ID,
          output,
        });
      });

      groupedModelGraph.nodes[outputLayer.id].inputs[outputLayerInputKey] = {
        connections: inputKeyConnections,
      };
    });
  });
}

function attachHiddenLayersWithDecorators(
  hiddenLayersWithDecorators: Node[],
  groupedModelGraph: ModelGraph,
  originModelGraph: ModelGraph
): void {
  hiddenLayersWithDecorators.forEach((hiddenLayersWithDecorators) => {
    const layerOutputs = Object.keys(hiddenLayersWithDecorators.outputs);

    layerOutputs.forEach((outputKey) => {
      const { connections } = hiddenLayersWithDecorators.outputs[outputKey];
      return connections.forEach(({ node: nodeOutputKey, input }) => {
        if (isLayerTypeNode(originModelGraph.nodes[nodeOutputKey])) return;

        generateDefaultModelLayer(groupedModelGraph);
        const modelLayer = groupedModelGraph.nodes[MODEL_LAYER_ID];
        if (!modelLayer.outputs[outputKey]) {
          modelLayer.outputs[outputKey] = { connections: [] };
        }

        const { connections: modelOutputKeyConnections } = modelLayer.outputs[
          outputKey
        ];

        modelOutputKeyConnections.push({
          node: nodeOutputKey,
          input: input,
        });

        const inputKeyConnections: InputData[] = [];
        const { connections: decoratorConnections } = originModelGraph.nodes[
          nodeOutputKey
        ].inputs[input];

        decoratorConnections.forEach(({ output }) => {
          inputKeyConnections.push({
            node: MODEL_LAYER_ID,
            output,
          });
        });

        groupedModelGraph.nodes[nodeOutputKey].inputs[
          input
        ].connections = inputKeyConnections;
      });
    });
  });
}

function generateDefaultModelLayer(groupedModelGraph: ModelGraph): void {
  if (!groupedModelGraph.nodes[MODEL_LAYER_ID]) {
    groupedModelGraph.nodes[MODEL_LAYER_ID] = {
      id: MODEL_LAYER_ID,
      data: { type: 'Layer', inputsInfo: {} },
      name: 'Model',
      position: [0, 0],
      inputs: {},
      outputs: {},
    };
  }
}

export function groupedModelGraphToModelGraph(
  groupedModelGraph: ModelGraph,
  lastFullModelGraph: ModelGraph
): ModelGraph {
  const fullModelGraph: ModelGraph = { id: '', nodes: {} };
  Object.values(groupedModelGraph.nodes).forEach((node) => {
    if (node.id !== MODEL_LAYER_ID)
      fullModelGraph.nodes[node.id] = structuredClone(node);
  });

  const modelLayer = groupedModelGraph.nodes[MODEL_LAYER_ID];
  if (!modelLayer) return fullModelGraph;

  const hiddenLayers: Set<string> = modelLayer.data['hidden_layers'];

  Object.values(lastFullModelGraph.nodes).forEach((node) => {
    if (
      !fullModelGraph.nodes[node.id] &&
      isLayerTypeNode(node) &&
      hiddenLayers.has(node.id)
    )
      fullModelGraph.nodes[node.id] = { ...node };
  });

  const { inputsInfo } = modelLayer.data;

  restoreInputsConnections(modelLayer, inputsInfo, fullModelGraph);

  restoreOutputsConnections(modelLayer, fullModelGraph, hiddenLayers);

  fullModelGraph.nodes = reorganizeMap(fullModelGraph.nodes);

  return fullModelGraph;
}

function restoreOutputsConnections(
  modelLayer: Node,
  fullModelGraph: ModelGraph,
  hiddenLayers: Set<string>
): void {
  Object.keys(modelLayer.outputs).forEach((modelOutputKey) => {
    const modelLayerOrgOutput = modelLayer?.outputs[modelOutputKey];

    const [originNodeId] = modelOutputKey.split('-');
    const fullModelGraphOrgOutput =
      fullModelGraph?.nodes[originNodeId]?.outputs[modelOutputKey];

    const orgLayerOutConnections: OutputData[] = [];
    fullModelGraphOrgOutput?.connections?.forEach(({ node, input }) => {
      if (hiddenLayers.has(node)) {
        orgLayerOutConnections.push({ input, node });
      }
    });

    modelLayerOrgOutput?.connections?.forEach(({ input, node }) => {
      const outputNode = fullModelGraph.nodes[node];
      if (!outputNode) {
        return;
      }

      orgLayerOutConnections.push({ input, node });
      const [inputConnection] = outputNode.inputs[input].connections;
      if (inputConnection.node === MODEL_LAYER_ID)
        outputNode.inputs[input] = {
          connections: [{ node: originNodeId, output: modelOutputKey }],
        };
    });

    if (fullModelGraph?.nodes?.[originNodeId]?.outputs?.[modelOutputKey]) {
      fullModelGraph.nodes[originNodeId].outputs[modelOutputKey] = {
        connections: orgLayerOutConnections,
      };
    }
  });
}

function restoreInputsConnections(
  modelLayer: Node,
  inputsInfo: Record<string, { nodeId: string; key: string }>,
  fullModelGraph: ModelGraph
): void {
  Object.keys(modelLayer.inputs).forEach((modelInputKey) => {
    const { nodeId, key } = inputsInfo[modelInputKey];

    if (fullModelGraph?.nodes?.[nodeId]?.inputs) {
      fullModelGraph.nodes[nodeId].inputs[key] =
        modelLayer.inputs[modelInputKey];
    }
    const { connections } = modelLayer.inputs[modelInputKey];
    connections.forEach(({ node, output }) => {
      const { connections: inputLayerOutputConnections } = fullModelGraph.nodes[
        node
      ].outputs[output];

      fullModelGraph.nodes[node].outputs[
        output
      ].connections = inputLayerOutputConnections.map(
        ({ input: inputLayerKey, node: targetNodeId }) => {
          if (isModelGraphLayer(targetNodeId)) {
            return {
              node: nodeId,
              input: key,
            };
          }

          return {
            node: targetNodeId,
            input: inputLayerKey,
          };
        }
      );
    });
  });
}

function handleFirstLayer(
  node: Node,
  modelGraph: ModelGraph,
  firstLayers: FirstLayer[]
): void {
  const layerInputs = Object.keys(node.inputs);
  if (layerInputs.length === 0) {
    firstLayers.push({ node, inputsKeys: new Set() });
    return;
  }

  const inputsKeys = new Set<string>();
  layerInputs.forEach((inputKey) => {
    const { connections } = node.inputs[inputKey];
    connections.forEach(({ node: nodeInputKey }) => {
      if (!isLayerTypeNode(modelGraph.nodes[nodeInputKey])) {
        inputsKeys.add(inputKey);
      }
    });
  });

  if (inputsKeys.size > 0) {
    firstLayers.push({ node, inputsKeys });
  }
}

function isLayerTypeNode(node: Node): boolean {
  return (
    node.data?.type === 'Layer' ||
    node.data?.type === 'CustomLayer' ||
    node.name === 'Representation Block'
  );
}

function isLastLayer(node: Node, modelGraph: ModelGraph): boolean {
  const layerOutputs = Object.keys(node.outputs);
  if (layerOutputs.length === 0) return true;

  const hasLayerInputs = layerOutputs.some((outputKey) => {
    const { connections } = node.outputs[outputKey];
    return connections.some(({ node: nodeOutputKey }) => {
      return isLayerTypeNode(modelGraph.nodes[nodeOutputKey]);
    });
  });

  return !hasLayerInputs;
}

function isHiddenLayerWithConnectedDocorator(
  node: Node,
  modelGraph: ModelGraph
): boolean {
  if (isLastLayer(node, modelGraph)) return false;

  const layerOutputs = Object.keys(node.outputs);
  const hasDecoratorInputs = layerOutputs.some((outputKey) => {
    const { connections } = node.outputs[outputKey];
    return connections.some(({ node: nodeOutputKey }) => {
      return (
        modelGraph.nodes[nodeOutputKey] &&
        !isLayerTypeNode(modelGraph.nodes[nodeOutputKey])
      );
    });
  });

  return hasDecoratorInputs;
}

export function isModelGraphLayer(nodeId: string): boolean {
  return nodeId === MODEL_LAYER_ID;
}
