import pandas as pd
import math
from datetime import datetime
from catboost import CatBoostRegressor
from sklearn.ensemble import RandomForestRegressor

from sklearn.model_selection import train_test_split
from bots.botlibs.labeling_lib import get_labels_r
from bots.botlibs.tester_lib import test_model_r

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType


def get_prices() -> pd.DataFrame:
    p = pd.read_csv('../utils/' + hyper_params['symbol'] + '.csv', sep='\s+')
    pFixed = pd.DataFrame(columns=['time', 'close'])
    pFixed['time'] = p['<DATE>'] + ' ' + p['<TIME>']
    pFixed['time'] = pd.to_datetime(pFixed['time'], format='mixed')
    pFixed['close'] = p['<CLOSE>']
    pFixed.set_index('time', inplace=True)
    pFixed.index = pd.to_datetime(pFixed.index, unit='s')
    return pFixed.dropna()


def get_features(data: pd.DataFrame) -> pd.DataFrame:
    pFixed = data.copy()
    pFixedC = data.copy()
    count = 0

    for i in hyper_params['periods']:
        pFixed[str(count)] = pFixedC.rolling(i).std()
        count += 1

    return pFixed.dropna()


def  meta_learners(data, models_number: int, iterations: int, depth: int):
    data = data.copy()
    data = data[(data.index < hyper_params['forward']) & (data.index > hyper_params['backward'])].copy()

    X = data[data.columns[1:-1]]
    y = data['labels']
    data['meta_labels'] = 0

    for i in range(models_number):
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, train_size=0.5, test_size=0.5, shuffle=True)

        # fit debias model with train and validation subsets
        meta_m = CatBoostRegressor(iterations=iterations,
                                   depth=depth,
                                   verbose=False,
                                   use_best_model=True)

        meta_m.fit(X_train, y_train, eval_set=(X_val, y_val), plot=False)

        coreset = X.copy()
        coreset['labels'] = y
        coreset['labels_pred'] = meta_m.predict(X)
        data['meta_labels'] += abs(coreset['labels'] - coreset['labels_pred'])

    data['meta_labels'] = data['meta_labels'] / models_number
    return data


def fit_final_models(dataset, tol=1e-2) -> list:
    # features for model\meta models. We learn main model only on filtered labels
    X = dataset[dataset['meta_labels'] < tol]
    X, X_meta = X[X.columns[1:-2]], dataset[dataset.columns[1:-2]]
    # labels for model\meta models
    y = dataset[dataset['meta_labels'] < tol]
    y, y_meta = y[y.columns[-2]], dataset[dataset.columns[-1]]

    # fit main model with train and validation subsets
    model = RandomForestRegressor(n_estimators=50, max_depth=10)
    model.fit(X, y)
    # fit meta model with train and validation subsets
    meta_model = RandomForestRegressor(n_estimators=50, max_depth=10)
    meta_model.fit(X_meta, y_meta)

    data = get_features(get_prices())
    R2 = test_model_r(data,
                      [model, meta_model],
                      hyper_params['stop_loss'],
                      hyper_params['take_profit'],
                      hyper_params['forward'],
                      hyper_params['backward'],
                      hyper_params['markup'],
                      plt=False)

    if math.isnan(R2):
        R2 = -1.0
        print('R2 is fixed to -1.0')
    print('R2: ' + str(R2))
    result = [R2, model, meta_model]
    return result


