import { attach, restore, sample } from 'effector';
import { equals, includes, map } from 'ramda';
import { TFLiteModel } from '@tensorflow/tfjs-tflite';
import * as tf from '@tensorflow/tfjs-core';
import { $usedClassifiers, confidenceFx, tfLandmarks, _loadTfModelFx } from '.';
import { createUniversalEmbedding } from '../../utils/pose';

const MODELS = [
  'BendForwardFlexible',
  'BendForwardSide',
  'BicycleCrunchLeft',
  'BicycleCrunchRight',
  'BurpeeDown',
  'BurpeeUp',
  'CrossFeetTapBackLeft',
  'CrossFeetTapBackRight',
  'CrossFeetTapFrontLeft',
  'CrossFeetTapFrontRight',
  'FeetTapLeft',
  'FeetTapRight',
  'GluteBridgeDown',
  'GluteBridgeUp',
  'HighKneeLeft',
  'HighKneeRight',
  'JumpingJackOut',
  'KickBackLeft',
  'KickBackMiddle',
  'KickBackRight',
  'KneePushDown',
  'KneePushUp',
  'LegRaiseDown',
  'LegRaiseUp',
  'LungeLeft',
  'LungeRight',
  'MountainClimberLeft',
  'MountainClimberRight',
  'Plank',
  'PunchLeft',
  'PunchRight',
  'PushDown',
  'PushUp',
  'RunningLeft',
  'RunningMiddle',
  'RunningRight',
  'SideLungeLeft',
  'SideLungeRight',
  'SitupDown',
  'SitupUp',
  'SquatDown',
  'StandingCrunchKneeLeft',
  'StandingCrunchKneeRight',
  'StandingCrunchLegLeft',
  'StandingCrunchLegRight',
  'StandingFront',
  'StandingFrontBoxing',
  'StandingFrontFlexible',
  'StandingFrontWide',
  'StandingFrontWithHandDown',
  'StandingHandDownOrBoxing',
  'StandingKickLeft',
  'StandingKickRight',
  'StandingSide',
  'StretchHands',
  'WalkOutDown',
  'WindmillLeft',
  'WindmillMiddle',
  'WindmillRight',
  'HandRaiseLeft',
  'HandRaiseRight',
] as const;

export type ModelName = 'NoAction' | typeof MODELS[number];

const createClassifier = (name: ModelName) => {
  const loadModelFx = attach({
    effect: _loadTfModelFx,
    mapParams: () => `${name}.tflite`,
    name,
  });

  const $model = restore(loadModelFx, null);

  const $clasifierUsed = $usedClassifiers.map(includes(name));

  sample({
    clock: sample({
      clock: tfLandmarks,
      source: $clasifierUsed,
      filter: equals(true),
      fn: (_, pose) => pose,
    }),
    source: $model,
    filter: (model): model is TFLiteModel => model !== null,
    fn: (model: TFLiteModel, pose) => {
      const confidence = tf.tidy(() => {
        const embed = createUniversalEmbedding(pose);

        try {
          const result = model.predict(embed) as tf.Tensor2D;

          const _data = result.dataSync();
          const estimations = _data[0];

          return estimations;
        } catch (e) {
          console.error(e);
        }

        return 0;
      });

      return { confidence, model: name };
    },
    target: confidenceFx,
    greedy: true,
  });

  return { $model, loadModelFx };
};

const classifiers = map(createClassifier, MODELS);

export default classifiers;
