import * as tf from "@tensorflow/tfjs";
import * as tmImage from "@teachablemachine/image";

// import { base64StringToArrayBuffer } from "./strings";

import { ModelProvider } from "../services";

export class TfModel {

    constructor(provider, model, classes) {
        this.provider = provider;
        this.model = model;
        this.classes = classes;
    }

    async predict(image) {
        if (this.provider === ModelProvider.TeachableMachine) {
            const prediction = await this.model.predict(image, false);

            let maxValue = 0, maxValueIndex = -1, result = [];
            prediction.forEach((p, index) => {
                const v = Math.round(p.probability * 1000) / 10;
                if (p.probability > maxValue) {
                    maxValue = p.probability;
                    maxValueIndex = index;
                }
                result.push({
                    label: this.classes[index].name,
                    value: v === 100 ? 99.9 : (v < 0.1 ? 0.1 : v),
                    color: `graph-${index % 4}`,
                    weight: (maxValue === index) ? "bold" : "normal",
                });
            });

            result[maxValueIndex].weight = "bold";

            return new Promise(function(resolve, reject) {
                resolve(result);
            });
        }
        else if (this.provider === ModelProvider.Tensorflow) {
            const base64Data = await fetch(image);
            const blob = await base64Data.blob();
            const url = URL.createObjectURL(blob);

            return new Promise((resolve, reject) => {
                let img = new Image();

                img.onload = async () => {
                    const x = await tf.browser.fromPixels(img).resizeNearestNeighbor([24, 24]).div(255.0).expandDims();
                    const probabilities = await this.model.predict(x).data();
                    const prob = await Math.max.apply(null, probabilities);
                    const maxValue = probabilities.indexOf(prob);

                    let result = [];
                    probabilities.forEach((probability, index) => {
                        const v = Math.round(probability * 1000) / 10;
                        result.push({
                            label: this.classes[index].name,
                            value: v === 100 ? 99.9 : (v < 0.1 ? 0.1 : v),
                            color: `graph-${index % 4}`,
                            weight: (maxValue === index) ? "bold" : "normal",
                        });
                    });

                    resolve(result);
                };

                img.onerror = reject;
                img.src = url;
            });
        }

        return new Promise(function(resolve, reject) {
            reject(new Error("Invalid provider!"));
        });
    }

}

export async function createTfModel(modelInfo, classes) {
    let model;

    if (modelInfo.provider === ModelProvider.TeachableMachine) {
        const modelInfoUrl = modelInfo.externalUrl + "model.json";
        const metaDataUrl = modelInfo.externalUrl + "metadata.json";
        model = await tmImage.load(modelInfoUrl, metaDataUrl);
    }
    else if (modelInfo.provider === ModelProvider.Tensorflow) {
        if (modelInfo.metadata && modelInfo.weights) {
            const jsonFile = await convertFile(modelInfo.metadata, "metadata.json", "application/json");
            const weightsFile = await convertFile(modelInfo.weights, "model.weights.bin", "application/octet-stream");
            let path = tf.io.browserFiles([jsonFile, weightsFile]);
            model = await tf.loadLayersModel(path);
        }
        else {
            model = modelInfo.model;
        }
    }
    else {
        throw new Error("This is rejected!");
    }

    return new TfModel(modelInfo.provider, model, classes);
}

export async function train(modelName, classes) {
    let images = [];

    for (const clazz of classes) {
        for (const imageObj of clazz.images) {
            const base64Data = await fetch(imageObj);
            const blob = await base64Data.blob();
            const img = await loadImage(URL.createObjectURL(blob), clazz.name);
            images.push(img);
        }
    }

    const model = await createConvModel(classes.length);

    let xs, ys;
    for (const image of images) {
        [xs, ys] = await makeDataset(xs, ys, image, classes);
    }

    await model.fit(xs, ys, {
        epochs: 10
    });

    return model;
}

export const createModelArtifacts = (model) => {
    const tfjsLayersVersion = tf.version["tfjs-layers"];
    return {
        format: "layers-model",
        generatedBy: `TensorFlow.js tfjs-layers ${tfjsLayersVersion}`,
        convertedBy: "null",
        modelTopology: {
            class_name: model.getClassName(),
            config: model.getConfig(),
            keras_version: `tfjs-layers ${tfjsLayersVersion}`,
            backend: "tensorflow.js",
        },
        weightsManifest: [{
            paths: ["./model.weights.bin"],
            weights: model.getWeights(),
        }],
    };
}

export const createModelWeightsData = async (model) => {
    const namedWeights = [];
    const trainableOnly = false;
    const weights = trainableOnly ? model.trainableWeights : model.weights;
    const weightValues = model.getWeights(trainableOnly);
    for (let i = 0; i < weights.length; ++i) {
        if (trainableOnly && !weights[i].trainable) {
            // Optionally skip non-trainable weights.
            continue;
        }
        namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
    }
    const weightDataAndSpecs = await tf.io.encodeWeights(namedWeights)
    return weightDataAndSpecs.data;
};

const createConvModel = async(size) => {
    const model = tf.sequential();

    model.add(tf.layers.conv2d({
        filters: 32,
        kernelSize: 3,
        activation: "relu",
        inputShape: [24, 24, 3],
        strides: 1,
        kernelInitializer: "varianceScaling"
    }));
    model.add(tf.layers.conv2d({
        filters: 64,
        kernelSize: 3,
        activation: "relu"
    }));
    model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2]
    }));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({
        units: 128,
        activation: "relu"
    }));
    model.add(tf.layers.dense({
        units: size,
        activation: "softmax"
    }));
    model.compile({
        loss: "categoricalCrossentropy",
        optimizer: tf.train.adam(),
        metrics: ["accuracy"]
    });

    return model;
};

const makeDataset = async(xs, ys, img, classes) => {
    return tf.tidy(() => {
        const t = tf.browser.fromPixels(img).resizeNearestNeighbor([24, 24]).div(255.0).expandDims();
        (xs === undefined) ? xs = t : xs = xs.concat(t);
        classes.forEach((clazz, index) => {
            if (clazz.name === img.id) {
                const l = tf.oneHot(index, classes.length).expandDims();
                (ys === undefined) ? ys = l : ys = ys.concat(l);
            }
        });
        return [xs, ys];
    });
};

const loadImage = async(url, className) => {
    return new Promise((resolve, reject) => {
        const image = new Image();
        image.id = className;
        image.addEventListener("load", e => resolve(image));
        image.addEventListener("error", () => {
            reject(new Error(`Failed to load image's URL: ${url}`));
        });
        image.src = url;
    });
};

const convertFile = async(url, filename, type) => {
    let response = await fetch(url);
    let data = await response.blob();
    let metadata = {
        type: type,
    };
    return new File([data], filename, metadata);
}

// function extractMetadata(modelName) {
//     return JSON.stringify({
//         ...JSON.parse(localStorage.getItem(`tensorflowjs_models/${modelName}/model_metadata`)),
//         modelTopology : JSON.parse(indexedDB.getItem(`tensorflowjs_models/${modelName}/model_topology`)),
//         weightsManifest: [
//             { paths: ["./model.weights.bin"], weights: JSON.parse(localStorage.getItem(`tensorflowjs_models/${modelName}/weight_specs`)) }
//         ],
//     });
// }
//
// function extractWeightsFile(modelName) {
//     const weightBuffer = tf.io.CompositeArrayBuffer.join(
//         base64StringToArrayBuffer(indexedDB.getItem(`tensorflowjs_models/${modelName}/weight_data`))
//     );
//     return new Blob([weightBuffer], { type: "application/octet-stream" });
// }
