import { Tensor } from "onnxruntime-web";

import { Point } from "../components/EditMode";

export interface ModelScaleProps {
  samScale: number;
  height: number;
  width: number;
}

export interface ModelDataProps {
  clicks?: Array<Point>;
  tensor: Tensor;
  modelScale: ModelScaleProps;
}

export const modelData = ({ clicks, tensor, modelScale }: ModelDataProps) => {
  const imageEmbedding = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;

  // Check there are input click prompts
  if (clicks) {
    // console.log("Coordinates...");
    // console.log(clicks);
    const n = clicks.length;

    // If there is no box input, a single padding point with
    // label -1 and coordinates (0.0, 0.0) should be concatenated
    // so initialize the array to support (n + 1) points.
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    // Add clicks and scale to what SAM expects
    for (let i = 0; i < n; i++) {
      pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
      pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
      pointLabels[i] = clicks[i].label;
    }

    // Add in the extra point/label when only clicks and no box
    // The extra point is at (0, 0) with label -1
    pointCoords[2 * n] = 0.0;
    pointCoords[2 * n + 1] = 0.0;
    pointLabels[n] = -1.0;

    // Create the tensor
    pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
    pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
    // console.log("Click tensor...");
    // console.log(pointCoords);
    // console.log(pointCoordsTensor);
  }
  const imageSizeTensor = new Tensor("float32", [
    modelScale.height,
    modelScale.width,
  ]);

  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
    throw Error("No tensor found!");

  // There is no previous mask, so default to an empty tensor
  const maskInput = new Tensor(
    "float32",
    new Float32Array(256 * 256),
    [1, 1, 256, 256],
  );
  // There is no previous mask, so default to 0
  const hasMaskInput = new Tensor("float32", [0]);

  const shape = [4, 1, 64, 64, 1280];
  const data = new Float32Array(shape.reduce((a, b) => a * b));
  const intermEmbeddings = new Tensor("float32", data, shape);
  intermEmbeddings.data.fill(Math.random());

  return {
    interm_embeddings: intermEmbeddings,
    image_embeddings: imageEmbedding,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    orig_im_size: imageSizeTensor,
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
};

export const handleImageScale = (image: HTMLImageElement) => {
  const LONG_SIDE_LENGTH = 1024;
  const w = image.naturalWidth;
  const h = image.naturalHeight;
  const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
  return { height: h, width: w, samScale };
};
