import React, { Fragment, useContext, useMemo, useState } from "react"

import { useQuery } from "react-query"
import { useNavigate, useParams } from "react-router-dom"

import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"
import { Box, CircularProgress, Skeleton } from "@mui/material"
import Grid from "@mui/material/Grid"
import { DateRangePicker, Tooltip } from "@synapse-analytics/synapse-ui"
import { AxiosError, AxiosResponse } from "axios"
import moment from "moment"

import { AccuracyCard } from "../../../../components/AccuracyCard"
import { KonanEmptyState } from "../../../../components/KonanEmptyState"
import { KonanPageHeader } from "../../../../components/KonanPageHeader"
import { ConfusionMatrix } from "../../../../components/graphs/ConfusionMatrix"
import { CoveragePieChart } from "../../../../components/graphs/CoveragePieChart"
import { MetricsBarChart } from "../../../../components/graphs/MetricsBarChart"
import { useDateQuery } from "../../../../hooks/useDateQuery"
import { KonanAPI } from "../../../../services/KonanAPI"
import { CurrentProjectAndModelContext } from "../../../../store/CurrentProjectAndModelContext"
import { ClassificationEvaluation } from "../../../../types/generated/api/ClassificationEvaluation"
import { DeploymentLatestFeedbackDateTime } from "../../../../types/generated/api/DeploymentLatestFeedbackDateTime"
import { DeploymentPredictionsSummary } from "../../../../types/generated/api/DeploymentPredictionsSummary"
import { ListClassificationModelOutputs } from "../../../../types/generated/api/ListClassificationModelOutputs"
import { PaginatedListClassificationModelOutputsList } from "../../../../types/generated/api/PaginatedListClassificationModelOutputsList"
import { customColorScheme } from "../../../../utils/genericHelpers"
import {
  convertCoverageToPieChart,
  convertDataToConfusionMatrix,
  getModelPredictionsCoverageState,
  getModelPredictionsGraphsState,
} from "../../../../utils/modelDetailsHelpers"

import styles from "./ClassificationMetrics.module.scss"

type ParamsType = {
  id: string
}

type CoverageBlockType = {
  name: string
  predictedValue: number
  actualValue: number
  colorIndex: number
}

function CoverageBlock(props: CoverageBlockType): React.ReactElement {
  const { name, predictedValue, actualValue, colorIndex } = props

  return (
    <Grid item>
      <Grid container justifyContent="center" alignItems="center" spacing={1}>
        <Grid item>
          <span key="chip" className={styles.coverageBlock} style={{ background: customColorScheme()[colorIndex] }} />
        </Grid>
        <Grid item className={styles.coverageBlockName}>
          <strong key="name">{name}</strong>
        </Grid>
        <Grid item className={styles.coverageBlockValueContainer}>
          <span key="value1" className={styles.coverageBlockValue}>
            {predictedValue + " / "}
          </span>
          <span key="value2" className={styles.coverageBlockValue}>
            {actualValue}
          </span>
        </Grid>
      </Grid>
    </Grid>
  )
}

