import { Box, useTheme } from "@mui/material";
import { memo, useCallback, useRef } from "react";
import { ChartDataset } from "chart.js";
import { CarChartLegend } from "components/ChartLegend";
import { CarChipGroup } from "components/Chip";
import Plotly from "plotly.js-gl3d-dist-min";
import {
  getYieldCurveChartData,
  SimulationYieldsChartProps,
} from "./SimulationYieldsChart2d";

const chartHeight = 448;

type Trace = Partial<Plotly.PlotData>;
type Layout = Partial<Plotly.Layout>;

interface LastState {
  traces: Trace[];
  layout: Layout;
}

export const SimulationYieldsChart3d = memo(
  (props: SimulationYieldsChartProps) => {
    const theme = useTheme();
    const chartRef = useRef<HTMLDivElement>();
    const lastState = useRef<LastState>();

    const { datasets, xLabel, yLabel, labels } = getYieldCurveChartData(
      props.data,
      props.hiddenDatasetIds,
    );

    const finalDatasets = datasets.map<
      ChartDataset<"line", (number | null)[]> & { color: string }
    >((ds) => ({
      id: ds.id,
      label: ds.label,
      data: ds.data,
      borderColor: ds.color,
      backgroundColor: ds.color,
      pointBackgroundColor: ds.color,
      pointBorderColor: ds.color,
      borderWidth: 2,
      pointRadius: 3,
      color: ds.color,
      tension: 0.4,
      hidden: props.hiddenDatasetIds.includes(ds.id ?? ""),
    }));

    const visibleDs = finalDatasets.filter((ds) => !ds.hidden);

    const traces = visibleDs.map<Partial<Plotly.PlotData>>((ds, idx) => {
      return {
        x: ds.data.map((i) => {
          return [idx, idx + 1];
        }),
        y: ds.data.map((i, idx1) => [idx1, idx1]),
        z: ds.data.map((i) => [i ?? 0, i ?? 0]),
        hoverinfo: "y+z+name",
        hovertemplate: ` <br>\n  <b>${ds.id}</b>  \n<br>\n<br>\n  %{y}  \n<br>\n  ${yLabel}: %{z}  \n<br> `,
        name: ds.id?.split(" ").at(-1),
        colorscale: [
          [0, String(ds.borderColor)],
          [1, String(ds.borderColor)],
        ],
        type: "surface",
        showscale: false,
      };
    });

    const layout: Partial<Plotly.Layout> = {
      uirevision: 1,
      showlegend: true,
      autosize: true,
      width: chartRef.current?.clientWidth,
      height: chartRef.current?.clientHeight,
      font: {
        family: theme.typography.fontFamily,
        color: theme.palette.softBlack,
        size: 13,
      },
      hoverlabel: {
        font: {
          family: theme.typography.fontFamily,
          color: theme.palette.softBlack,
          size: 15,
        },
        bgcolor: theme.palette.gray1,
        bordercolor: theme.palette.gray6,
      },
      scene: {
        aspectmode: "auto",
        xaxis: {
          title: { text: props.datasetsLabel },
          tickvals: visibleDs.map((i, idx) => idx + 0.5),
          ticktext: visibleDs.map((i) => i.label?.split(" ").at(-1) ?? ""),
        },
        yaxis: {
          title: { text: xLabel },
          tickvals: visibleDs.at(0)?.data.map((i, idx) => idx) ?? [],
          ticktext: labels,
        },
        zaxis: {
          title: { text: yLabel },
          tickformat: ".0%",
          hoverformat: ".0%",
        },
      },

      margin: {
        t: 0,
        b: 0,
        r: 0,
        l: 0,
        pad: 0,
      },
    };

    const refCallback = useCallback((ref: HTMLDivElement | null) => {
      chartRef.current = ref ?? undefined;

      // if (ref) {
      //   console.log("mount");
      // } else {
      //   console.log("unmount");
      // }

      if (ref && lastState.current) {
        Plotly.newPlot(
          ref,
          lastState.current.traces,
          lastState.current.layout,
          {
            displayModeBar: true,
            displaylogo: false,
            responsive: true,
          },
        );
      }
    }, []);

    if (chartRef.current) {
      if (lastState.current?.traces.length !== traces.length) {
        lastState.current = {
          traces,
          layout,
        };

        Plotly.react(
          chartRef.current,
          lastState.current.traces,
          lastState.current.layout,
        );
      }
    }

    if (!lastState.current) {
      lastState.current = {
        traces,
        layout,
      };
    }

    return (
      <Box sx={{ display: "flex", gap: 4, width: "100%", pt: 1.5 }}>
        <Box
          sx={{
            width: 220,
            display: "flex",
            flexDirection: "column",
            backgroundColor: props.isWhiteContext ? "gray1" : "white",
            border: "1px solid",
            borderColor: "gray3",
            borderRadius: "5px",
            px: 3,
            py: 3,
            gap: 1,
          }}
        >
          <CarChipGroup
            label="Curves"
            direction="column"
            items={datasets.map((ds) => ({
              label: ds.label ?? "",
              value: ds.id ?? "",
              isChecked: !props.hiddenDatasetIds.includes(ds.id ?? ""),
            }))}
            onClick={(value) => {
              if (props.hiddenDatasetIds.includes(value)) {
                props.setHiddenDatasetIds(
                  props.hiddenDatasetIds.filter((i) => i !== value),
                );
              } else {
                props.setHiddenDatasetIds([...props.hiddenDatasetIds, value]);
              }
            }}
          />
        </Box>
        <Box
          sx={{
            display: "flex",
            flexDirection: "column",
            alignItems: "stretch",
            width: "100%",
          }}
        >
          <Box
            sx={{
              flex: "auto",
              width: "100%",
              display: "flex",
              flexDirection: "column",
              justifyContent: "center",
              border: "1px solid",
              borderColor: "gray3",
              backgroundColor: "white",
              borderRadius: "5px",
              overflow: "hidden",
              minHeight: chartHeight,
            }}
            ref={refCallback}
          />
          <CarChartLegend
            sx={{ mt: 1, ml: 9 }}
            labelVariant="caption"
            items={finalDatasets.map((ds, idx) => ({
              label: ds.label ?? "",
              color: ds.color,
              datasetIndex: idx,
              dataIndex: 0,
              hidden: ds.hidden,
            }))}
            // chartHighlight={chartHighlight}
          />
        </Box>
      </Box>
    );
  },
);
