import React, { useState, useEffect } from 'react';
import ReactApexChart from 'react-apexcharts';
import { Assessment } from "../../../enums/Assessment";
import {Typography} from "@mui/material";

const ConfusionMatrix = ({reviews}) => {
    const [chartData, setChartData] = useState(null);

    useEffect(() => {
        if (!reviews) return;

        const confusionMatrix = {
            [Assessment.VALID]: { [Assessment.VALID]: 0, [Assessment.UNFAIR]: 0, [Assessment.VOID]: 0 },
            [Assessment.UNFAIR]: { [Assessment.VALID]: 0, [Assessment.UNFAIR]: 0, [Assessment.VOID]: 0 },
            [Assessment.VOID]: { [Assessment.VALID]: 0, [Assessment.UNFAIR]: 0, [Assessment.VOID]: 0 },
        };

        const groupedReviews = reviews.reduce((acc, review) => {
            if (!acc[review.clause]) {
                acc[review.clause] = [];
            }
            acc[review.clause].push(review);
            return acc;
        }, {});

        for (const clause in groupedReviews) {
            const clauseReviews = groupedReviews[clause];
            const llmReview = clauseReviews.filter(review => review.author.kind === 'LLM').sort((a, b) => new Date(b.createdAt) - new Date(a.createdAt))[0];
            const humanReview = clauseReviews.filter(review => review.author.kind !== 'LLM').sort((a, b) => new Date(b.createdAt) - new Date(a.createdAt))[0];

            if (!llmReview) continue;

            const predictedLabel = llmReview.assessment;
            const trueLabel = humanReview ? humanReview.assessment : predictedLabel;

            confusionMatrix[trueLabel][predictedLabel]++;
        }

        const heatmapData = Object.keys(confusionMatrix).map(trueLabel => ({
            name: trueLabel,
            data: Object.keys(confusionMatrix[trueLabel]).map(predictedLabel => ({
                x: predictedLabel,
                y: confusionMatrix[trueLabel][predictedLabel] / Object.values(confusionMatrix[trueLabel]).reduce((sum, value) => sum + value, 0)
            })),
        })).reverse();

        setChartData(heatmapData);
    }, [reviews]);

    if (!chartData) {
        return <></>;
    }

    return (
        <div>
            <h2>Confusion Matrix</h2>
            <p>This chart shows how the predictions (columns) align with the true labels (rows).</p>
            {reviews.length > 0 ?
                <ReactApexChart options={{
                    colors: ["#008FFB"],
                    dataLabels: {
                        enabled: true,
                        formatter: (val) => `${(val * 100).toFixed(2)}%`,
                        style: {
                            colors: ['#000'],
                        },
                    },
                    xaxis: {
                        title: {
                            text: 'Predicted Labels',
                        },
                    },
                    yaxis: {
                        title: {
                            text: 'True Labels',
                        },
                    },
                    tooltip: {
                        y: {
                            formatter: (val) => `${(val * 100).toFixed(2)}%`,
                        },
                    },
                }} series={chartData} type="heatmap" height={350} /> : <Typography>Not enough data.</Typography>
            }
        </div>
    );
};

export default ConfusionMatrix;
