/**
 * Modified from @tensorflow/tfjs-core/dist/io/http.js
 */

import { loadWeightsAsArrayBuffer } from "@tensorflow/tfjs-core/dist/io/weights_loader";
import CryptoJS from "crypto-js";
import * as tf from "@tensorflow/tfjs";

export class EncryptedHttpLoader implements tf.io.IOHandler {
  protected readonly path: string;
  protected readonly requestInit: RequestInit;

  private readonly fetch: typeof fetch;
  private readonly weightUrlConverter: (weightName: string) => Promise<string>;

  readonly DEFAULT_METHOD = "POST";

  static readonly URL_SCHEME_REGEX = /^https?:\/\//;
  protected readonly modelKey: string;
  private readonly weightPathPrefix: string;
  private readonly loadOptions: tf.io.LoadOptions;

  constructor(modelKey: string, path: string, loadOptions?: tf.io.LoadOptions) {
    if (loadOptions == null) {
      loadOptions = {};
    }
    this.weightPathPrefix = loadOptions.weightPathPrefix;
    this.weightUrlConverter = loadOptions.weightUrlConverter;

    this.fetch = tf.util.fetch;

    tf.util.assert(
      path != null && path.length > 0,
      () => "URL path for http must not be null, undefined or " + "empty."
    );

    if (Array.isArray(path)) {
      tf.util.assert(
        path.length === 2,
        () =>
          "URL paths for http must have a length of 2, " +
          `(actual length is ${path.length}).`
      );
    }
    this.modelKey = modelKey;
    this.path = path;

    if (
      loadOptions.requestInit != null &&
      loadOptions.requestInit.body != null
    ) {
      throw new Error(
        "requestInit is expected to have no pre-existing body, but has one."
      );
    }
    this.requestInit = loadOptions.requestInit || {};
    this.loadOptions = loadOptions;
  }

  private async loadModelJSON(): Promise<tf.io.ModelJSON> {
    const modelConfigRequest = await this.fetch(this.path, this.requestInit);

    if (!modelConfigRequest.ok) {
      throw new Error(
        `Request to ${this.path} failed with status code ` +
          `${modelConfigRequest.status}. Please verify this URL points to ` +
          `the model JSON of the model to load.`
      );
    }
    let modelJSON: tf.io.ModelJSON;

    // load encrypted model json file
    let cryptedJSON;
    try {
      cryptedJSON = await modelConfigRequest.text();
    } catch (e) {
      let message = `Failed to parse model JSON of response from ${this.path}. Please make sure the server is serving valid  JSON for this request.`;
      throw new Error(message);
    }

    // decrypt json
    try {
      //let key = 'myeyeshurtnow';
      //console.log(`the modelKey is ${this.modelKey}`);
      const pass = CryptoJS.SHA3(this.modelKey);
      const decrypted = CryptoJS.RC4.decrypt(cryptedJSON, pass); // Return the decrypted data as WordArray
      var typedArray = convertWordArrayToUint8Array(decrypted); // Convert the WordArray into a Uint8Array
      let jsonDecoded = new TextDecoder().decode(typedArray);
      modelJSON = JSON.parse(jsonDecoded);
    } catch (error) {
      console.log("wrong password!", error);
    }

    // We do not allow both modelTopology and weightsManifest to be missing.
    const modelTopology = modelJSON.modelTopology;
    const weightsManifest = modelJSON.weightsManifest;
    if (modelTopology == null && weightsManifest == null) {
      throw new Error(
        `The JSON from HTTP path ${this.path} contains neither model ` +
          `topology or manifest for weights.`
      );
    }

    return modelJSON;
  }

  /**
   * Load model artifacts via HTTP request(s).
   *
   * See the documentation to `tf.io.http` for details on the saved
   * artifacts.
   *
   * @returns The loaded model artifacts (if loading succeeds).
   */
  async load(): Promise<tf.io.ModelArtifacts> {
    const modelJSON = await this.loadModelJSON();
    return tf.io.getModelArtifactsForJSON(modelJSON, (weightsManifest) =>
      this.loadWeights(weightsManifest)
    );
  }

  private async getWeightUrls(
    weightsManifest: tf.io.WeightsManifestConfig
  ): Promise<string[]> {
    const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
    const [prefix, suffix] = parseUrl(weightPath);
    const pathPrefix = this.weightPathPrefix || prefix;

    const fetchURLs: string[] = [];
    const urlPromises: Array<Promise<string>> = [];
    for (const weightsGroup of weightsManifest) {
      for (const path of weightsGroup.paths) {
        if (this.weightUrlConverter != null) {
          urlPromises.push(this.weightUrlConverter(path));
        } else {
          fetchURLs.push(pathPrefix + path + suffix);
        }
      }
    }

    if (this.weightUrlConverter) {
      fetchURLs.push(...(await Promise.all(urlPromises)));
    }
    return fetchURLs;
  }

  private async loadWeights(
    weightsManifest: tf.io.WeightsManifestConfig
  ): Promise<[tf.io.WeightsManifestEntry[], tf.io.WeightData]> {
    const fetchURLs = await this.getWeightUrls(weightsManifest);
    const weightSpecs = tf.io.getWeightSpecs(weightsManifest);

    const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
    return [weightSpecs, buffers];
  }
}

/**
 * Extract the prefix and suffix of the url, where the prefix is the path before
 * the last file, and suffix is the search params after the last file.
 * ```
 * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
 * [prefix, suffix] = parseUrl(url)
 * // prefix = 'http://tfhub.dev/model/1/'
 * // suffix = '?tfjs-format=file'
 * ```
 * @param url the model url to be parsed.
 */
export function parseUrl(url: string): [string, string] {
  const lastSlash = url.lastIndexOf("/");
  const lastSearchParam = url.lastIndexOf("?");
  const prefix = url.substring(0, lastSlash);
  const suffix =
    lastSearchParam > lastSlash ? url.substring(lastSearchParam) : "";
  return [prefix + "/", suffix];
}

/** convertWordArrayToUint8Array */
function convertWordArrayToUint8Array(wordArray) {
  var arrayOfWords = wordArray.hasOwnProperty("words") ? wordArray.words : [];
  var length = wordArray.hasOwnProperty("sigBytes")
    ? wordArray.sigBytes
    : arrayOfWords.length * 4;
  var uInt8Array = new Uint8Array(length),
    index = 0,
    word,
    i;
  for (i = 0; i < length; i++) {
    word = arrayOfWords[i];
    uInt8Array[index++] = word >> 24;
    uInt8Array[index++] = (word >> 16) & 0xff;
    uInt8Array[index++] = (word >> 8) & 0xff;
    uInt8Array[index++] = word & 0xff;
  }
  return uInt8Array;
}

