import * as tf from '@tensorflow/tfjs';
import { Keypoint, Pose, calculators } from '@tensorflow-models/pose-detection';
import { applyTo, compose, divide, map, propOr, sum } from 'ramda';

export type LandmarkVector = [number, number, number, number];

const prepareLandmarks = (pose: Pose): LandmarkVector[] => {
  const extractLandmarks = map<Keypoint, LandmarkVector>(
    ({ x, y, z = 0, score = 0 }) => [x, y, z, score]
  );

  const _landmarks = extractLandmarks(pose.keypoints3D ?? []);
  const _normalizedLandmarks = extractLandmarks(
    calculators.keypointsToNormalizedKeypoints(pose.keypoints, {
      width: 360,
      height: 270,
    })
  );

  return [..._normalizedLandmarks, ..._landmarks];
};

type Embedder = (landmarks: LandmarkVector[][]) => tf.Tensor1D;

export const createUniversalEmbedding = (pose: Pose) => {
  const landmarks = prepareLandmarks(pose);

  return tf.tensor([landmarks]);
};

export const bodyInFrame = (pose: Pose) => {
  const landmarks = pose.keypoints;

  const _bodyInFrame = sum(
    map<Keypoint, number>(
      compose(applyTo(landmarks.length), divide, propOr(0, 'score')),
      landmarks
    )
  );

  return _bodyInFrame;
};

export const angle =
  (center: number, a: number, b: number): Embedder =>
  ([landmarks]) => {
    const _center = landmarks[center];
    const _a = landmarks[a];
    const _b = landmarks[b];

    const vectorA = tf.sub(_a, _center).dataSync();
    const vectorB = tf.sub(_b, _center).dataSync();

    const theta = Math.acos(
      (vectorA[0] * vectorB[0] + vectorA[1] * vectorB[1]) /
        (Math.sqrt(Math.pow(vectorA[0], 2) + Math.pow(vectorA[1], 2)) *
          Math.sqrt(Math.pow(vectorB[0], 2) + Math.pow(vectorB[1], 2)))
    );

    return tf.tensor([theta, theta, theta]) as tf.Tensor1D;
  };

export const averagePointArray =
  (first: number, second: number): Embedder =>
  ([landmarks]) => {
    const _first = landmarks[first];
    const _second = landmarks[second];

    return tf.add(_first, _second).div(2);
  };

export const landmark =
  (index: number): Embedder =>
  ([landmarks]) =>
    tf.tensor(landmarks[index]);

export const worldLandmark =
  (index: number): Embedder =>
  ([, worldLandmarks]) =>
    landmark(index)([worldLandmarks]);

export const distance =
  (first: number, second: number): Embedder =>
  ([landmarks]) => {
    const _first = landmarks[first];
    const _second = landmarks[second];

    return tf.sub(_second, _first);
  };

export const midPointDistance =
  (first: number, second: number, third: number, fourth: number): Embedder =>
  (landmarks) => {
    const _first = averagePointArray(third, fourth)(landmarks);
    const _second = averagePointArray(first, second)(landmarks);

    return _first.sub(_second);
  };
