import React from 'react';
import { Bubble } from 'react-chartjs-2';

import ChartFilters, { type Filter } from './filters';

function findMinOrMaxValue({
  values,
  findMax,
  offset = 0,
}: {
  offset?: number;
  values: number[];
  findMax?: boolean;
}) {
  if (values.length === 0) {
    return 0;
  }

  let value = values[0];

  for (let i = 1; i < values.length; i++) {
    if (findMax) {
      if (values[i] > value) {
        value = values[i];
      }
    } else {
      if (values[i] < value) {
        value = values[i];
      }
    }
  }

  return Math.round((value + offset) * 10) / 10;
}

import {
  Chart as ChartJS,
  LinearScale,
  LogarithmicScale,
  PointElement,
  Tooltip,
  Legend,
  ChartOptions,
  TooltipOptions,
  LegendOptions,
} from 'chart.js';

ChartJS.register(LinearScale, LogarithmicScale, PointElement, Tooltip, Legend);

interface IChartItem {
  x: number;
  y: number;
  r: number;
  label: string;
  contextLength: number;
}

const calculateRadius = (contextLength: number, screenWidth: number) => {
  const radius = contextLength / 8;

  let minRadius = 7;
  let maxRadius = 35;
  if (screenWidth < 500) {
    minRadius = 4;
    maxRadius = 16;
  } else if (screenWidth < 600 && screenWidth >= 500) {
    minRadius = 4;
    maxRadius = 18;
  } else if (screenWidth < 768 && screenWidth >= 600) {
    minRadius = 5;
    maxRadius = 25;
  } else if (screenWidth < 1024 && screenWidth >= 768) {
    minRadius = 6;
    maxRadius = 30;
  }

  if (radius < minRadius) {
    return minRadius;
  }
  if (radius > maxRadius) {
    return maxRadius;
  }

  return radius;
};
type DeepPartial<T> = {
  [P in keyof T]?: T[P] extends object ? DeepPartial<T[P]> : T[P];
};
const defaultPlugins: {
  tooltip: DeepPartial<TooltipOptions>;
  legend: DeepPartial<LegendOptions<any>>;
} = {
  tooltip: {
    callbacks: {
      label: (context) =>
        `${context.raw.label} - $${context.raw.x} - Context length: ${context.raw.contextLength}k`,
    },
  },
  legend: {
    position: 'top' as const,
    labels: {
      padding: 20,
      font: {
        size: 16,
      },
    },
  },
};

const getChartOptions = ({
  xScaleDataset,
  yScaleDataset,
  minYTickOffset = 0,
  maxYTickOffset = 0,
}: {
  maxYTickOffset?: number;
  minYTickOffset?: number;
  xScaleDataset: number[];
  yScaleDataset: number[];
}): ChartOptions<'bubble'> => {
  const minYTick = findMinOrMaxValue({
    values: yScaleDataset,
    findMax: false,
    offset: minYTickOffset,
  });
  const maxYTick = findMinOrMaxValue({
    values: yScaleDataset,
    findMax: true,
    offset: maxYTickOffset,
  });

  return {
    scales: {
      x: {
        type: 'logarithmic' as const,
        position: 'bottom' as const,
        title: {
          display: true,
          text: '$/M prompt tokens (log scale)',
          font: {
            size: 18,
            weight: 'bold',
          },
        },
        ticks: {
          font: {
            size: 16,
          },
          callback: (value: number) => {
            // round the value to one decimal place
            if (value === 0) return '$0.0';
            if (value === 1) return '$1.0';
            if (value === 5) return '$5.0';
            if (value === 10) return '$10.0';
            if (value === 15) return '$15.0';
            if (value === 20) return '$20.0';
            if (value === 25) return '$25.0';
            return '';
            // return `$${value.toFixed(1)}`;
          },
          // stepSize: 2,
        },
        max: findMinOrMaxValue({
          values: xScaleDataset,
          findMax: true,
          offset: 5,
        }),
        min: 0,
      },
      y: {
        type: 'linear',
        position: 'left',
        title: {
          display: true,
          text: 'Context Adherence Score',
          font: {
            size: 18,
            weight: 'bold',
          },
        },
        ticks: {
          stepSize: 0.1,
          font: {
            size: 16,
          },
          callback: (value: number) => {
            // round the value to one decimal place
            if (value === maxYTick) return '';
            return parseFloat(`${value}`).toFixed(1);
          },
        },
        max: maxYTick,
        min: minYTick,
      },
    },
    plugins: defaultPlugins as any,
    maintainAspectRatio: false,
  };
};

const labelPlugin = {
  id: 'labelPlugin',
  afterDatasetsDraw(chart: any) {
    const ctx = chart.ctx;
    chart.data.datasets.forEach((dataset: any, i: number) => {
      const meta = chart.getDatasetMeta(i);
      if (!meta.hidden) {
        meta.data.forEach((element: any, index: number) => {
          const datapoint = dataset.data[index];
          const { x, y } = element.getCenterPoint();

          ctx.fillStyle = 'black';
          ctx.textAlign = 'center';
          ctx.textBaseline = 'middle';
          ctx.font = '14px Arial';
          ctx.fillText(datapoint.label, x, y - datapoint.r - 5);
        });
      }
    });
  },
};

