import isElectron from "is-electron";
import Jimp from "jimp";
import * as ortWeb from "onnxruntime-web";
import { pixelToFloat, imageDataToTensor, arrayToImageData } from "../shared/utility";
import { IInferenceEngine } from "./IInferenceEngine";

class OnnxWebEngine implements IInferenceEngine {
  constructor(backendType: string) {
    this._backendType = backendType;
    this._modelPath = OnnxWebEngine.getModelPath(this._backendType);
  }

  initializeAsync = async () => {
    if (this._initializePromise === undefined) {
      this._initializePromise = (async () => {
        ortWeb.env.logLevel = "verbose";
        ortWeb.env.webgl.pack = true;
        this._session = await ortWeb.InferenceSession.create(this._modelPath, {
          executionProviders: [this._backendType],
        });
        console.log(`created ONNXWeb session for model: ${this._modelPath}\n`);
        console.log("input names: ", this._session.inputNames);
        console.log("output names: ", this._session.outputNames);
      })();
    }
    await this._initializePromise;
  };

  runAsync = async (inputImageUrl: string) => {
    await this.initializeAsync();

    const inputImage = await Jimp.read(inputImageUrl);
    
    const dims = [1, 3, 128, 128];
    const imageData = imageDataToTensor(inputImage, dims);
    //create the tensor object from onnxruntime-web.
    const imageTensor = new ortWeb.Tensor("float32", imageData, dims);
    const imageFeeds = { lr_input: imageTensor };
    const srResults = await this._session!.run(imageFeeds);
    const outputImage = await arrayToImageData(srResults.sr_output.data as Float32Array, 
                                                srResults.sr_output.dims[2],
                                                srResults.sr_output.dims[3]);
    return await outputImage.getBase64Async(outputImage.getMIME());
  };

  static getModelPath(backendType: string) {
    return backendType === "webgl" ? "./smallSRGANTrained.onnx" :"./srgan_generator.onnx";
  }

  _backendType: string;
  _modelPath: string;
  _initializePromise?: Promise<void>;
  _session?: ortWeb.InferenceSession = undefined;
}

export async function checkWasmEngineAvailable() {
  return true;
}

export async function checkWebGlEngineAvailable() {
  const canvasElement = document.createElement("canvas");
  const glContext = canvasElement.getContext("webgl");
  const gl2Context = canvasElement.getContext("webgl2");
  return (
    (!!glContext && glContext instanceof WebGLRenderingContext) ||
    (!!gl2Context && gl2Context instanceof WebGL2RenderingContext)
  );
}

export function createWasmEngine() {
  return new OnnxWebEngine("wasm");
}

export function createWebGlEngine() {
  return new OnnxWebEngine("webgl");
}

class OnnxNativeEngine implements IInferenceEngine {
  constructor(backendType: string) {
    this._backendType = backendType;
    this._modelPath = OnnxNativeEngine.getModelPath(this._backendType);
  }

  initializeAsync = async () => {
    if (this._initializePromise === undefined) {
      this._initializePromise = (async () => {
        console.log("init inference session");
        await (window as any).ipc.invoke("onnxnative-initialize-session", this._backendType, this._modelPath);
      })();
    }
    await this._initializePromise;
  };

  runAsync = async (inputImageUrl: string) => {
    await this.initializeAsync();

    const inputImage = await Jimp.read(inputImageUrl);
    const image = inputImage;
    const dims = [1, 3, 128, 128];

    // 1. Get buffer data from image and create R, G, and B arrays.
    var imageBufferData = image.bitmap.data;
    const redArray = [];
    const greenArray = [];
    const blueArray = [];

    // 2. Loop through the image buffer and extract the R, G, and B channels
    for (let i = 0; i < imageBufferData.length; i += 4) {
      redArray.push(imageBufferData[i]);
      greenArray.push(imageBufferData[i + 1]);
      blueArray.push(imageBufferData[i + 2]);
      // skip data[i + 3] to filter out the alpha channel
    }

    // 3. Concatenate RGB to transpose [128, 128, 3] -> [3, 128, 128] to a number array
    const transposedData = redArray.concat(greenArray).concat(blueArray);

    // 4. convert to float32
    let i,
      l = transposedData.length; // length, we need this for the loop

    // create the Float32Array size 3 * 128 * 128 for these dimensions output
    const float32Data = new Float32Array(dims[1] * dims[2] * dims[3]);
    for (i = 0; i < l; i++) {
      float32Data[i] = pixelToFloat(transposedData[i]); // convert to float
    }

    const outputFloatData: {data: Float32Array, dims: number[]} = await (window as any).ipc.invoke("onnxnative-run-inference", this._backendType, float32Data, dims);

    const outputImage = await arrayToImageData(outputFloatData.data, outputFloatData.dims[2], outputFloatData.dims[3]);
    return await outputImage.getBase64Async(outputImage.getMIME());
  };

  static getModelPath(backendType: string) {
    return backendType === "cuda" ? "./smallSRGANTrained.onnx" :"./srgan_generator.onnx";
  }

  _backendType: string;
  _modelPath: string;
  _initializePromise?: Promise<void>;
}

async function checkNativeBackendAvailable(backend: string) {
  if (!isElectron()) {
    return false;
  }

  //@ts-ignore
  //return chrome?.webview?.hostObjects?.dredge;

  return (window as any).ipc.invoke("onnxnative-check-backend-available", backend, OnnxNativeEngine.getModelPath(backend));
}

export async function checkNativeCpuEngineAvailable() {
  return checkNativeBackendAvailable("cpu");
}

export async function checkNativeCudaEngineAvailable() {
  return checkNativeBackendAvailable("cuda");
}

export function createNativeCpuEngine() {
  return new OnnxNativeEngine("cpu");
}

export function createNativeCudaEngine() {
  return new OnnxNativeEngine("cuda");
}
