import ndarray from "ndarray";
import ops from "ndarray-ops";
import { Tensor } from "onnxruntime-web";
import { round } from "lodash";

export const adjustFaceCoordinates = (face) => {
  face.box.yMin = Math.max(0, face.box.yMin - face.box.height / 2);
  face.box.height = face.box.yMax - face.box.yMin;

  return face;
};

export const resizeCanvasCtx = (
  video,
  box,
  targetWidth,
  targetHeight,
  mirror
) => {
  const canvas = document.createElement("canvas", { willReadFrequently: true });
  canvas.width = targetWidth;
  canvas.height = targetHeight;

  const xCoord = mirror ? video.videoWidth - box.width - box.xMin : box.xMin;

  // Draw the source canvas into the target canvas
  if (video.videoWidth < 1400 && video.videoHeight < 1400) {
    canvas
      .getContext("2d")
      .drawImage(
        video,
        xCoord,
        box.yMin,
        box.width,
        box.height,
        0,
        0,
        targetWidth,
        targetHeight
      );
  } else {
    // For bigger images it is faster to reduce it to 1/4th size and then scale it.
    const quadCanvas = document.createElement("canvas");
    quadCanvas.width = box.width;
    quadCanvas.height = box.height;
    quadCanvas
      .getContext("2d")
      .drawImage(
        video,
        xCoord,
        box.yMin,
        box.width,
        box.height,
        0,
        0,
        quadCanvas.width,
        quadCanvas.height
      );
    canvas
      .getContext("2d")
      .drawImage(quadCanvas, 0, 0, targetWidth, targetHeight);
  }

  // Get a new rendering context for the new canvas
  const resizedCtx = canvas.getContext("2d");
  return resizedCtx;
};

export const lightPreprocess = (imageData) => {
  const { data, width, height } = imageData;
  // data processing
  const dataTensor = ndarray(new Float32Array(data), [width, height, 4]);
  const dataProcessedTensor = ndarray(new Float32Array(width * height * 3), [
    1,
    3,
    width,
    height,
  ]);

  ops.divseq(dataTensor, 255);

  ops.assign(
    dataProcessedTensor.pick(0, 0, null, null),
    ops.divseq(ops.subseq(dataTensor.pick(null, null, 0), 0.485), 0.229)
  );
  ops.assign(
    dataProcessedTensor.pick(0, 1, null, null),
    ops.divseq(ops.subseq(dataTensor.pick(null, null, 1), 0.456), 0.224)
  );
  ops.assign(
    dataProcessedTensor.pick(0, 2, null, null),
    ops.divseq(ops.subseq(dataTensor.pick(null, null, 2), 0.406), 0.225)
  );

  const tensor = new Tensor("float32", new Float32Array(width * height * 3), [
    1,
    3,
    width,
    height,
  ]);

  tensor.data.set(dataProcessedTensor.data);
  return tensor;
};

export const headPreprocess = (video, modelResolution, mirror) => {
  const resizedCtx = resizeCanvasCtx(
    video,
    null,
    modelResolution,
    modelResolution,
    mirror
  );

  const imageData = resizedCtx.getImageData(
    0,
    0,
    modelResolution,
    modelResolution
  );
  const { data, width, height } = imageData;
  // data processing
  const dataTensor = ndarray(new Float32Array(data), [width, height, 4]);
  const dataProcessedTensor = ndarray(new Float32Array(width * height * 3), [
    1,
    3,
    width,
    height,
  ]);

  ops.assign(
    dataProcessedTensor.pick(0, 0, null, null),
    dataTensor.pick(null, null, 0)
  );
  ops.assign(
    dataProcessedTensor.pick(0, 1, null, null),
    dataTensor.pick(null, null, 1)
  );
  ops.assign(
    dataProcessedTensor.pick(0, 2, null, null),
    dataTensor.pick(null, null, 2)
  );

  const tensor = new Tensor("float32", new Float32Array(width * height * 3), [
    1,
    3,
    width,
    height,
  ]);

  tensor.data.set(dataProcessedTensor.data);
  return tensor;
};

export const headPostprocess = async (tensor, ctx) => {
  const dx = ctx.canvas.width / 320;
  const dy = ctx.canvas.height / 320;

  ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

  if (tensor.score.data.length) {
    let [score] = tensor.score.data;
    [score] = [score].map((x) => round(x * 100, 1));
    let [x0, y0, x1, y1] = tensor.box.data;
    [x0, y0, x1, y1] = [x0, y0, x1, y1].map((x) => round(x));
    // check value and return score + box cordinates

    // scale to canvas size
    [x0, x1] = [x0, x1].map((x) => x * dx);
    [y0, y1] = [y0, y1].map((x) => x * dy);

    // Make sure values are within bounds
    const validXRegion = 0.9 * ctx.canvas.width;
    const validYRegion = 0.9 * ctx.canvas.height;

    const validXMin = (ctx.canvas.width - validXRegion) / 2;
    const validXMax = validXMin + validXRegion;
    const validYMin = (ctx.canvas.height - validYRegion) / 2;
    const validYMax = validYMin + validYRegion;

    ctx.strokeStyle = "rgb(255, 255, 0)";
    ctx.lineWidth = 3;
    ctx.strokeRect(x0, y0, x1 - x0, y1 - y0);

    // fillrect with transparent color
    ctx.fillStyle = "rgba(255, 255, 0, 0.2)";
    ctx.fillRect(x0, y0, x1 - x0, y1 - y0);
    return (
      score > 0.85 &&
      x0 >= validXMin &&
      x1 <= validXMax &&
      y0 >= validYMin &&
      y1 <= validYMax
    );
  }
  return false;
};
