import { Target } from '@resistapp/common/assays';
import { getResistanceGeneIndex } from '@resistapp/common/statistics/resistance-index';
import {
  Aggregation,
  calcExperimentalRiskScore,
  geometricMean,
  getAbundancesInScope,
  getRiskScore,
  getValues,
  IndexVersion,
  Scope,
} from '@resistapp/common/statistics/risk-score';
import { FullAbundance } from '@resistapp/common/types';
import { replaceZerosWithLod } from '@resistapp/common/utils';
import { mean } from 'lodash';

export interface ExperimentalMetrics {
  ARGI: number | undefined;
  ARI: number | undefined;
  ARGI_GEOMETRIC: number | undefined;
  ARGI_QUADRATIC: number | undefined;
  ARGI_ARITHMETIC: number | undefined;
  RISK_AQI: number | undefined;
  RISK_MEAN_OF_ANTIB_MAX: number | undefined;
  RISK_DETECTED: number | undefined;
  RISK_ARITHMETIC: number | undefined;
  RISK_KARINA: number | undefined;
}

export function buildExperimentalMetrix(
  abundances: FullAbundance[],
  antibioticTargets?: Target[],
): ExperimentalMetrics {
  return {
    ARGI: buildAGRI(abundances, Scope.ANALYSED, IndexVersion.ARI, Aggregation.GEOMETRIC), // ARGI
    ARI: buildAGRI(abundances, Scope.DETECTED, IndexVersion.ARI, Aggregation.ARITHMETIC), // NOTE: the custom implementation of these other versions are not looking at same set of genes
    ARGI_GEOMETRIC: buildAGRI(abundances, Scope.ANALYSED, IndexVersion.ARGI, Aggregation.GEOMETRIC),
    ARGI_QUADRATIC: buildAGRI(abundances, Scope.ANALYSED, IndexVersion.ARGI, Aggregation.QUADRATIC),
    ARGI_ARITHMETIC: buildAGRI(abundances, Scope.ANALYSED, IndexVersion.ARGI, Aggregation.ARITHMETIC),
    RISK_AQI: calcExperimentalRiskScore(
      abundances,
      Scope.ANALYSED,
      IndexVersion.RISK,
      Aggregation.MAX,
      antibioticTargets,
    ),
    RISK_MEAN_OF_ANTIB_MAX: calcExperimentalRiskScore(
      abundances,
      Scope.ANALYSED,
      IndexVersion.RISK,
      Aggregation.ARITHMETIC_MAX,
      antibioticTargets,
    ),
    RISK_DETECTED: calcExperimentalRiskScore(
      abundances,
      Scope.DETECTED,
      IndexVersion.RISK,
      Aggregation.GEOMETRIC,
      antibioticTargets,
    ),
    RISK_ARITHMETIC: calcExperimentalRiskScore(
      abundances,
      Scope.ANALYSED,
      IndexVersion.RISK,
      Aggregation.ARITHMETIC,
      antibioticTargets,
    ),
    RISK_KARINA: getRiskScore(abundances, undefined) ?? undefined,
  };
}

export function buildAGRI(
  abundances: FullAbundance[],
  scope: Scope,
  index: IndexVersion,
  aggregation: Aggregation,
): number | undefined {
  const inScopeAbundances = getAbundancesInScope(abundances, 'relative', scope, true);
  const valuesInScope = getValues(inScopeAbundances, 'relative', scope);
  if (index === IndexVersion.ARI && aggregation === Aggregation.ARITHMETIC && scope === Scope.DETECTED) {
    return getResistanceGeneIndex(inScopeAbundances, undefined) || undefined;
  } else if (index === IndexVersion.ARI && aggregation === Aggregation.GEOMETRIC && scope === Scope.ANALYSED) {
    return 5 + mean(toLogSpace(valuesInScope, -5));
  } else if (index === IndexVersion.ARGI && aggregation === Aggregation.GEOMETRIC) {
    return geometricMean(replaceZerosWithLod(toLogSpace(valuesInScope, -5)));
  } else if (index === IndexVersion.ARGI && aggregation === Aggregation.QUADRATIC) {
    return quadraticMeanInLogSpace(valuesInScope);
  } else if (index === IndexVersion.ARGI && aggregation === Aggregation.ARITHMETIC) {
    return 5 + Math.log10(mean(valuesInScope));
  } else {
    throw Error();
  }
}

function toLogSpace(values: number[], zeroReplacement: number): number[] {
  return values.map(v => (v ? Math.log10(v) : zeroReplacement));
}

export function quadraticMeanInLogSpace(values: number[]): number {
  if (values.length === 0) {
    return 0;
  }
  const sumOfSquares = values.map(v => (v ? Math.log10(v) : -5)).reduce((acc, value) => acc + value ** 2, 0);
  return Math.sqrt(sumOfSquares / values.length);
}