interface IModel {
  pricePerToken: number;
  contextAdherenceScore: number;
  contextLength: number;
  name: string;
  shortName?: string;
  label?: string;
  type: string;
  responseCost: number;
  promptType?: string;
}

const formatData = (models: IModel[], screenWidth: number) => {
  const modelData = models.map((model) => ({
    x: model.pricePerToken,
    y: model.contextAdherenceScore,
    // calculate the radius based on the context length
    r: calculateRadius(model.contextLength, screenWidth),
    label: model.label || model.shortName || model.name,
    contextLength: model.contextLength,
    promptType: model.promptType,
  }));
  return modelData;
};

type Props = {
  models: IModel[];
  enabledFilters: string[];
  defaultSelectedChartFilters?: Filter[];
  maxYTickOffset?: number;
  minYTickOffset?: number;
};

function Chart({
  models,
  enabledFilters,
  defaultSelectedChartFilters,
  maxYTickOffset,
  minYTickOffset,
}: Props) {
  const wrapperEl = React.useRef<HTMLDivElement>(null);
  const [items, setItems] = React.useState<IChartItem[]>([]);

  const chartOptions = {
    ...getChartOptions({
      xScaleDataset: items.map((i) => i.x),
      yScaleDataset: items.map((i) => i.y),
      maxYTickOffset,
      minYTickOffset,
    }),
  };

  React.useEffect(() => {
    applyFilters(defaultSelectedChartFilters || []);
  }, []);

  const handleResize = () => {
    // get the width of the wrapper element
    const width = wrapperEl.current?.clientWidth;
    const reformattedData = formatData(models, width || window.innerWidth);
    setItems(reformattedData);
  };

  const modelsByCostRange = {
    low: {
      value: 1,
      label: '(< $1.0)',
    },
    normal: {
      value: 3,
      label: '($1.0 - $3.0)',
    },
    high: {
      value: 3,
      label: '(> $3.0)',
    },
  };

  const sortFromSmallestToLargestOnChart = (data: IChartItem[]) => {
    return data.slice().sort((a, b) => b.r - a.r);
  };

  chartOptions.onResize = handleResize;
  const sortedItems = sortFromSmallestToLargestOnChart(items);
  // group by x value into 3 groups, from 0 to 1, from 1 to 3, and from 3+
  const groupedData = sortedItems.reduce(
    (acc, dataPoint) => {
      if (dataPoint.x < modelsByCostRange.low.value) {
        acc[0].data.push(dataPoint);
      } else if (dataPoint.x < modelsByCostRange.normal.value) {
        acc[1].data.push(dataPoint);
      } else {
        acc[2].data.push(dataPoint);
      }

      return acc;
    },
    [
      {
        label: `Affordable Workhorses ${modelsByCostRange.low.label}`,
        data: [],
        backgroundColor: '#A2CFFF',
      },
      {
        label: `Mid-Range Challengers ${modelsByCostRange.normal.label}`,
        data: [],
        backgroundColor: '#8f74ff',
      },
      {
        label: `Pricey Top Dogs ${modelsByCostRange.high.label}`,
        data: [],
        backgroundColor: '#A03AA9',
      },
    ],
  );
  const data = {
    datasets: groupedData,
  };

  const applyFilters = (filters: Filter[]) => {
    const filtredModels = models.filter((model) => {
      return filters.every((filter) => {
        switch (filter.name) {
          case 'type':
            return filter.values.includes(model.type);
          case 'promptType':
            return filter.values.includes(model.promptType);
          case 'contextLength':
          case 'pricePerToken':
          case 'responseCost':
            return filter.values.some((range) => {
              const value = model[filter.name];
              if (typeof range === 'object' && range.from !== undefined) {
                if (range.from === null || range.from === undefined) {
                  return value <= range.to;
                } else if (range.to === null || range.to === undefined) {
                  return value >= range.from;
                } else {
                  return value >= range.from && value <= range.to;
                }
              }
              return true;
            });
          default:
            return true; // Ignore any filters not explicitly handled
        }
      });
    });

    const reformattedData = formatData(
      filtredModels,
      wrapperEl.current?.clientWidth || window.innerWidth,
    );

    setItems(reformattedData);

    // setTimeout(() => {
    //   const highestY = Math.max(...items.map((dataPoint) => dataPoint.y));
    //   const highestX = Math.max(...items.map((dataPoint) => dataPoint.x));
    //   chartOptions.scales.x.max = highestX + highestX * 0.2;
    //   chartOptions.scales.y.max = highestY + highestY * 0.2;
    // }, 200);
  };

  return (
    <>
      <ChartFilters
        applyFilters={applyFilters}
        enabledFilters={enabledFilters}
        defaultSelectedChartFilters={defaultSelectedChartFilters}
      />
      <div className="relative my-2 w-full" ref={wrapperEl}>
        <Bubble
          data={data}
          options={chartOptions}
          plugins={wrapperEl.current?.clientWidth < 768 ? [] : [labelPlugin]}
          height={620}
        />
      </div>
    </>
  );
}

export default Chart;
