import { AxisLeft } from '@visx/axis';
import { GridColumns } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLinear, scaleOrdinal } from '@visx/scale';
import { BarGroupHorizontal, BarStackHorizontal } from '@visx/shape';
import type {
  AnyScaleBand,
  BarGroupHorizontal as BarGroupHorizontalType,
  PositionScale,
} from '@visx/shape/lib/types';
import type { ScaleOrdinal } from 'd3-scale';
import type { ReactElement } from 'react';
import { useCallback, useMemo, useState } from 'react';

import type { BarTooltipProps } from './components';
import {
  AxisBottomPercentLabels,
  AxisLeftGroupLabel,
  BarPercentLabel,
  BarTooltip,
  ChartLegend,
} from './components';
import { barChartContainerCss } from './styles';
import {
  defaultMargin,
  getLegendColors,
  getLegendLabels,
  mapStringToNumber,
} from './utils/helpers';
import type { BarChartProps, BarData, BarItem, BarItemValue } from './utils/types';

export type StackedBarGroupProps = Required<BarChartProps>;

type SubGroupData = {
  subGroupName: string;
  [key: string]: BarItemValue;
};

type GroupData = {
  groupName: string;
  data: SubGroupData[];
};

export type StackedBarsProps = {
  barGroups: BarGroupHorizontalType<string>[];
  data: GroupData[];
  keys: string[];
  yMax: number;
  xScale: PositionScale;
  yScale: AnyScaleBand;
  colorScale: ScaleOrdinal<string, string>;
  hideLabels?: boolean;
  showAsPercent?: boolean;
  numberOfDecimalPlaces?: number;
  isRTL?: boolean;
};

export const StackedBarGroup = (props: StackedBarGroupProps): ReactElement => {
  const {
    width,
    height,
    stackKey,
    yAxis,
    xAxes,
    xMax,
    yMax,
    xScaleTicks,
    leftPosition,
    data,
    dataColors,
    hideXAxisLabels,
    hideYAxisLabels,
    hideXPercentLabels,
    showAsPercent,
    numberOfDecimalPlaces,
    hideLegend,
    isRTL,
  } = props;

  const getGroupValue = useCallback((data: GroupData) => {
    return data.groupName;
  }, []);

  const mappedData = useMemo(
    () => mapDataToStackedBarsFormat(data, yAxis, xAxes, stackKey),
    [data, yAxis, xAxes, stackKey]
  );

  const legendLabels = useMemo(() => getLegendLabels(data, stackKey), [data, stackKey]);
  const barColors = useMemo(
    () => getLegendColors(legendLabels, dataColors),
    [legendLabels, dataColors]
  );

  const yScale = useMemo(
    () =>
      scaleBand<string>({
        range: [0, yMax],
        domain: mappedData.map(getGroupValue),
        padding: 0.2,
      }),
    [yMax, mappedData, getGroupValue]
  );

  const xScale = useMemo(
    () =>
      scaleLinear<number>({
        domain: [0, getBiggestStackValue(data, stackKey, yAxis)],
        range: [0, xMax],
      }),
    [xMax, data, yAxis, stackKey]
  );

  const xPercentScale = useMemo(
    () =>
      scaleBand<string>({
        range: [0, yScale.bandwidth()],
        domain: xAxes,
        padding: 0.1,
        paddingOuter: 1,
      }),
    [yScale, xAxes]
  );

  const colorScale = useMemo(
    () =>
      scaleOrdinal<string, string>({
        domain: legendLabels,
        range: barColors.slice(0, legendLabels.length),
      }),
    [legendLabels, barColors]
  );

  return (
    <figure className={barChartContainerCss}>
      <svg width={width} height={height}>
        <Group top={defaultMargin.top} left={leftPosition}>
          <GridColumns scale={xScale} height={height - 60} numTicks={xScaleTicks} />
          <BarGroupHorizontal
            data={mappedData}
            keys={xAxes}
            width={xMax}
            y0={getGroupValue}
            y0Scale={yScale}
            y1Scale={xPercentScale}
            xScale={xScale}
            color={colorScale}
          >
            {barGroups => (
              <StackedBars
                barGroups={barGroups}
                data={mappedData}
                keys={legendLabels}
                yMax={yMax}
                xScale={xScale}
                yScale={xPercentScale}
                colorScale={colorScale}
                hideLabels={hideXAxisLabels}
                showAsPercent={showAsPercent}
                numberOfDecimalPlaces={numberOfDecimalPlaces}
                isRTL={isRTL}
              />
            )}
          </BarGroupHorizontal>
          {!hideXPercentLabels && (
            <AxisBottomPercentLabels
              yMax={yMax}
              xScale={xScale}
              numTicks={xScaleTicks}
              hidePercentSignal={!showAsPercent}
            />
          )}
        </Group>
        {!hideYAxisLabels && <AxisLeftGroupLabel yScale={yScale} />}
      </svg>
      {!hideLegend && <ChartLegend colorScale={colorScale} keys={legendLabels} />}
    </figure>
  );
};

