import '@tensorflow/tfjs-backend-webgl';
import * as tflite from '@tensorflow/tfjs-tflite';
import * as tf from '@tensorflow/tfjs-core';
import {
  combine,
  createEffect,
  createEvent,
  createStore,
  fromObservable,
  restore,
  Store,
} from 'effector';
import debug from 'debug';
import { bodyInFrame } from '../../utils/pose';
import { prop, objOf, mergeRight, assoc, pick } from 'ramda';
import { map as map$, Subject, groupBy, mergeMap, scan } from 'rxjs';
import {
  createDetector,
  Pose,
  PoseDetector,
  SupportedModels,
} from '@tensorflow-models/pose-detection';
import { throttle } from 'patronum';
import { getPoseConfig } from './config';
import { ModelName } from './classifiers';

const logger = debug('pose');
const landmarkLogger = debug('pose:landmarks');
const confidenceLogger = debug('pose:confidence');

tflite.setWasmPath(
  'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.9/dist/'
);

type Estimations = {
  model: ModelName;
  confidence: number;
};

const confidence$ = new Subject<Estimations>();

const initializeRuntime = async () => {
  logger('initializing runtime');
  const configs = getPoseConfig();

  tf.enableProdMode();
  tf.env().setFlags(configs.flags);

  await tf.setBackend('webgl');

  await tf.ready();

  const _detector = await createDetector(
    SupportedModels.BlazePose,
    configs.model
  );

  logger('runtime initialized');

  return _detector;
};

const getMediaStream = async () => {
  const stream = navigator.mediaDevices.getUserMedia({
    audio: false,
    video: {
      facingMode: 'user',
      ...getPoseConfig().videoInput,
    },
  });

  return stream;
};

export const setVideoElement = createEvent<HTMLVideoElement>();
export const initializeCamera = createEvent('initialize_camera');
export const releaseMediaStream = createEvent('release_media_stream');
export const initializePoseDetector = createEvent('initialize_pose_detector');
export const startPredictionLoop = createEvent(
  'start_prediction_loop'
).prepend<boolean>((mirrorLandmarks = false) => mirrorLandmarks);
export const stopPredictionLoop = createEvent('stop_prediction_loop');
export const predictionLoop = createEvent<number>('prediction_loop');
export const useClassifiers = createEvent<string[]>('use_classifiers');
export const tfLandmarks = createEvent<Pose>();

export const _loadTfModelFx = createEffect(async (filename: string) => {
  logger('loading model', filename);
  const _tfModel = await tflite.loadTFLiteModel(`/models/${filename}`, {
    numThreads: 1,
  });

  return _tfModel;
});

export const confidenceFx = createEffect((estimations: Estimations) => {
  const { model, confidence } = estimations;

  confidenceLogger(model, confidence);

  confidence$.next({ model, confidence });
});

export const initializeCameraFx = createEffect(
  async (videoElement: HTMLVideoElement) => {
    logger('initializing camera');
    const stream = await getMediaStream();

    videoElement.srcObject = stream;
    videoElement.play();

    await new Promise<void>((resolve) => {
      const handler = () => {
        videoElement.removeEventListener('loadeddata', handler);
        resolve();
      };

      videoElement.addEventListener('loadeddata', handler);
    });

    return stream;
  }
);

export const $rawVideoElement = createStore(document.createElement('video')).on(
  initializeCameraFx.doneData,
  (video, stream) => {
    const { videoInput } = getPoseConfig();

    video.width = videoInput.width;
    video.height = videoInput.height;
    video.srcObject = stream;
    video.muted = true;
    video.playsInline = true;
    video.play();
  }
);

export const releaseMediaStreamFx = createEffect((stream: MediaStream) => {
  logger('releasing camera');
  stream.getTracks().forEach((track) => track.stop());
});

export const initializePoseDetectorFx = createEffect(async () => {
  const detector = await initializeRuntime();

  return detector;
});

export const resetPoseDetectorFx = createEffect((detector: PoseDetector) => {
  detector.reset();
});

export const $flipCanvasCtx = createStore(document.createElement('canvas')).map(
  (canvas) => {
    const { videoInput } = getPoseConfig();
    canvas.width = videoInput.width;
    canvas.height = videoInput.height;

    const ctx = canvas.getContext('2d');

    if (ctx === null) throw new Error('Error getting render context');

    ctx.scale(-1, 1);

    return ctx;
  }
);

export const predictionLoopFx = createEffect(
  async ([detector, videoElement, flipCtx]: [
    PoseDetector,
    HTMLVideoElement,
    CanvasRenderingContext2D,
    number,
    boolean
  ]) => {
    const now = Date.now();

    flipCtx.drawImage(
      videoElement,
      flipCtx.canvas.width * -1,
      0,
      flipCtx.canvas.width,
      flipCtx.canvas.height
    );

    try {
      const [results] = await detector.estimatePoses(
        flipCtx.canvas,
        {
          flipHorizontal: false,
          maxPoses: 1,
        },
        now
      );

      if (results) {
        landmarkLogger('landmarks', results);
        tfLandmarks(results);
      }
    } catch (e) {
      logger(e);
    }

    requestAnimationFrame(() => {
      predictionLoop(now);
    });
  }
);

export type $Pose = {
  initialized: boolean;
  detector: PoseDetector | null;
  videoElement: HTMLVideoElement | null;
  predictionLoopRunning: boolean;
  mediaStream: MediaStream | null;
  mirrorLandmarks: boolean;
};

export const $pose = createStore<$Pose>({
  initialized: false,
  detector: null,
  videoElement: null,
  predictionLoopRunning: false,
  mediaStream: null,
  mirrorLandmarks: true,
})
  .on(initializePoseDetectorFx.doneData, (state, detector) =>
    mergeRight(state, {
      initialized: true,
      detector,
    })
  )
  .on(initializeCameraFx.doneData, (state, mediaStream) =>
    assoc('mediaStream', mediaStream, state)
  )
  .on(releaseMediaStreamFx.done, assoc('mediaStream', null))
  .on(startPredictionLoop, (state, mirrorLandmarks) =>
    mergeRight(state, {
      predictionLoopRunning: true,
      mirrorLandmarks,
    })
  )
  .on(stopPredictionLoop, assoc('predictionLoopRunning', false))
  .on(setVideoElement, (state, videoElement) =>
    assoc('videoElement', videoElement, state)
  );

const inFramePercentage = throttle({ source: tfLandmarks, timeout: 100 })
  .map(bodyInFrame)
  .map((p) => Math.round(p * 100));

export const $bodyInFramePercentage = restore(inFramePercentage, 0);

export const $usedClassifiers = createStore<ModelName[]>([]);

export const agregatedConfidence$ = confidence$.pipe(
  groupBy(prop('model')),
  mergeMap((model$) =>
    model$.pipe(map$(prop('confidence')), map$(objOf(model$.key)))
  ),
  scan<Record<ModelName, number>>(mergeRight)
);

export const $agregatedConfidence: Store<Record<ModelName, number>> = combine(
  restore(
    fromObservable<Record<ModelName[number], number>>(agregatedConfidence$),
    {
      NoAction: 1,
    }
  ),
  $usedClassifiers,
  (confidence: Record<ModelName, number>, classifiers) =>
    pick(classifiers, confidence)
);
