import { useEffect, useMemo, useState } from 'react';
import { first, max } from 'lodash';
import { useVersionControl } from '../../../core/VersionControlContext';
import {
  JobStatus,
  JobType,
  ScatterViz,
  ScatterVizDataState,
  VisualizationResponse,
} from '@tensorleap/api-client';

import { mapToEsFilters } from '../../../model-tests/modelTestHelpers';
import { VisualizationFilter } from '../../../core/types/filters';
import { useMergedObject } from '../../../core/useMergedObject';
import { useInsightsContext } from '../../../insights/InsightsContext';
import {
  useTriggerAndCheckPopulationExplorationStatus,
  DataStatus,
} from './useTriggerAndCheckPopulationExplorationStatus';
import { useFetchJson } from '../../../core/data-fetching/fetch-json';
import { LoadingStatus } from '../../../core/data-fetching/loading-status';

const REFRESH_INTERVAL_MS = 60000; // 1 minutes
const MAX_IS_LOADING_TIMEOUT_MS = 900000; // 15 minutes

export type UsePopulationExploration = {
  epochs: number[];
  selectedEpoch: number;
  selectEpoch: (epoch: number) => void;
} & UseFetchVisData;

type UsePopulationExplorationProps = {
  projectId: string;
  sessionRunId: string;
  dashletId: string;
  filters: VisualizationFilter[];
  projectionMetric?: string;
};

const DEFAULT_NUM_OF_SAMPLES = 2000;

export function usePopulationExploration({
  projectId,
  sessionRunId,
  dashletId,
  filters,
  projectionMetric,
}: UsePopulationExplorationProps) {
  const { getSessionRunEpochs } = useVersionControl();
  const epochs = useMemo(() => getSessionRunEpochs(sessionRunId), [
    sessionRunId,
    getSessionRunEpochs,
  ]);

  const [selectedEpoch, setSelectedEpoch] = useState<number>(0);
  const selected = useMemo(
    () => epochs.find((e) => e === selectedEpoch) ?? (max(epochs) || 0),
    [selectedEpoch, epochs]
  );

  const [numOfSamples, setNumOfSamples] = useState<number>(
    DEFAULT_NUM_OF_SAMPLES
  );

  const visData = useFetchVisData({
    projectId,
    sessionRunId,
    epoch: selected,
    numOfSamples,
    dashletId,
    filters,
    projectionMetric,
  });

  return {
    epochs,
    selectedEpoch: selected,
    selectEpoch: setSelectedEpoch,
    numOfSamples,
    setNumOfSamples,
    ...visData,
  };
}

type UseFetchVisData = {
  fullVisualization?: VisualizationResponse;
  loadingStatus: LoadingStatus;
  lastReadyDigest?: string;
};

interface useFetchVisDataParams {
  projectId: string;
  sessionRunId: string;
  epoch: number;
  dashletId: string;
  numOfSamples: number;
  filters: VisualizationFilter[];
  projectionMetric?: string;
}

