import { FormEvent, useCallback, useEffect, useMemo } from 'react';
import { useController, useForm } from 'react-hook-form';
import { PopulationExplorationParams } from '@tensorleap/api-client';
import { createStyles, makeStyles } from '../../../ui/mui';
import { useCurrentProject } from '../../../core/CurrentProjectContext';
import { PaneActions } from './PaneActions';
import { useEpochRange, useSelectedSessionRuns } from './useSelectedModels';
import { Input } from '../../../ui/atoms/Input';
import { Select } from '../../../ui/atoms/Select';
import api from '../../../core/api-client';
import clsx from 'clsx';
import { useVersionControl } from '../../../core/VersionControlContext';

const useStyles = makeStyles(() =>
  createStyles({
    childMargin: {
      '& > *': { marginBottom: 10 },
    },
  })
);

export interface PopulationExplorationInputs {
  model: string;
  amount: number;
  batchSize: number;
  fromEpoch: number;
}
export interface PopulationExplorationProps {
  closeTooltip: () => void;
}
export function PopulationExploration({
  closeTooltip,
}: PopulationExplorationProps): JSX.Element {
  const classes = useStyles();
  const {
    currentVersion: { cid: versionId } = {},
    fetchValidProjectCid,
  } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const selectedSessionRuns = useSelectedSessionRuns();
  const { sessionRunsToSessionMap } = useVersionControl();
  const { getEpochRange, getLastEpoch } = useEpochRange();

  const defaultSession = useMemo(() => {
    if (!selectedSessionRuns.length) {
      return undefined;
    }
    const defaultSessionRun = selectedSessionRuns[0];
    const defaultSession = sessionRunsToSessionMap.get(defaultSessionRun.cid);
    return defaultSession;
  }, [selectedSessionRuns, sessionRunsToSessionMap]);

  const {
    control,
    handleSubmit,
    register,
    trigger,
    formState: { errors },
  } = useForm<PopulationExplorationInputs>({
    mode: 'onBlur',
  });
  const { field: modelField } = useController({
    control,
    name: 'model',
    defaultValue: selectedSessionRuns[0]?.cid || '',
  });
  const { field: amountField } = useController({
    control,
    name: 'amount',
    defaultValue: 800,
  });
  const { field: batchSizeField } = useController({
    control,
    name: 'batchSize',
    defaultValue: defaultSession?.trainingParams?.batch_size || 8,
  });

  const fromEpochOptions = useMemo(() => getEpochRange(modelField.value), [
    getEpochRange,
    modelField.value,
  ]);

  const { field: fromEpoch } = useController({
    control,
    name: 'fromEpoch',
    rules: { required: true },
    defaultValue: fromEpochOptions[fromEpochOptions.length - 1],
  });

  const handleModelChange = useCallback(
    (value?: string) => {
      const model = selectedSessionRuns.find(({ cid }) => cid === value);
      modelField.onChange(model?.cid);
      fromEpoch.onChange(getLastEpoch(model?.cid || ''));
    },
    [fromEpoch, getLastEpoch, modelField, selectedSessionRuns]
  );

  useEffect(() => {
    trigger();
  }, [trigger]);

  const onSubmit = useCallback(
    async (event: FormEvent<HTMLFormElement>) => {
      try {
        await handleSubmit<PopulationExplorationInputs>(
          async ({ model: sessionRunId, amount, batchSize, fromEpoch }) => {
            if (!versionId) {
              console.error(
                'how populationEploration was submitted without versionId?'
              );
              return;
            }

            const populationExplorationParams: PopulationExplorationParams = {
              projectId,
              sessionRunId,
              batchSize: +batchSize,
              numOfSamples: +amount,
              fromEpoch,
              digest: '', // todo
            };

            const job = await api.populationExploration(
              populationExplorationParams
            );
            return job;
          }
        )(event);
      } catch (e) {
        console.error(e);
      }
    },
    [handleSubmit, projectId, versionId]
  );

  return (
    <form
      onSubmit={onSubmit}
      className={clsx('w-full h-full', classes.childMargin)}
    >
      <Select
        label="Model"
        options={selectedSessionRuns}
        optionID="cid"
        optionToLabel={(model) => model?.name}
        {...register('model', {
          required: { value: true, message: 'Value is required' },
        })}
        error={errors.model && errors?.model?.message}
        {...modelField}
        onChange={handleModelChange}
      />
      <Select
        label="From Epoch"
        optionToLabel={(n: number) => n.toString()}
        options={fromEpochOptions}
        {...register('fromEpoch', {
          required: { value: true, message: 'Value is required' },
        })}
        error={errors.fromEpoch && errors?.fromEpoch?.message}
        {...fromEpoch}
      />
      <Input
        type="number"
        label="Amount"
        min={1}
        {...register('amount', {
          required: { value: true, message: 'Value is required' },
          min: { value: 1, message: 'Value must be greater or equal to 1' },
        })}
        error={errors.amount && errors?.amount?.message}
        {...amountField}
      />
      <Input
        type="number"
        label="Batch size"
        min={1}
        required
        {...register('batchSize', {
          required: { value: true, message: 'Value is required' },
          min: { value: 1, message: 'Value must be greater or equal to 1' },
        })}
        error={errors.batchSize && errors?.batchSize.message}
        {...batchSizeField}
      />
      <PaneActions
        closeTooltip={closeTooltip}
        enableSubmit={
          !errors.model &&
          !errors.amount &&
          !errors.batchSize &&
          !errors.fromEpoch
        }
      />
    </form>
  );
}