export function ClassificationMetrics(): React.ReactElement {
  const { id: projectId } = useParams<ParamsType>()

  const navigate = useNavigate()

  const metricsBaseHeight = 50
  const graphHeight = 300

  const now = moment()

  const [startDate, setStartDate, endDate, setEndDate] = useDateQuery()

  // used as a check if active config is loading
  // starts out as True to initiate the page in a loading state
  const [isActiveConfigLoading, setIsActiveConfigLoading] = useState<boolean>(true)

  const { currentProject, currentModel } = useContext(CurrentProjectAndModelContext)

  const { data: activeConfiguration } = useQuery<AxiosResponse<ListClassificationModelOutputs>, AxiosError>(
    ["activeConfig", currentModel],
    () => KonanAPI.fetchActiveModelConfiguration(currentModel),
    {
      enabled: currentModel !== "" && !!currentProject && currentProject.type !== "generic",
      retry: false,
      onSettled: () => {
        setIsActiveConfigLoading(false)
      },
    },
  )

  // fetch model predictions data
  // labels, confusion matrix, coverage, accuracy, ...
  const { data: evaluationData, isLoading } = useQuery<AxiosResponse<ClassificationEvaluation>, AxiosError>(
    ["classification-evaluation", startDate, endDate, currentModel],
    () =>
      KonanAPI.fetchModelClassificationEvaluationMetrics(
        startDate?.toISOString() as string,
        endDate?.toISOString() as string,
        currentModel,
      ),
    { enabled: currentModel !== "" && startDate != null && endDate != null },
  )

  // fetch predictions in current project since date of creation
  // to check if it has successful predictions since the date of creation or not
  const { isLoading: isApiRequestsLoading, data: apiRequests } = useQuery<
    AxiosResponse<DeploymentPredictionsSummary>,
    AxiosError
  >(
    ["apiRequests", projectId, startDate, endDate],
    () =>
      KonanAPI.fetchApiRequestsSummary({
        project_uuid: projectId as string,
        start_date: currentProject?.created_at as string,
        // TODO:: check the requirements for this condition
        end_date: now?.toISOString() as string,
      }),
    {
      enabled: projectId !== "" && currentProject?.created_at !== undefined && endDate != null,
    },
  )

  // fetch model features (model predictions table data)
  const { isLoading: isModelPredictionsLoading, data: modelPredictions } = useQuery<
    AxiosResponse<PaginatedListClassificationModelOutputsList>,
    AxiosError
  >(
    ["classification-predictions", startDate, endDate, currentModel],
    () =>
      KonanAPI.fetchModelOutputPredictionsData(
        startDate?.toISOString() as string,
        endDate?.toISOString() as string,
        currentModel,
      ),
    {
      keepPreviousData: true,
      enabled: !!currentModel && !!startDate && !!endDate,
    },
  )

  //fetch last feedback in this project
  const { isLoading: isLastFeedbackLoading, data: lastFeedbackResponse } = useQuery<
    AxiosResponse<DeploymentLatestFeedbackDateTime>,
    AxiosError
  >(["lastFeedback", projectId], () => KonanAPI.fetchLastFeedback(projectId as string), {
    enabled: projectId !== "",
  })

  const isDataLoading = useMemo(() => {
    return isLoading || isModelPredictionsLoading || isLastFeedbackLoading || isApiRequestsLoading
  }, [isApiRequestsLoading, isLastFeedbackLoading, isLoading, isModelPredictionsLoading])

  //sort classification labels to maintain consistency across all cards
  const sortedLabels = useMemo(() => {
    if (evaluationData?.data.labels && evaluationData?.data.labels.length) {
      return evaluationData?.data.labels.toSorted((a, b) => a.name.localeCompare(b.name))
    }
    return []
  }, [evaluationData])

  const dynamicGraphHeight = useMemo(() => {
    return metricsBaseHeight + sortedLabels.length * 40
  }, [sortedLabels])

  // check the empty state of the metrics
  const metricsGraphsState = useMemo(() => {
    if (evaluationData?.data && apiRequests?.data && modelPredictions?.data) {
      const hasFeedback =
        lastFeedbackResponse?.data.latest_feedback_date_time !== null &&
        lastFeedbackResponse?.data.latest_feedback_date_time !== undefined &&
        lastFeedbackResponse?.data.latest_feedback_date_time !== ""

      return getModelPredictionsGraphsState(
        apiRequests?.data.successful ?? 0,
        modelPredictions.data.count ?? 0,
        evaluationData.data.observations_count,
        evaluationData.data.feedbacks_count,
        hasFeedback,
      )
    }
    return
  }, [apiRequests, evaluationData, lastFeedbackResponse, modelPredictions])

  // check the empty state of the coverage
  const coverageGraphState = useMemo(() => {
    if (evaluationData?.data && apiRequests?.data && modelPredictions?.data) {
      return getModelPredictionsCoverageState(
        apiRequests?.data.successful ?? 0,
        modelPredictions.data.count ?? 0,
        evaluationData.data.observations_count,
      )
    }
    return
  }, [evaluationData, modelPredictions, apiRequests])

  // convert confusion matrix and labels into nivo heatmap format
  const confusionMatrixData = useMemo(() => {
    if (evaluationData?.data.confusion_matrix && evaluationData?.data.confusion_matrix.length) {
      return convertDataToConfusionMatrix(evaluationData?.data.confusion_matrix ?? [], sortedLabels)
    }
    return []
  }, [evaluationData, sortedLabels])

  // convert classification labels into coverage pie chart (nivo format)
  const coveragePieChartData = useMemo(() => {
    if (sortedLabels.length && evaluationData && evaluationData.data) {
      return convertCoverageToPieChart(sortedLabels, evaluationData?.data?.observations_count ?? 1)
    }
    return []
  }, [evaluationData, sortedLabels])

  return (
    <Grid container spacing={2}>
      <KonanPageHeader
        title="Classification Metrics"
        subtitle={
          <React.Fragment>
            Insights into your classification model's performance.
            <br />
            Metrics are automatically calculated using the feedback provided to the model.
          </React.Fragment>
        }
        actions={[
          <DateRangePicker
            startDate={startDate}
            endDate={endDate}
            onStartDateChange={setStartDate}
            onEndDateChange={setEndDate}
            disableFuture
          />,
        ]}
      />

      {isActiveConfigLoading ? (
        <Grid container direction="column" justifyContent="center" alignItems="center" spacing={1}>
          <Box mt={3} />
          <CircularProgress size={36} />
          <Box mt={1} />
        </Grid>
      ) : !activeConfiguration ? (
        <KonanEmptyState
          setAction={() => navigate(`/projects/${projectId}/decision-engines?page=Configuration`)}
          title="Classification settings have not been configured yet."
          subTitle={"Please configure your model."}
          buttonText={"Configure"}
        />
      ) : (
        <Fragment>
          <Grid item xs={12}>
            <Grid container direction="row-reverse" justifyContent="flex-end" alignItems="flex-start" spacing={2}>
              {isDataLoading ? (
                <Grid item>
                  <Skeleton animation="wave" width={320} height={18} />
                </Grid>
              ) : !isDataLoading && sortedLabels && sortedLabels.length > 0 ? (
                <Fragment>
                  <Grid item>
                    <Tooltip title="Predictions / Actuals" placement="top" display="auto" verticalAlign="middle">
                      <InfoOutlinedIcon fontSize="inherit" />
                    </Tooltip>
                  </Grid>

                  {sortedLabels.map((item, index) => (
                    <CoverageBlock
                      key={item.name}
                      name={item.name}
                      predictedValue={item.predicted_occurances}
                      actualValue={item.actual_occurances}
                      colorIndex={index % customColorScheme().length}
                    />
                  ))}
                </Fragment>
              ) : (
                ""
              )}
            </Grid>
          </Grid>

          {/* Model Output section */}
          <Grid item xs={12}>
            <Grid container spacing={2}>
              <Grid item xs={12} md={6} lg={3}>
                <AccuracyCard
                  accuracy={metricsGraphsState?.showGraphs ? evaluationData?.data.accuracy ?? null : null}
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  cardHeight={graphHeight}
                  emptyState={metricsGraphsState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>

              <Grid item xs={12} md={6} lg={3}>
                <CoveragePieChart
                  title="COVERAGE"
                  data={coverageGraphState?.showGraphs ? coveragePieChartData : []}
                  totalCoveragePercentage={evaluationData?.data.coverage ?? 0}
                  graphHeight={graphHeight}
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  emptyState={coverageGraphState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>

              <Grid item xs={12} lg={6}>
                <ConfusionMatrix
                  title="CONFUSION MATRIX"
                  data={metricsGraphsState?.showGraphs ? confusionMatrixData : []}
                  graphHeight={graphHeight}
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  emptyState={metricsGraphsState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>

              <Grid item xs={12} md={4}>
                <MetricsBarChart
                  title="F1 Score"
                  data={metricsGraphsState?.showGraphs ? sortedLabels : []}
                  graphKey={"f1_score"}
                  graphHeight={
                    sortedLabels && sortedLabels.length && metricsGraphsState?.showGraphs
                      ? dynamicGraphHeight
                      : graphHeight
                  }
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  emptyState={metricsGraphsState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>

              <Grid item xs={12} md={4}>
                <MetricsBarChart
                  title="PRECISION"
                  data={metricsGraphsState?.showGraphs ? sortedLabels : []}
                  graphKey={"precision"}
                  graphHeight={
                    sortedLabels && sortedLabels.length && metricsGraphsState?.showGraphs
                      ? dynamicGraphHeight
                      : graphHeight
                  }
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  emptyState={metricsGraphsState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>

              <Grid item xs={12} md={4}>
                <MetricsBarChart
                  title="RECALL"
                  data={metricsGraphsState?.showGraphs ? sortedLabels : []}
                  graphKey={"recall"}
                  graphHeight={
                    sortedLabels && sortedLabels.length && metricsGraphsState?.showGraphs
                      ? dynamicGraphHeight
                      : graphHeight
                  }
                  isLoading={isDataLoading || !!(startDate && !endDate)}
                  emptyState={metricsGraphsState}
                  range={endDate?.diff(startDate, "days") ? endDate?.diff(startDate, "days") + 1 : undefined}
                />
              </Grid>
            </Grid>
          </Grid>
        </Fragment>
      )}
    </Grid>
  )
}
