import { useState } from 'react';
import styled from 'styled-components/macro';
import { motion, AnimatePresence } from 'framer-motion';
import { linearRegression, linearRegressionLine, rSquared } from 'simple-statistics';
import randomGen from 'random-seed';
import Star from 'components/charts/Star';
import Label from 'components/charts/Label';

const AXIS_PADDING = 20;

const AxisLabel = styled.text`
  font-size: 4px;
  font-family: ${p => p.theme.fonts.body};
  fill: ${p => p.tick ? p.theme.colors.dark : p.theme.colors.text};
`;

const HeatMapSection = styled(motion.rect)`
  fill: ${p => p.theme.colors.highlight};
`;

const RLabel = styled(motion.text)`
  font-size: 6px;
`;

function spreadMatchingPoints(dataset, boundsX, boundsY, isRadial) {
  const distinctPoints = [...new Set(dataset.map(d => JSON.stringify([d.x,d.y])))]

  return distinctPoints.flatMap(s => {
    const [x,y] = JSON.parse(s);

    const matchingPoints = dataset.filter(d => d.x === x && d.y === y);

    return matchingPoints.map((p, i) => {
      if (isRadial) {
        const angle = (i / matchingPoints.length) * 2 * Math.PI;
        const radius = 4;

        var dx = Math.cos(angle) * radius;
        var dy = Math.sin(angle) * radius;

        return { ...p, x: p.x + dx, y: p.y + dy };
      } else {
        const rand = randomGen(p.id);

        return {
          ...p,
          x: x + rand.intBetween(0, boundsX) - (boundsX / 2),
          y: y + rand.intBetween(0, boundsY) - (boundsY / 2)
        };
      }
    });
  });
}

export default function Chart({ data, width, height, xAxis, yAxis, xAxisTitle, yAxisTitle, spreadOption, showHeatmap, showTrendline, className, onClickPoint }) {
  const [hoveredPoint, setHoveredPoint] = useState(null);
  const outerHeight = height + AXIS_PADDING;
  const outerWidth = width + AXIS_PADDING;

  const xChunkSize = (width - 0.5) / xAxis.length;
  const yChunkSize = (height - 0.5) / yAxis.length;

  const scaledData = data.map(r => ({
    ...r,
    x: r.x * (width - xChunkSize) + xChunkSize / 2,
    y: (1-r.y) * (height - yChunkSize) + yChunkSize / 2,
  }));

  const spreadDataset = spreadMatchingPoints(scaledData, xChunkSize - 1, yChunkSize - 1, spreadOption === 'radial');
  const samples = scaledData.map(d => [d.x,d.y])
  const linearReg = showTrendline && samples.length && linearRegression(samples);
  const trendline = linearReg?.m && linearRegressionLine(linearReg);
  const rValue = trendline && rSquared(samples, trendline);

  const [tx1, tx2] = [0, width];
  const [ty1, ty2] = trendline ? [trendline(tx1), trendline(tx2)] : [];

  const chunkSizes = data.reduce((prev,d) => {
    const key = JSON.stringify([d.x * (xAxis.length - 1), d.y * (yAxis.length - 1)]);
    return { ...prev, [key]: (prev[key] || 0) + 1 };
  }, {});

  const maxChunkSize = Math.max(...Object.values(chunkSizes));

  return (
    <motion.svg viewBox={`0 0 ${outerWidth} ${outerHeight}`} className={className}>
      {xAxis.map(([label], i) => (
        <AxisLabel
          tick
          key={`x-${i}`}
          x={AXIS_PADDING + i * xChunkSize + xChunkSize/2}
          y={height + 9}
          textAnchor="middle"
        >
          {label}
        </AxisLabel>
      ))}

      {yAxis.map(([label,], i) => {
        const x = 13;
        const y = height - (i * yChunkSize + yChunkSize/2);

        return (
          <AxisLabel
            tick
            key={`y-${i}`}
            x={x}
            y={y}
            textAnchor="middle"
            transform={`rotate(270 ${x} ${y})`}
          >
            {label}
          </AxisLabel>
        );
      })}

      <AxisLabel
        x={AXIS_PADDING + width / 2}
        y={height + 19}
        textAnchor="middle"
      >
        {xAxisTitle}
      </AxisLabel>

      <AxisLabel
        x={3}
        y={height / 2}
        textAnchor="middle"
        transform={`rotate(270 ${3} ${height / 2})`}
      >
        {yAxisTitle}
      </AxisLabel>

      <g transform={`translate(${AXIS_PADDING}, 0)`}>
        {yAxis.flatMap((_,y) => (
          xAxis.map((_,x) => {
            // We draw these squares even if showHeatmap is false because they form the gridlines too
            const key = JSON.stringify([x,y]);
            const numberInChunk = showHeatmap ? (chunkSizes[key] || 0) : 0;

            return (
              <HeatMapSection
                key={`heatmap-${y}-${x}`}
                x={xChunkSize * x}
                y={height - yChunkSize * (y+1)}
                width={xChunkSize}
                height={yChunkSize}
                strokeWidth={0.5}
                stroke='#ECECEC'
                initial={false}
                animate={{
                  fillOpacity: (numberInChunk / maxChunkSize) / 2,
                  transition: { duration: 0.6 }
                }}
              />
            );
          })
        ))}

        <AnimatePresence>
          {spreadDataset.map(r => {
            return (
              <Star
                key={r.id}
                color={r.color}
                onClick={() => onClickPoint(r)}
                onMouseOver={() => setHoveredPoint(r)}
                onMouseOut={() => setHoveredPoint(null)}
                initial={{
                  opacity: 0,
                  cx: r.x,
                  cy: height + 10,
                }}
                animate={{
                  opacity: 1,
                  cx: r.x,
                  cy: r.y,
                }}
                exit={{
                  opacity: 0,
                  cx: r.x,
                  cy: height + 10,
                }}
                transition={{ type: 'spring', stiffness: 200, damping: 20 }}
              />
            )
          })}
          {trendline && (
            <>
              <motion.line
                transition={{ type: 'spring', duration: 0.6 }}
                initial={{
                  opacity: 0,
                  x1: tx1,
                  x2: tx2,
                  y1: height + 10,
                  y2: height + 10,
                }}
                animate={{
                  opacity: 1,
                  x1: tx1,
                  x2: tx2,
                  y1: ty1,
                  y2: ty2,
                }}
                exit={{
                  opacity: 0,
                  x1: tx1,
                  x2: tx2,
                  y1: height + 10,
                  y2: height + 10,
                }}
                stroke="black"
              />
              <RLabel
                initial={{
                  x: tx2+30,
                  y: ty2 + (ty2 > height/2 ? -10 : 10),
                }}
                animate={{
                  opacity: 1,
                  x: tx2-2,
                  y: ty2 + (ty2 > height/2 ? -10 : 10)
                }}
                exit={{
                  opacity: 0,
                  x: tx2+30,
                  y: ty2 + (ty2 > height/2 ? -10 : 10),
                }}
                textAnchor="end"
              >
                r: {Math.round(rValue * 100) / 100}
              </RLabel>
            </>
          )}
        </AnimatePresence>

        {hoveredPoint ? (
          <Label
            x={hoveredPoint.x}
            y={hoveredPoint.y}
            reverse={hoveredPoint.x > width * 0.7}
          >
            {hoveredPoint.identifier ?? `Participant ${hoveredPoint.num}`}
          </Label>
        ) : null}
      </g>
    </motion.svg>
  );
}