def export_model_to_ONNX(**kwargs):
    model = kwargs.get('model')
    symbol = kwargs.get('symbol')
    periods = kwargs.get('periods')
    periods_meta = kwargs.get('periods_meta')
    model_number = kwargs.get('model_number')
    export_path = kwargs.get('export_path')

    initial_type = [('float_input', FloatTensorType([None, len(hyper_params['periods'])]))]
    onnx_model = convert_sklearn(model[1], initial_types=initial_type)
    # save main model to ONNX
    with open(export_path + 'catmodel ' + symbol + ' ' + str(model_number) + '.onnx', "wb") as f:
        f.write(onnx_model.SerializeToString())
    onnx_model_meta = convert_sklearn(model[2], initial_types=initial_type)
    # save meta model to ONNX
    with open(export_path + 'catmodel_m ' + symbol + ' ' + str(model_number) + '.onnx', "wb") as f:
        f.write(onnx_model_meta.SerializeToString())

    code = '#include <Math\Stat\Math.mqh>'
    code += '\n'
    code += '#resource "catmodel ' + symbol + ' ' + str(
        model_number) + '.onnx" as uchar ExtModel_' + symbol + '_' + str(model_number) + '[]'
    code += '\n'
    code += '#resource "catmodel_m ' + symbol + ' ' + str(
        model_number) + '.onnx" as uchar ExtModel2_' + symbol + '_' + str(model_number) + '[]'
    code += '\n\n'
    code += 'int Periods' + symbol + '_' + str(model_number) + '[' + str(len(periods)) + \
            '] = {' + ','.join(map(str, periods)) + '};'
    code += '\n'
    code += 'int Periods_m' + symbol + '_' + str(model_number) + '[' + str(len(periods_meta)) + \
            '] = {' + ','.join(map(str, periods_meta)) + '};'
    code += '\n\n'

    # get features
    code += 'void fill_arays' + symbol + '_' + str(model_number) + '( double &features[]) {\n'
    code += '   double pr[], ret[];\n'
    code += '   ArrayResize(ret, 1);\n'
    code += '   for(int i=ArraySize(Periods' + symbol + '_' + str(model_number) + ')-1; i>=0; i--) {\n'
    code += '       CopyClose(NULL,PERIOD_H1,1,Periods' + symbol + '_' + str(model_number) + '[i],pr);\n'
    code += '       ret[0] = MathStandardDeviation(pr);\n'
    code += '       ArrayInsert(features, ret, ArraySize(features), 0, WHOLE_ARRAY); }\n'
    code += '   ArraySetAsSeries(features, true);\n'
    code += '}\n\n'

    # get features
    code += 'void fill_arays_m' + symbol + '_' + str(model_number) + '( double &features[]) {\n'
    code += '   double pr[], ret[];\n'
    code += '   ArrayResize(ret, 1);\n'
    code += '   for(int i=ArraySize(Periods_m' + symbol + '_' + str(model_number) + ')-1; i>=0; i--) {\n'
    code += '       CopyClose(NULL,PERIOD_H1,1,Periods_m' + symbol + '_' + str(model_number) + '[i],pr);\n'
    code += '       ret[0] = MathStandardDeviation(pr);\n'
    code += '       ArrayInsert(features, ret, ArraySize(features), 0, WHOLE_ARRAY); }\n'
    code += '   ArraySetAsSeries(features, true);\n'
    code += '}\n\n'

    file = open(export_path + str(symbol) + ' ONNX include' + ' ' + str(model_number) + '.mqh', "w")
    file.write(code)

    file.close()
    print('The file ' + 'ONNX include' + '.mqh ' + 'has been written to disk')


hyper_params = {
    'symbol': 'EURUSD_H1_ORIG',
    'export_path': 'C:/Users/guanlinqi/AppData/Roaming/MetaQuotes/Terminal/29E91DA909EB4475AB204481D1C2CE7D/MQL5/Include/Trend following/',
    'model_number': 0,
    'markup': 0.00010,
    'stop_loss': 0.02000,
    'take_profit': 0.02000,
    'periods': [i for i in range(5, 100, 15)],
    'backward': datetime(2010, 1, 1),
    'forward': datetime(2024, 1, 1),
}

models = []
for i in range(1):
    print('Learn ' + str(i) + ' model')
    data = get_labels_r(get_features(get_prices()), min=1, max=15)
    dataset = meta_learners(data=data, models_number=5, iterations=50, depth=3)
    models.append(fit_final_models(dataset, tol=1e-2))

models.sort(key=lambda x: x[0])
data = get_features(get_prices())
test_model_r(data,
             models[-1][1:],
             hyper_params['stop_loss'],
             hyper_params['take_profit'],
             hyper_params['forward'],
             hyper_params['backward'],
             hyper_params['markup'],
             plt=True)

export_model_to_ONNX(model=models[-1],
                     symbol=hyper_params['symbol'],
                     periods=hyper_params['periods'],
                     periods_meta=hyper_params['periods'],
                     model_number=hyper_params['model_number'],
                     export_path=hyper_params['export_path'])
