import { Fragment } from 'react';
import { capitalize, nth, range } from 'lodash';
import * as dateFns from 'date-fns';
import { add } from 'date-fns';
import classNames from 'classnames';
import { common } from '@gosupersimple/types';
import { useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';

import { getFormattedDateForPeriod } from '@/lib/date';

import { ChartTooltip } from '../common';

import styles from './cohort-chart.module.scss';

const getCellStyle = (percentage: number) => {
  if (percentage === 100) {
    return styles.cellIntensity1;
  }
  if (percentage >= 75) {
    return styles.cellIntensity2;
  }
  if (percentage >= 50) {
    return styles.cellIntensity3;
  }
  if (percentage >= 25) {
    return styles.cellIntensity4;
  }
  return styles.cellIntensity5;
};

const getPeriodLabel = (
  cohortDate: Date,
  period: number,
  periodInterval: common.CohortTimeInterval,
  timezone?: string,
) =>
  getFormattedDateForPeriod(
    add(cohortDate, { [`${periodInterval}s`]: period }),
    periodInterval,
    timezone,
  );

const cohortDateDiff = (curr: Date, prev?: Date, interval?: string) => {
  if (prev === undefined) {
    return 0;
  }
  if (interval === 'day') {
    return Math.abs(dateFns.differenceInCalendarDays(prev, curr));
  } else if (interval === 'week') {
    return Math.abs(dateFns.differenceInCalendarWeeks(prev, curr));
  } else if (interval === 'month') {
    return Math.abs(dateFns.differenceInCalendarMonths(prev, curr));
  } else if (interval === 'year') {
    return Math.abs(dateFns.differenceInCalendarYears(prev, curr));
  }
  return 0;
};

type TooltipData = {
  title: string;
  periodStart: string;
  timeInterval: string;
  step: number;
  size: number;
  stepSize: number;
  percentage: number;
};

interface CohortChartProps {
  data: { cohortStart: Date; size: number; stepSizes: number[] }[];
  height?: number;
  cohortTimeInterval?: common.CohortTimeInterval;
  eventTimeInterval?: common.CohortTimeInterval;
}

const getScale = (node: HTMLDivElement) => {
  return node.offsetWidth / node.getBoundingClientRect().width;
};

export const CohortChart = (props: CohortChartProps) => {
  const { data, cohortTimeInterval = 'month', eventTimeInterval = 'month' } = props;

  const { showTooltip, hideTooltip, tooltipOpen, tooltipLeft, tooltipTop, tooltipData } =
    useTooltip<TooltipData>();

  const maxStep = data.length > 0 ? Math.max(...data.map((d) => d.stepSizes.length)) : 0;

  const handleMouseOver = (event: any, data: TooltipData) => {
    const scale = getScale(event.target.parentElement);
    const coords = localPoint(event.target.parentElement, event);

    if (coords !== null) {
      showTooltip({
        tooltipLeft: coords.x * scale,
        tooltipTop: coords.y * scale,
        tooltipData: data,
      });
    }
  };

  return (
    <div className={styles.cohortChart}>
      <div
        className={styles.grid}
        style={{
          gridTemplateColumns: `auto auto repeat(${maxStep}, 1fr)`,
        }}>
        <div className={styles.header} key={`date-header`}>
          Date
        </div>
        <div className={styles.header} key={`count-header`}>
          Users
        </div>
        {range(maxStep).map((step, stepIdx) => {
          return (
            <div key={`value-header-${stepIdx}`} className={styles.header}>
              <span className={styles.period}>{capitalize(eventTimeInterval)}</span> {step}
            </div>
          );
        })}
        {data.map((row, idx) => (
          <Fragment key={idx}>
            <div>{getFormattedDateForPeriod(row.cohortStart, cohortTimeInterval)}</div>
            <div>{row.size}</div>
            {range(maxStep).map((step, stepIdx) => {
              const stepSize = nth(row.stepSizes, step);

              if (stepSize === undefined) {
                return <div key={`value-${idx}-${stepIdx}`} />;
              }

              return (
                <div
                  key={`value-${idx}-${stepIdx}`}
                  className={classNames(
                    styles.cell,
                    styles.valueCell,
                    getCellStyle((stepSize / row.size) * 100),
                  )}
                  onMouseMove={(e) => {
                    return handleMouseOver(e, {
                      title: getFormattedDateForPeriod(row.cohortStart, cohortTimeInterval),
                      periodStart: getPeriodLabel(row.cohortStart, step, eventTimeInterval),
                      timeInterval: eventTimeInterval,
                      step,
                      stepSize,
                      size: row.size,
                      percentage: stepSize / row.size,
                    });
                  }}
                  onMouseOut={hideTooltip}>
                  {`${Math.round((stepSize / row.size) * 100)}%`}
                </div>
              );
            })}
            {cohortDateDiff(row.cohortStart, nth(data, idx + 1)?.cohortStart, cohortTimeInterval) >
              1 && (
              <div
                className={styles.separator}
                style={{
                  gridColumnStart: 1,
                  gridColumnEnd: maxStep + 3,
                }}
              />
            )}
          </Fragment>
        ))}
      </div>
      {tooltipOpen && tooltipData && (
        <ChartTooltip top={tooltipTop ?? 0} left={tooltipLeft ?? 0}>
          <div className={styles.tooltipTitle}>
            <b>{tooltipData.title}</b> cohort
          </div>
          <div className={styles.tooltipTitle}>
            {capitalize(tooltipData.timeInterval)} {tooltipData.step} ({tooltipData.periodStart})
          </div>
          <b>{Math.round(tooltipData.percentage * 100)}%</b>{' '}
          {`(${tooltipData.stepSize} of ${tooltipData.size})`}
        </ChartTooltip>
      )}
    </div>
  );
};