const StackedBars = (props: StackedBarsProps) => {
  const {
    barGroups,
    data,
    keys,
    yMax,
    xScale,
    yScale,
    colorScale,
    hideLabels,
    showAsPercent,
    numberOfDecimalPlaces,
    isRTL,
  } = props;
  const [tooltipData, setTooltipData] = useState<BarTooltipProps | null>(null);

  let tooltipTimeout: ReturnType<typeof setTimeout>;

  const handleMouseLeave = () => {
    tooltipTimeout = setTimeout(() => {
      setTooltipData(null);
    }, 300);
  };

  const handleMouseMove = (
    top: number,
    left: number,
    barKey: string,
    barGroupIndex: number,
    barDataIndex: number
  ) => {
    if (tooltipTimeout) {
      clearTimeout(tooltipTimeout);
    }

    const { groupName } = data[barGroupIndex]!;
    const { subGroupName } = data[barGroupIndex]!.data[barDataIndex]!;
    const colorKey = barKey;
    const value = data[barGroupIndex]!.data[barDataIndex]![colorKey]!;

    setTooltipData({
      tooltipData: { value, colorKey, groupName, subGroupName },
      tooltipTop: top,
      tooltipLeft: left,
      colorScale,
      showAsPercent,
      numberOfDecimalPlaces,
    });
  };

  return (
    <>
      {barGroups.map(barGroup => {
        const stackedData = data[barGroup.index]!.data;
        return barGroup.bars.map(bar => (
          <Group key={`${barGroup.index}-${bar.index}-${bar.key}`} top={barGroup.y0}>
            <BarStackHorizontal
              key={`barstack-horizontal-${barGroup.index}-${bar.index}-${bar.key}`}
              data={stackedData}
              keys={keys}
              height={yMax}
              y={d => d.subGroupName}
              xScale={xScale}
              yScale={yScale}
              color={colorScale}
            >
              {barStacks =>
                barStacks.map((barStack, index) =>
                  barStack.bars.map(barData => {
                    const isLastStackBar = barStacks.length === index + 1;
                    return (
                      <Group key={`stacked-bar-${barGroup.index}-${barData.index}`}>
                        <rect
                          x={barData.x}
                          y={barData.y}
                          width={barData.width}
                          height={barData.height}
                          fill={barData.color}
                          onMouseLeave={handleMouseLeave}
                          onMouseMove={e =>
                            handleMouseMove(
                              e.clientY,
                              e.clientX,
                              barStack.key,
                              barGroup.index,
                              barData.index
                            )
                          }
                        />
                        {isLastStackBar && (
                          <BarPercentLabel
                            barX={barData.x}
                            barY={barData.y}
                            barWidth={barData.width}
                            barHeight={barData.height}
                            showAsPercent={showAsPercent}
                            numberOfDecimalPlaces={numberOfDecimalPlaces}
                            value={getXAxisStackedBarsTotal(barData.bar.data)}
                            isRTL={isRTL}
                          />
                        )}
                      </Group>
                    );
                  })
                )
              }
            </BarStackHorizontal>

            {!hideLabels && (
              <AxisLeft
                hideTicks
                scale={yScale}
                hideAxisLine
                tickLabelProps={() => ({
                  fill: 'black',
                  fontSize: 11,
                  // Bar labels are right aligned against their respective bar. The "end" is actually the "start"
                  // of the text for RTL languages, so anchor should be 'start' for RTL and 'end' for LTR.
                  textAnchor: isRTL ? 'start' : 'end',
                  dy: '0.33em',
                })}
              />
            )}
          </Group>
        ));
      })}
      {tooltipData && <BarTooltip {...tooltipData} />}
    </>
  );
};

