import { Accordion, AccordionPanel, Box, Button, Heading, Text } from 'grommet';
import produce from 'immer';
import React, { useState } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { DatasetType, EventFeatureType, FlagFeatureType, LayerType, ModelType, NeuralNetworkType, SignalFeatureType } from '../../../codegen/models/models';
import { IRequestable } from '../../../core/api';
import EditableText from '../../../core/components/EditableText';
import ErrorCat from '../../../core/components/ErrorCat';
import InputRow from '../../../core/components/InputRow';
import ModelEditor from '../../../core/components/ModelEditor';
import { generateModel, SchemaInterface } from '../../../core/components/Utils';
import { IApplicationState } from '../../../core/state';
import { addLayer, updateModel } from '../actions';


export const NeuralNetworkBuilder = ({ model, schema }: { model: ModelType, schema: SchemaInterface }) => {
    if (!model.algorithm) {
        return null
    }
    const neuralNetwork = model.algorithm as unknown as NeuralNetworkType
    return <Box className='NeuralNetworkBuilder' gap="small">
        <Box pad='xsmall'>
            <Heading level='3' margin='small'>Optimiser Setup</Heading>
            <ModelEditor schema={schema.properties.optimizer}
                model={model}
                property='algorithm.optimizer'
                allowRange
            />
            <Heading level='3' margin='small'>Loss Function Setup</Heading>
            <ModelEditor schema={schema.properties.lossFunction}
                model={model}
                property='algorithm.lossFunction'
                allowRange
            />
            <Heading level='3' margin='small'>Data Setup</Heading>
            <ModelEditor schema={schema}
                model={model}
                property='algorithm'
                excluded={['lossFunction', 'optimizer', 'layers']}
                allowRange
            />
        </Box>
        <Heading level='3'>Layers</Heading>
        <Box>
            <Layers model={model} neuralNetwork={neuralNetwork} schema={schema.properties.layers} />
        </Box>
        <LayerControls model={model} schema={schema.properties.layers} index={((neuralNetwork.layers.length) - 1)} />
    </Box>
};

export const Layer = ({ layer, model, schema, onUpdated, onRemoved, background, index, numLayers }:
    {
        layer: LayerType,
        model: ModelType,
        schema: SchemaInterface,
        onUpdated: (p: LayerType) => any,
        onRemoved?: () => void,
        background?: string,
        index: number
        numLayers: number
    }) => {
    const excluded = ["sizeIn", "sizeOut"]
    const dispatch = useDispatch()
    return <AccordionPanel label={layer.layerName}>
        <Box background={background}>
            {
                index === 0 || index === numLayers - 1 ?
                    <Box pad='large' fill align='center'><Text>Layer automatically generated based on other layers and input/output data</Text></Box> :
                    <Box gap='small' pad='small'>
                        <InputRow label="Layer Size">
                            <EditableText
                                initialValue={layer.sizeOut?.toString()}
                                focus={false}
                                placeholder='Enter Layer Size...'
                                onFinishEdit={v => { v && dispatch(updateModel(updateLayerSizes(parseInt(v, 10), model, index))) }}
                                onAbortEdit={() => {}}
                            />
                        </InputRow>
                        <ModelEditor schema={schema} model={model} property={`algorithm.layers[${index}]`} excluded={excluded} allowRange={true} />
                        {
                            onRemoved &&
                            <Button secondary label='Remove Layer' onClick={onRemoved} />
                        }
                    </Box>
            }
        </Box>
    </AccordionPanel>
}

