import { groupBy, first, range, entries, sortBy, last, fromPairs, get } from 'lodash';
import {
  differenceInCalendarDays,
  differenceInCalendarWeeks,
  differenceInCalendarMonths,
  differenceInCalendarYears,
  parseISO,
} from 'date-fns';

import { Exploration, Pipeline } from '@/explore/types';
import { getCohortOperation, getIdFields, getTimeKeyFields } from '@/explore/edit-cohort/utils';
import { dereferencePipeline } from '@/explore/pipeline/utils';
import { getFinalStateOrThrow, PipelineStateContext } from '@/explore/pipeline/state';

interface CohortRow {
  cohort_date: string;
  cohort_size: number;
  period_number: number;
  period_dropped: number;
}

export const convertData = (
  rows: CohortRow[],
  eventTimeInterval: string = 'month',
  maxDate: Date = new Date(),
) =>
  entries(groupBy(rows, 'cohort_date')).map(([cohortStart, steps]) => {
    const sortedSteps = sortBy(steps, 'period_number');
    const cohortSize = first(sortedSteps)?.cohort_size ?? 0;
    const totalSteps = last(sortedSteps)?.period_number ?? 0; // Up to last churn event
    const periodMap = fromPairs(
      sortedSteps.map(({ period_number, period_dropped }) => [period_number, period_dropped]),
    );
    let sizeInStep = cohortSize;
    const startDate = parseISO(cohortStart);
    const maxSteps =
      eventTimeInterval === 'day'
        ? differenceInCalendarDays(maxDate, startDate)
        : eventTimeInterval === 'week'
          ? differenceInCalendarWeeks(maxDate, startDate)
          : eventTimeInterval === 'month'
            ? differenceInCalendarMonths(maxDate, startDate)
            : differenceInCalendarYears(maxDate, startDate);
    const stepSizes = range(maxSteps + 1).reduce((acc, step) => {
      sizeInStep -= get(periodMap, step) ?? 0;
      // Omit zeroes if we are beyond the last churn event (cohort period)
      if (step <= totalSteps || sizeInStep > 0) {
        acc.push(sizeInStep);
      }
      return acc;
    }, [] as number[]);

    return { cohortStart: startDate, size: cohortSize, stepSizes };
  });

export const pipelineHasCohortOperation = (pipeline: Pipeline) =>
  pipeline.operations.find((o) => o.operation === 'cohort') !== undefined;

/**
 * Cohort data will not be fetched if invalid.
 */
export const validateCohortOperation = (
  pipeline: Pipeline,
  exploration: Exploration,
  ctx: PipelineStateContext,
) => {
  const operation = getCohortOperation(pipeline);

  if (operation === undefined) {
    return { isValid: false, error: 'Please fill in the operation parameters.' };
  }

  const { baseModelId, operations } = dereferencePipeline(
    operation.parameters.pipeline,
    exploration,
  );
  const { fields } = getFinalStateOrThrow(baseModelId, operations, ctx);

  if (!getIdFields(fields).some((field) => field.key === operation.parameters.cohortId)) {
    return { isValid: false, error: 'Please make sure the cohort ID field is valid.' };
  }

  if (!getTimeKeyFields(fields).some((field) => field.key === operation.parameters.cohortTimeKey)) {
    return { isValid: false, error: 'Please make sure the cohort time field is valid.' };
  }

  if (!getTimeKeyFields(fields).some((field) => field.key === operation.parameters.eventTimeKey)) {
    return { isValid: false, error: 'Please make sure the event time field is valid.' };
  }

  return { isValid: true };
};