function mapDataToStackedBarsFormat(
  data: BarData,
  yAxis: string,
  xAxes: string[],
  stackKey: string
): GroupData[] {
  const dataByYAxis = data.reduce((result: Record<string, BarItem[]>, item) => {
    const { [yAxis]: yAxisName, ...itemWithouYAxis } = item;

    if (result[yAxisName as string]) {
      result[yAxisName as string]!.push(itemWithouYAxis);
    } else {
      result[yAxisName as string] = [itemWithouYAxis];
    }

    return result;
  }, {});

  const groupStackedDataByXAxes = (items: BarItem[]) => {
    const xAxisDataMapping: SubGroupData[] = [];

    xAxes.forEach(xAxis => {
      const mappedItem: SubGroupData = {
        subGroupName: xAxis,
      };

      items.forEach(item => {
        const stackValue = item[stackKey]!;
        mappedItem[stackValue] = item[xAxis] ?? 0;
      });

      xAxisDataMapping.push(mappedItem);
    });

    return xAxisDataMapping;
  };

  const stackedData: GroupData[] = Object.keys(dataByYAxis).map(yAxis => ({
    groupName: yAxis,
    data: groupStackedDataByXAxes(dataByYAxis[yAxis]!),
  }));

  const ensureSamePropsForAllData = (groupsData: GroupData[]) => {
    const allKeys: string[] = groupsData.flatMap(({ data }) =>
      data.flatMap(item => Object.keys(item))
    );
    const uniqueKeys = Array.from(new Set(allKeys));

    groupsData.forEach(group => {
      group.data.forEach(subGroup => {
        uniqueKeys.forEach(key => {
          subGroup[key] = subGroup[key] ?? 0;
        });
      });
    });

    return groupsData;
  };

  const finalData = ensureSamePropsForAllData(stackedData);
  const sortedData = finalData.sort((a, b) => a.groupName.localeCompare(b.groupName));

  return sortedData;
}

function getBiggestStackValue(data: BarData, stackKey: string, yAxis: string) {
  const keysToIgnore = [stackKey, yAxis];

  const totalMapping: Record<string, number> = {};

  data.forEach(item => {
    const yAxisValue = item[yAxis];

    Object.keys(item)
      .filter(key => !keysToIgnore.includes(key))
      .forEach(dataKey => {
        const mappingKey = `${yAxisValue}-${dataKey}`;
        const currentMappingTotal = totalMapping[mappingKey] ?? 0;
        const currentValue = mapStringToNumber(item[dataKey]);
        totalMapping[mappingKey] = currentMappingTotal + currentValue;
      });
  });

  const biggestTotal = Object.values(totalMapping).reduce(
    (previous, current) => Math.max(previous, current),
    0
  );

  return biggestTotal;
}

function getXAxisStackedBarsTotal(data: SubGroupData) {
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
  const { subGroupName, ...valueData } = data;
  const subGroupStackedBarsTotal = Object.values(valueData).reduce((total: number, current) => {
    total += mapStringToNumber(current);
    return total;
  }, 0);

  return subGroupStackedBarsTotal;
}
