import MetaTrader5 as mt5
import logging
import math
from datetime import datetime

logger = logging.getLogger("RiskManager")

STRICT_MODE = True

class RiskManager:
    def __init__(self, config):
        self.config = config
        rm = config.get('risk_management', {})
        self.risk_percent = rm.get('risk_per_trade', 0.01)  # 1% default
        self.max_open_risk_percent = rm.get('max_total_risk_percent', 5.0)  # 5% max overall risk
        self.max_lots = rm.get('max_lot_size', 0.10)  # 0.10 lots max
        self.min_lots = rm.get('min_lot_size', 0.01)
        self.tolerance = 0.05  # 5% tolerance
        self.debug_mode = True
        self.partial_2r_done_tickets = set()
        self.min_rr = rm.get('min_risk_reward_ratio', 1.5)  # 1.5 minimum RR
        
        logger.info(f"[RISK-INIT] Loaded config: risk_percent={self.risk_percent}, max_open_risk_percent={self.max_open_risk_percent}, max_lots={self.max_lots}, min_lots={self.min_lots}, min_rr={self.min_rr}")

    def calculate_position_size(self, entry, stop_loss, symbol):
        """Calculate position size based on risk percentage and stop loss distance."""
        symbol_info = mt5.symbol_info(symbol)
        if symbol_info is None:
            logger.error(f"Symbol info not found for {symbol}")
            return 0.0

        account_info = mt5.account_info()
        if account_info is None:
            logger.error("Account info not available")
            return 0.0

        account_balance = account_info.balance
        risk_amount = account_balance * self.risk_percent
        
        logger.info(f"[RISK-CALC] Account balance: {account_balance}, Risk percent: {self.risk_percent}, Risk amount: {risk_amount:.2f}")

        point = symbol_info.point
        contract_size = symbol_info.trade_contract_size
        min_lot = getattr(symbol_info, 'volume_min', self.min_lots)
        max_lot = min(getattr(symbol_info, 'volume_max', 100.0), self.max_lots)

        stop_distance_pips = abs(entry - stop_loss) / point
        pip_value_per_lot = point * contract_size

        if stop_distance_pips == 0 or pip_value_per_lot == 0:
            logger.error("Invalid stop distance or pip value for position sizing.")
            return 0.0

        raw_lot_size = risk_amount / (stop_distance_pips * pip_value_per_lot)
        final_lot = max(min_lot, min(raw_lot_size, max_lot))

        if final_lot < min_lot:
            logger.warning(f"Lot size {final_lot} below minimum allowed {min_lot}. Trade blocked.")
            return 0.0

        logger.info(f"Calculated Position Size: {final_lot} lots for {symbol} (Risk: {risk_amount:.2f}, SL Dist: {stop_distance_pips:.5f})")
        return round(final_lot, 2)

    def validate_signal(self, signal, mt5_interface):
        """Validate trade signal for risk management rules."""
        # Price validation
        for k in ['entry', 'stop_loss', 'take_profit']:
            if k not in signal or signal[k] is None:
                logger.error(f"Signal missing or invalid {k}")
                return False

        dir_val = signal['direction']
        entry = signal['entry']
        stop = signal['stop_loss']
        tp = signal['take_profit']

        # RR check
        if dir_val in [1, "BUY", "buy", "LONG", "long"]:
            if not (tp > entry > stop):
                logger.error("Directional price logic failed for BUY")
                return False
            risk = abs(entry - stop)
            reward = abs(tp - entry)
        elif dir_val in [-1, "SELL", "sell", "SHORT", "short"]:
            if not (tp < entry < stop):
                logger.error("Directional price logic failed for SELL")
                return False
            risk = abs(stop - entry)
            reward = abs(entry - tp)
        else:
            logger.error(f"Invalid direction value: {dir_val}")
            return False

        if risk == 0:
            logger.error("Risk is zero. Trade blocked.")
            return False

        rr = reward / risk
        if rr < self.min_rr:
            logger.error(f"Risk-Reward ratio {rr:.2f} is below minimum allowed {self.min_rr}. Trade blocked.")
            return False

        logger.info(f"Trade validation passed. R:R = {rr:.2f}")
        return True

    def total_open_risk(self, mt5_interface):
        """Calculate total open risk across all positions."""
        positions = mt5_interface.get_open_positions() if hasattr(mt5_interface, 'get_open_positions') else []
        account_info = mt5_interface.get_account_info() if hasattr(mt5_interface, 'get_account_info') else None
        
        if not positions or not account_info:
            logger.warning("[GLOBAL RISK] Could not retrieve open positions or account info. Assuming zero open risk.")
            return 0.0

        total_risk = 0.0
        for pos in positions:
            symbol_info = mt5_interface.get_symbol_info(pos['symbol']) if hasattr(mt5_interface, 'get_symbol_info') else None
            if not symbol_info or pos['sl'] is None:
                continue
            point = symbol_info['point']
            pip_value = symbol_info['trade_tick_value']
            stop_distance_pips = abs(pos['price_open'] - pos['sl']) / point
            risk = pos['volume'] * stop_distance_pips * pip_value
            total_risk += risk

        logger.info(f"[TOTAL OPEN RISK] {total_risk:.2f} (account currency)")
        return total_risk

    def would_exceed_max_open_risk(self, new_trade_risk, mt5_interface):
        """Check if new trade would exceed maximum open risk."""
        account_info = mt5_interface.get_account_info() if hasattr(mt5_interface, 'get_account_info') else None
        if not account_info:
            logger.warning("[GLOBAL RISK] Could not retrieve account info. Allowing trade by default.")
            return False

        max_allowed = account_info['balance'] * self.max_open_risk_percent / 100.0
        total_risk = self.total_open_risk(mt5_interface)
        
        if total_risk + new_trade_risk > max_allowed:
            logger.error(f"[GLOBAL RISK BLOCK] Trade blocked: Total open risk ({total_risk + new_trade_risk:.2f}) would exceed max allowed ({max_allowed:.2f}, {self.max_open_risk_percent}% of balance).")
            return True
        return False

    def manage_partial_close_and_breakeven(self, symbol, mt5_interface=None, notification_manager=None):
        """Manage breakeven at 1R only (no partial close)."""
        if mt5_interface is None:
            logger.error("mt5_interface must be provided to manage_partial_close_and_breakeven.")
            return

        positions = mt5_interface.get_open_positions(symbol=symbol) if hasattr(mt5_interface, 'get_open_positions') else []
        symbol_info = mt5_interface.get_symbol_info(symbol) if hasattr(mt5_interface, 'get_symbol_info') else None

        if symbol_info is None:
            logger.error(f"Symbol info not found for {symbol}")
            return

        for pos in positions or []:
            entry = pos['price_open']
            sl = pos['sl']
            ticket = pos['ticket']
            direction = pos['type']
            comment = pos.get('comment', '')

            # Direction handling
            is_buy = str(direction).upper() in ['0', 'BUY', 'LONG']
            is_sell = str(direction).upper() in ['1', 'SELL', 'SHORT']

            if not is_buy and not is_sell:
                logger.warning(f"[MANAGE-BREAKEVEN] Unknown direction '{direction}' for ticket {ticket}, skipping.")
                continue

            # Get current price
            tick = mt5_interface.get_symbol_tick(symbol)
            if not tick:
                logger.error(f"[MANAGE-BREAKEVEN] No tick data for {symbol} when managing ticket {ticket}")
                continue

            current_price = tick['bid'] if is_buy else tick['ask']

            # Calculate risk and 1R level
            risk = abs(entry - sl)
            if risk < 1e-6:
                logger.info(f"[MANAGE-BREAKEVEN] Position {ticket} has invalid risk (entry={entry}, sl={sl}), skipping.")
                continue

            # Calculate 1R level
            rr_1 = entry + risk if is_buy else entry - risk

            # Check if breakeven already done
            breakeven_done = 'breakeven_done' in comment

            # Move SL to breakeven at 1R only
            if not breakeven_done:
                if (is_buy and current_price >= rr_1) or (is_sell and current_price <= rr_1):
                    logger.info(f"[BREAKEVEN] Moving SL to breakeven for ticket {ticket} at 1R")
                    if hasattr(mt5_interface, 'modify_position'):
                        mod_result = mt5_interface.modify_position(ticket, sl=entry, tp=pos.get('tp'))
                        if mod_result:
                            logger.info(f"[BREAKEVEN] Successfully moved SL to breakeven for ticket {ticket}")
                        else:
                            logger.error(f"[BREAKEVEN] Failed to move SL to breakeven for ticket {ticket}")
                else:
                    logger.debug(f"[BREAKEVEN] Ticket {ticket} has not reached 1R yet. Current price: {current_price}, RR1: {rr_1}")
            else:
                logger.debug(f"[BREAKEVEN] Ticket {ticket} already marked as breakeven_done, skipping.")

    def enforce_risk_on_open_trades(self, symbol, mt5_interface=None, notification_manager=None):
        """Enforce risk management rules on open trades."""
        # This method is called but we don't need additional enforcement
        # as our manage_partial_close_and_breakeven handles the breakeven logic
        pass

    def reset_daily_limits(self):
        """Reset daily risk limits."""
        logger.info("[DAILY-RESET] Resetting daily risk limits")
        self.partial_2r_done_tickets.clear()