function useFetchVisData({
  projectId,
  sessionRunId,
  epoch,
  dashletId,
  numOfSamples,
  filters,
  projectionMetric,
}: useFetchVisDataParams): UseFetchVisData {
  const { getSessionRunEpochs, selectedSessionRunMap } = useVersionControl();

  const populationParams = useMemo(
    () => ({
      sessionRunId,
      fromEpoch: epoch,
      batchSize: 1,
      numOfSamples,
      projectionMetric,
      filters: mapToEsFilters(filters),
    }),
    [sessionRunId, epoch, numOfSamples, projectionMetric, filters]
  );

  const dataStatus = useMemo(() => {
    const epochs = getSessionRunEpochs(sessionRunId);
    if (epochs.length === 0) return DataStatus.NoData;
    if (epoch < (max(epochs) || 0)) return DataStatus.Ready;
    // when we have epoch it means the job already started
    const jobs = selectedSessionRunMap.get(sessionRunId)?.jobs;
    if (!jobs?.length) return DataStatus.Ready;

    const isRunning = jobs.some(
      ({ status, type }) =>
        status === JobStatus.Started && type === JobType.Training
    );
    return isRunning ? DataStatus.Updating : DataStatus.Ready;
  }, [selectedSessionRunMap, sessionRunId, epoch, getSessionRunEpochs]);

  const { registerInsights, unregisterInsights } = useInsightsContext();

  const statusQueue = useTriggerAndCheckPopulationExplorationStatus({
    dataStatus,
    populationParams,
    projectId,
  });

  const {
    insightState,
    scatterState,
    scatterLoadingStatus,
    insightLoadingStatus,
  } = useMemo(() => {
    const insightState = statusQueue.find(
      ({ jobResult: status }) => status?.readyArtifacts.insights
    );
    const lastState = first(statusQueue);

    let scatterState = statusQueue.find(({ jobResult: status, digest }) =>
      digest === lastState?.digest
        ? status?.readyArtifacts.scatter &&
          status.readyArtifacts.scatterClusters
        : false
    );
    const ifNotFoundArtifactsWithCluster = !scatterState;
    if (ifNotFoundArtifactsWithCluster) {
      scatterState = statusQueue.find(
        ({ jobResult: status }) => status?.readyArtifacts.scatter
      );
    }

    const isScatterReady =
      lastState &&
      lastState.digest === scatterState?.digest &&
      scatterState.jobResult?.readyArtifacts.scatterClusters;
    const isScatterLoading = !scatterState;
    const isScatterUpdating =
      lastState?.baseDigest === scatterState?.baseDigest;

    const scatterLoadingStatus = calcLoadingStatus(
      !!isScatterReady,
      isScatterLoading,
      isScatterUpdating
    );

    const isInsightReady =
      scatterState && scatterState.digest === insightState?.digest;
    const isInsightLoading = !insightState;
    const isInsightUpdating =
      scatterState?.baseDigest === insightState?.baseDigest;

    const insightLoadingStatus = calcLoadingStatus(
      !!isInsightReady,
      isInsightLoading,
      isInsightUpdating
    );

    return {
      insightState,
      scatterState,
      scatterLoadingStatus,
      insightLoadingStatus,
    } as const;
  }, [statusQueue]);

  const insightBlobUrl = insightState?.jobResult?.readyArtifacts.insights;
  const csvBlobUrl = scatterState?.jobResult?.readyArtifacts.analysis;
  const insightBaseDigest =
    insightState?.baseDigest ?? scatterState?.baseDigest;

  useEffect(() => {
    if (!insightBaseDigest) return;
    const registerKey = {
      sessionRunId,
      epoch,
      digest: insightBaseDigest,
      dashletId,
    };
    registerInsights(
      registerKey,
      insightLoadingStatus,
      insightBlobUrl,
      csvBlobUrl
    );
    return () => unregisterInsights(registerKey);
  }, [
    insightBlobUrl,
    insightBaseDigest,
    insightLoadingStatus,
    csvBlobUrl,
    dashletId,
    sessionRunId,
    epoch,
    registerInsights,
    unregisterInsights,
  ]);

  const fullVisualization = useLoadAndCalcScatter(
    scatterState?.jobResult?.readyArtifacts.scatter,
    scatterState?.jobResult?.readyArtifacts.scatterClusters
  );

  return useMergedObject({
    fullVisualization,
    loadingStatus: scatterLoadingStatus,
    lastReadyDigest: scatterState?.baseDigest,
  });
}

function useLoadAndCalcScatter(
  scatterBlobUrl?: string,
  scatterClusterBlobUrl?: string
) {
  const { data: scatterVisualization } = useFetchJson<VisualizationResponse>({
    url: scatterBlobUrl,
    refreshIntervalAfterSuccessMS: REFRESH_INTERVAL_MS,
    maxIsLoadingTimeoutMs: MAX_IS_LOADING_TIMEOUT_MS,
  });

  const {
    data: scatterClusterVisualization,
  } = useFetchJson<VisualizationResponse>({
    url: scatterClusterBlobUrl,
    refreshIntervalAfterSuccessMS: REFRESH_INTERVAL_MS,
    maxIsLoadingTimeoutMs: MAX_IS_LOADING_TIMEOUT_MS,
  });

  /**
   * Combines scatter and scatter_cluster data to optimize load times.
   * Initial scatter.json lacks cluster data for quicker rendering, while scatter_cluster adds additional fields later.
   * This method maintains visual consistency during updates.
   * Metadata starts with scatter.json fields, expanded by scatter_cluster. Info and guid, used in filters, are based on scatter_cluster.
   */
  const fullVisualization = useMemo((): VisualizationResponse | undefined => {
    if (!scatterVisualization) {
      return undefined;
    }
    const clusterScatterData = (scatterClusterVisualization?.data.payload[0] as
      | ScatterViz
      | undefined)?.scatter_data;

    const scatterData: ScatterVizDataState = {
      ...clusterScatterData,
      ...(scatterVisualization?.data.payload[0] as ScatterViz).scatter_data,
      metadata: ((scatterClusterVisualization || scatterVisualization)?.data
        .payload[0] as ScatterViz).scatter_data.metadata,
    };
    const payload: ScatterViz = {
      ...scatterVisualization?.data.payload[0],
      guid: (scatterClusterVisualization || scatterVisualization).data
        .payload[0].guid,
      scatter_data: scatterData,
    };
    const info = (scatterClusterVisualization || scatterVisualization).info;
    const visualization: VisualizationResponse = {
      ...scatterVisualization,
      info,
      data: { ...scatterVisualization?.data, payload: [payload] },
    };
    return visualization;
  }, [scatterVisualization, scatterClusterVisualization]);

  return fullVisualization;
}

function calcLoadingStatus(
  isReady: boolean,
  isLoading: boolean,
  isUpdating: boolean
) {
  return isReady
    ? 'ready'
    : isLoading
    ? 'loading'
    : isUpdating
    ? 'updating'
    : 'refreshing';
}