export const Layers = ({ model, neuralNetwork, schema }:
    { model: ModelType, neuralNetwork: NeuralNetworkType, schema: SchemaInterface }) => {
    const dispatch = useDispatch()
    const datasets = useSelector<IApplicationState, IRequestable<DatasetType[]>>(s => s.data.datasets)
    const layerSchema = schema.items
    if (!layerSchema) {
        return <ErrorCat message={'Could not load layer schema'}/>
    }
    const numLayers = neuralNetwork.layers.length
    const currentDataset = datasets.value?.find(d => d.id === model.datasetId)
    if(!currentDataset){
        return <ErrorCat message={'Could not find a Selected Dataset!'}/>
    }
    if (numLayers === 0) {
        const newLayers = buildBaseNetwork(model, currentDataset, layerSchema)
        dispatch(updateModel(produce(model, m => {
            (m.algorithm as unknown as NeuralNetworkType).layers = newLayers
        })))
    }
    return <Accordion>
        {
            neuralNetwork.layers.map((item: LayerType, index: number) =>
                <Box key={`layer-${index}`}>
                    {
                        (index !== 0 && index !== numLayers - 1) ?
                            <Layer
                                key={`layer-${index}`}
                                index={index}
                                numLayers={numLayers}
                                layer={item}
                                model={model}
                                schema={layerSchema}
                                onUpdated={p => {
                                    dispatch(updateModel(produce(model, m => { (m.algorithm as unknown as NeuralNetworkType).layers[index] = p })))
                                }}
                                onRemoved={() => {
                                    dispatch(updateModel(produce(model, m => { (((m.algorithm as unknown as NeuralNetworkType).layers).splice(index, 1)) })))
                                }}
                            /> :
                            <Box background="background-contrast" key={`layer-${index}`}>
                                <Layer
                                    index={index}
                                    numLayers={numLayers}
                                    layer={item}
                                    model={model}
                                    schema={layerSchema}
                                    onUpdated={p => {
                                        dispatch(updateModel(produce(model, m => { (m.algorithm as unknown as NeuralNetworkType).layers[index] = p })))
                                    }}
                                    background="background-contrast"
                                />
                            </Box>
                    }
                </Box>)
        }
    </Accordion>
}

export const LayerControls = ({ model, schema, index }: { model: ModelType, schema: SchemaInterface, index: number }) => {
    const [indexCounter, setIndexCounter] = useState<number>(1)
    const dispatch = useDispatch()
    const layerSchema = schema.items
    if (!layerSchema) {
        return <Text>Could not load layer schema</Text>
    }
    return <Box align="center" fill="vertical" direction='row' gap="small">
        <Box basis='1/3'>
            <Button secondary label='Remove All' onClick={() => {
                const newLayers: LayerType[] = []
                dispatch(updateModel(produce(model, m => { (m.algorithm as unknown as NeuralNetworkType).layers = newLayers })))
            }} />
        </Box>
        <Box basis='2/3'>
            <Button secondary label='Add Layer' onClick={() => {
                dispatch(addLayer({ model: model, layerSchema: layerSchema, layerIndex: index, layerName: `layer_${indexCounter}` }))
                setIndexCounter(indexCounter + 1)
            }} />
        </Box>
    </Box>
}

//--------------UTILS-------------------
const updateLayerSizes = (v: number, model: ModelType, index: number) => {
    return produce(model, m => {
        if (index !== 0 || ((m.algorithm as unknown as NeuralNetworkType).layers.length - 1)) {
            let newLayers = (m.algorithm as unknown as NeuralNetworkType).layers
            newLayers[index].sizeOut = v;
            newLayers[index].sizeIn = newLayers[index - 1].sizeOut;
            newLayers[(index + 1)].sizeIn = v;
            (m.algorithm as unknown as NeuralNetworkType).layers = newLayers
        }
    }
    )
}

const buildBaseNetwork = (model: ModelType, currentDataset: DatasetType, layerSchema: SchemaInterface ) => {
    let numFeatures = 0
    let numTargets = 0
    if (currentDataset?.properties?.schemaRef === 'sift_dataset_properties.schema.json') {
        currentDataset.properties.features.forEach((feature: SignalFeatureType | FlagFeatureType | EventFeatureType) => {
            if (feature.featureType === 'feature') { numFeatures += 1 }
            if (feature.featureType === 'target') { numTargets += 1 }
        })
    }
    const newLayers: LayerType[] = [
        buildLayer(layerSchema, 'input_layer', numFeatures, numFeatures), 
        buildLayer(layerSchema, 'hidden_layer', numFeatures, 16), 
        buildLayer(layerSchema, 'output_layer', 16, numTargets)
    ]
    return newLayers
}

const buildLayer = (layerSchema: SchemaInterface, layerName: string, sizeIn: number, sizeOut: number) => {
    var newLayer = generateModel(layerSchema, "model.schema.json#/definitions/layer") as LayerType
    newLayer.layerName = layerName
    newLayer.sizeIn = sizeIn
    newLayer.sizeOut = sizeOut
    newLayer.activation = 'relu'
    newLayer.layerType = {
        "schemaRef": "model.schema.json#/definitions/linear"
    }
    return newLayer
}

export default NeuralNetworkBuilder;