import * as tf from "@tensorflow/tfjs"
import * as cocoSsd from "@tensorflow-models/coco-ssd"
import * as poseDetection from "@tensorflow-models/pose-detection"
import { RefObject, useEffect, useRef, useState } from "react"
import Webcam from "react-webcam"

export type Status = "on" | "off" | "blocked"

export type Violation = "noFace" | "unfairMeans"

// Check every 15 seconds
const INTERVAL = 15 * 1000

type Props = {
  disabled?: boolean
  videoRef: RefObject<Webcam>
}

const initialState = {
  peoplePresent: 1,
  faceDetected: true,
  unfairMeansDetected: false,
}

export function useFaceDetection({ disabled = false, videoRef }: Props) {
  const [violationType, setViolationType] = useState<Violation | null>(null)

  const [peoplePresent, setPeoplePresent] = useState(initialState.peoplePresent)
  const [faceDetected, setFaceDetected] = useState(initialState.faceDetected)
  const [unfairMeansDetected, setUnfairMeansDetected] = useState(
    initialState.unfairMeansDetected
  )

  const _faces = useRef<boolean[]>([])
  const _people = useRef<number[]>([])
  const _unfairMeans = useRef<boolean[]>([])

  const detectPerson = (detector: cocoSsd.ObjectDetection) => {
    let peoplePresent = 0
    let unfairMeans = false

    const video = videoRef.current?.video

    if (video && video.readyState === 4) {
      detector.detect(video).then(objects => {
        peoplePresent = objects.filter(
          (e: { class: string }) => e.class === "person"
        ).length

        unfairMeans = Boolean(
          objects.find((e: { class: string }) =>
            ["book", "cell phone", "remote"].includes(e.class)
          )
        )

        _people.current = (() => {
          // Take average of last 6 values
          const prevState =
            _people.current.length >= 7
              ? _people.current.slice(-6)
              : _people.current

          const inputs = [...prevState, peoplePresent]

          const failureCount = inputs.filter(x => x > 1).length
          const isFailure = failureCount >= 3

          setPeoplePresent(isFailure ? 2 : 1)
          return inputs
        })()

        _unfairMeans.current = (() => {
          // Take average of last 6 values
          const prevState =
            _unfairMeans.current.length >= 7
              ? _unfairMeans.current.slice(-6)
              : _unfairMeans.current

          const inputs = [...prevState, unfairMeans]

          const failureCount = inputs.filter(x => x === true).length
          const isViolating = failureCount >= 3

          setUnfairMeansDetected(isViolating)
          return inputs
        })()
      })
    }
  }

  const detectPose = (detector: poseDetection.PoseDetector) => {
    let testPasses = false

    const video = videoRef.current?.video

    if (video && video.readyState === 4) {
      detector
        .estimatePoses(video, {
          maxPoses: 1,
          flipHorizontal: false,
        })
        .then(objects => {
          if (objects.length > 0) {
            const leftEye = objects[0].keypoints.find(
              x => x.name === "left_eye"
            )
            const rightEye = objects[0].keypoints.find(
              x => x.name === "right_eye"
            )

            // console.log({ leftEye, rightEye })
            if (leftEye && rightEye) {
              const score = (leftEye.score || 0) + (rightEye.score || 0)

              if (score > 0.6) {
                const diff = Math.abs(leftEye.x - rightEye.x)
                const isLookingAway = diff <= 15

                if (!isLookingAway) {
                  testPasses = true
                }
              }

              // setUserIsPayingAttention(isLookingAway === false)
              // console.log({ isLookingAway, diff })
            }

            // if (!leftEye && !rightEye) {
            //   setFaceDetected(false)
            // }
          }

          _faces.current = (() => {
            const numberOfValuesToConsider = 8
            const failureThreshold = 6

            const prevState =
              _faces.current.length >= numberOfValuesToConsider
                ? _faces.current.slice(-(numberOfValuesToConsider - 1))
                : _faces.current
            const inputs = [...prevState, testPasses]

            /**
             * If the last two entries in the array are true, then we assume that the user has just returned and we let them attempt
             */
            const isAFK =
              inputs.length >= failureThreshold &&
              inputs.slice(-failureThreshold).every(x => x === false)

            const isValid = !isAFK

            setFaceDetected(isValid)
            return inputs
          })()
        })
        .catch(console.error)
    }
  }

  useEffect(() => {
    // if (peoplePresent > 1) {
    //   setViolationType("multipleFaces")
    //   return
    // }

    if (faceDetected === false) {
      setViolationType("noFace")
      return
    }

    if (unfairMeansDetected) {
      setViolationType("unfairMeans")
      return
    }

    setViolationType(null)
  }, [peoplePresent, faceDetected, unfairMeansDetected])

  useEffect(() => {
    if (disabled) return

    let interval1: ReturnType<typeof setTimeout>

    const initFaceDetection = async () => {
      const model = poseDetection.SupportedModels.MoveNet
      poseDetection
        .createDetector(model, {
          modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING,
          // minPoseScore: 0.3,
          // // @ts-expect-error ts issue
          // quantBytes: 4,
          // // @ts-expect-error ts issue
          // architecture: "MobileNetV1",
          // // @ts-expect-error ts issue
          // outputStride: 16,
          // inputResolution: { width: 500, height: 500 },
          // // @ts-expect-error ts issue
          // multiplier: 0.75,
          // enableTracking: false,
        })
        .then(detector => {
          // const tick = performance.now()
          // await detectPose(detector)
          // const tock = performance.now()
          // console.log(`Main thread took ${tock - tick}ms`)

          interval1 = setInterval(() => {
            detectPose(detector)
          }, INTERVAL)
        })
    }

    let interval2: ReturnType<typeof setTimeout>
    const initObjectDetection = async () => {
      cocoSsd.load().then(detector => {
        // const tick = performance.now()
        // await detectPerson(detector)
        // const tock = performance.now()
        // console.log(`Main thread took ${tock - tick}ms`)

        interval2 = setInterval(() => {
          detectPerson(detector)
        }, INTERVAL)
      })
    }

    tf.setBackend("webgl").then(webGlAvailable => {
      console.log(`WEBGL READY: ${webGlAvailable}`)

      if (!webGlAvailable) {
        return
      } else {
        tf.ready().then(() => {
          initFaceDetection()
          initObjectDetection()
        })
      }
    })

    return () => {
      clearTimeout(interval1)
      clearTimeout(interval2)
      tf.dispose()
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [])

  const clearViolations = () => {
    _faces.current = []
    _people.current = []
    _unfairMeans.current = []

    setFaceDetected(initialState.faceDetected)
    setPeoplePresent(initialState.peoplePresent)
    setUnfairMeansDetected(initialState.unfairMeansDetected)
  }

  return {
    faceDetected,
    unfairMeansDetected,
    clearViolations,
    personDetected: peoplePresent !== 0,
    // multipleFacesDetected: violationType === "multipleFaces",
    violationType,
  } as const
}
