import {
  createContext,
  FC,
  PropsWithChildren,
  useContext,
  useEffect,
  useMemo,
  useState,
} from 'react';
import {
  DataStateType,
  ScatterViz,
  ScatterVizDataState,
  SampleIdentity,
  VisualizationResponse,
  MutualInformationElement,
  NumberOrString,
} from '@tensorleap/api-client';

import { useMergedObject } from '../core/useMergedObject';
import {
  extractScatterSettingOptions,
  ScatterSettingValues,
  settingValuesWithDefault,
  useScattersAssets,
} from './dashlet/PopulationExploration/hooks';
import { useDashletScatterContext } from './dashlet/PopulationExploration/DashletScatterContext';
import { UseMultiSelect, useMultiSelect } from '../core/useMultiSelect';
import { Setter } from '../core/types';
import { SortTypeEnum } from '../ui/charts/legend/LabelsLegendMenu';
import { useToggle } from '../core/useToggle';
import { DEFAULT_TRUNCATE_LONG_TAIL } from '../ui/charts/legend/LabelsLegend';

export const VISUALIZATION_PAYLOAD_FILE_NAME = 'payload.json';

export enum Plane {
  XY = 'XY',
  XZ = 'XZ',
  YZ = 'YZ',
}

export type HoveredLegendFilter = {
  key: string;
  value: NumberOrString | NumberOrString[];
};

export interface ScatterDataContextProps {
  scatterData: ScatterVizDataState;
  epoch: number;
  sessionRunId: string;
  visualizationUUID: string;
  title: string;
  sample?: SampleIdentity;
  samplesIdsWithAssets: Set<string>;
  settings: ScatterSettingValues;
  settingsOptions: ReturnType<typeof extractScatterSettingOptions>;
  scatterSampleVisualizationsPrefix: string;
  visualizationDisplays: VisualizationDisplay[];
  selection: UseMultiSelect<number>;
  scatterMode: ScatterMode;
  miByCluster?: Record<string, Record<string, MutualInformationElement[]>>;
  clusterBlobPaths?: Record<string, Record<string, string>>;
  setScatterMode: Setter<ScatterMode>;
  setPressedScatterMode: Setter<ScatterMode | undefined>;
  showLegendNames: boolean;
  toggleShowLegendNames: () => void;
  legendTruncatedLongtail: number;
  setLegendTruncatedLongtail: Setter<number>;
  sizeOrShapeOrderMethod: SortTypeEnum;
  setSizeOrShapeOrderMethod: Setter<SortTypeEnum>;
  legendHovered?: HoveredLegendFilter;
  setLegendHovered: Setter<HoveredLegendFilter | undefined>;
}

export const contextDefaults: ScatterDataContextProps = {
  scatterData: {
    data_state: DataStateType.Test,
    scatter_data: [],
    samples: [],
    metadata: {},
  },
  title: 'Sample Selection',
  epoch: 0,
  sessionRunId: '',
  visualizationUUID: '',
  samplesIdsWithAssets: new Set(),
  scatterSampleVisualizationsPrefix: '',
  visualizationDisplays: [],
  settings: {
    sizeOrShape: '',
    dotColor: '',
    previewBy: undefined,
  },
  settingsOptions: {
    sizeOrShape: [],
    dotColor: [],
    previewBy: [],
  },
  selection: (null as unknown) as UseMultiSelect<number>,
  scatterMode: 'grab',
  setScatterMode: () => {},
  setPressedScatterMode: () => {},
  showLegendNames: false,
  toggleShowLegendNames: () => {},
  legendTruncatedLongtail: DEFAULT_TRUNCATE_LONG_TAIL,
  setLegendTruncatedLongtail: () => {},
  sizeOrShapeOrderMethod: SortTypeEnum.ASC_ALPHABETICALLY,
  setSizeOrShapeOrderMethod: () => {},
  setLegendHovered: () => {},
};
const ScatterDataContext = createContext<ScatterDataContextProps>(
  contextDefaults
);

export type ScatterMode = 'grab' | 'box-selection' | 'magic-selection';

type VisualizationDisplay = { visType: string; visName: string };

type ScatterDataProviderProps = PropsWithChildren<{
  projectId: string;
  epoch: number;
  sessionRunId: string;
  scatterVisualization: VisualizationResponse;
}>;

export const ScatterDataProvider: FC<ScatterDataProviderProps> = ({
  children,
  projectId,
  epoch,
  sessionRunId,
  scatterVisualization,
}): JSX.Element => {
  const {
    scatterSampleVisualizationsPrefix,
    visualizationDisplays,
    samplesIdsWithAssets,
  } = useScattersAssets({ projectId, sessionRunId, epoch });
  const payload = useMemo(
    () => scatterVisualization.data.payload[0] as ScatterViz,
    [scatterVisualization.data.payload]
  );

  const { settingsValues, register, unregister } = useDashletScatterContext();
  const scatterData = payload.scatter_data;
  const [_scatterMode, setScatterMode] = useState<ScatterMode>('grab');
  const [pressedScatterMode, setPressedScatterMode] = useState<ScatterMode>();
  const scatterMode = pressedScatterMode || _scatterMode;
  const selection = useMultiSelect<number>();

  const settingsOptions = useMemo(
    () => extractScatterSettingOptions(payload, visualizationDisplays || []),
    [payload, visualizationDisplays]
  );

  const settings = useMemo(
    () => settingValuesWithDefault(settingsValues, settingsOptions),
    [settingsValues, settingsOptions]
  );

  const [showLegendNames, toggleShowLegendNames] = useToggle(true);
  const [legendHovered, setLegendHovered] = useState<HoveredLegendFilter>();
  const [
    legendTruncatedLongtail,
    setLegendTruncatedLongtail,
  ] = useState<number>(DEFAULT_TRUNCATE_LONG_TAIL);
  const [sizeOrShapeOrderMethod, setSizeOrShapeOrderMethod] = useState(
    SortTypeEnum.ASC_ALPHABETICALLY
  );

  useEffect(() => {
    register(payload.guid, settingsOptions);
    return () => unregister(payload.guid);
  }, [settingsOptions, payload.guid, register, unregister]);

  const value = useMergedObject({
    visualizationUUID: payload.guid,
    clusterBlobPaths: payload.scatter_data.clusters_blob_path,
    miByCluster: payload.scatter_data?.mi_by_cluster,
    scatterData,
    epoch,
    sessionRunId,
    settings,
    settingsOptions,
    title:
      /**
       * NOTE:
       * This is only done because of the storybook mocks being incomplete,
       * once they'll be updated to real `node-server` responses then this'll be removed.
       */
      scatterVisualization.info?.analyze_type?.replace(/_/g, ' ') ||
      contextDefaults.title,
    samplesIdsWithAssets,
    scatterSampleVisualizationsPrefix,
    visualizationDisplays,
    selection,
    scatterMode,
    setScatterMode,
    setPressedScatterMode,
    showLegendNames,
    toggleShowLegendNames,
    legendTruncatedLongtail,
    setLegendTruncatedLongtail,
    sizeOrShapeOrderMethod,
    setSizeOrShapeOrderMethod,
    legendHovered,
    setLegendHovered,
  });
  return (
    <ScatterDataContext.Provider value={value}>
      {children}
    </ScatterDataContext.Provider>
  );
};
ScatterDataContext.displayName = 'ScatterDataContext';

export function useScatterData(): ScatterDataContextProps {
  return useContext(ScatterDataContext);
}
