// Start of Selection
import { Target } from '@resistapp/common/assays';
import { friendlyFoldChange } from '@resistapp/common/friendly';
import { getLog2FoldChange } from '@resistapp/common/statistics/fold-change';
import {
  getResistanceGeneIndexAndLevel,
  getResistanceLevel,
  ResistanceLevel,
} from '@resistapp/common/statistics/resistance-index';
import { getRiskScore } from '@resistapp/common/statistics/risk-score';
import { FullAbundance, MetricMode } from '@resistapp/common/types';
import { interpolateRgb, piecewise } from 'd3-interpolate';
import { isNil } from 'lodash';
import { theme } from '../components/shared/theme';
import { resistanceLevelMetadata } from '../data-utils/resistance-level';

export const metricRange = {
  [MetricMode.ARGI]: {
    min: 0,
    max: 100,
  },
  [MetricMode.REDUCTION]: {
    min: -6,
    max: 6,
    optionalMin: -5,
    optionalMax: 5,
    mapLegendTicks: ['-5', '-4', '-3', '-2', '-1', '0', '1', '2', '3', '4', '5'],
  },
  [MetricMode.RISK]: {
    min: 0,
    max: 100,
  },
} as const;

export type MetricAndLevel = [number | null, ResistanceLevel | null];

export function getMetricAndLevel(
  abundances: FullAbundance[],
  afterAbundances: FullAbundance[] | undefined,
  targets: Target[],
  metricMode: MetricMode,
): MetricAndLevel {
  if (metricMode === MetricMode.ARGI) {
    const { resistanceIndex, resistanceLevel } = getResistanceGeneIndexAndLevel(abundances, targets);
    return [resistanceIndex, resistanceLevel];
  } else if (metricMode === MetricMode.RISK) {
    const riskScore = getRiskScore(afterAbundances || abundances, targets);
    return [riskScore, null];
  } else {
    const reductionScore = getLog2FoldChange(
      abundances,
      afterAbundances,
      targets.length === 1 ? targets[0] : undefined,
    );
    return [reductionScore, null];
  }
}

const colorScales = {
  // we force the return type here, since in piecewise it's "any"
  [MetricMode.REDUCTION]: piecewise(interpolateRgb, [
    theme.colors.reduction100,
    theme.colors.reduction0,
    theme.colors.reductionMinus100,
  ]) as (t: number) => string,
  [MetricMode.RISK]: interpolateRgb(theme.colors.riskScore0, theme.colors.riskScore100),
};
export function getMetricColor(value: number | null, metricMode: MetricMode): string {
  if (value === null) {
    return theme.colors.neutral700;
  }
  if (metricMode === MetricMode.ARGI) {
    const level = getResistanceLevel(value);
    return resistanceLevelMetadata[level].color;
  }
  const min = metricMode === MetricMode.REDUCTION ? metricRange[MetricMode.REDUCTION].min : 0;
  const max = metricMode === MetricMode.REDUCTION ? metricRange[MetricMode.REDUCTION].max : 100;
  const normalizedValue = (value - min) / (max - min);

  return colorScales[metricMode](normalizedValue);
}

export function getMetricTextColor(value: number | null, metricMode: MetricMode): string {
  if (value === null) {
    return 'black';
  }
  if (metricMode === MetricMode.ARGI) {
    const level = getResistanceLevel(value);
    return resistanceLevelMetadata[level].textColor;
  } else if (metricMode === MetricMode.RISK) {
    return value > metricRange[MetricMode.RISK].max / 2 ? 'white' : 'black';
  } else {
    return value > metricRange[MetricMode.REDUCTION].max / 2 || value < metricRange[MetricMode.REDUCTION].min / 2
      ? 'white'
      : 'black';
  }
}

export function friendlyMetricValue(metric: number | null, metricMode: MetricMode): string {
  if (isNil(metric)) return '-';
  return metricMode === MetricMode.ARGI
    ? metric.toFixed(1)
    : metricMode === MetricMode.REDUCTION
      ? friendlyFoldChange(metric)
      : metric.toFixed(0);
}
