// src/utils/sudoku/modelLoader.js
import * as tf from '@tensorflow/tfjs';

let loadedModel = null;
const MODEL_VERSION = '1.0';
const LOCAL_STORAGE_KEY = 'sudoku_model_data';
const LOCAL_MODEL_PATH = '/sudoku/models/mnist/model.json'; // Updated path to match our saved model

const enhanceImageProcessing = (tensor) => {
  return tf.tidy(() => {
    // Log original tensor stats
    console.log('Original tensor stats:', {
      shape: tensor.shape,
      mean: tensor.mean().dataSync()[0],
      max: tensor.max().dataSync()[0]
    });

    // Normalize to [0,1] range
    const normalized = tensor.sub(tensor.min())
      .div(tensor.max().sub(tensor.min()).add(1e-6));

    // Enhance contrast
    const mean = normalized.mean();
    const contrast = normalized.sub(mean).mul(1.5).add(mean).clipByValue(0, 1);

    // Sharpen edges
    const threshold = contrast.mean();
    const binary = contrast.greater(threshold).toFloat();

    // Log processed tensor stats
    console.log('Processed tensor stats:', {
      mean: binary.mean().dataSync()[0],
      max: binary.max().dataSync()[0]
    });

    return binary;
  });
};

export const calculateConfidence = (predictions) => {
  const sorted = Array.from(predictions).sort((a, b) => b - a);
  const topConfidence = sorted[0];
  const runnerUp = sorted[1];
  
  // If top prediction isn't significantly higher than second best,
  // reduce confidence
  const margin = topConfidence - runnerUp;
  return margin > 0.3 ? topConfidence : topConfidence * (margin / 0.3);
};

export const isEmptyCell = (tensor) => {
  return tf.tidy(() => {
    const mean = tensor.mean();
    const std = tensor.moments().variance.sqrt();
    return mean.less(0.1).logicalAnd(std.less(0.1));
  });
};

const preprocessImage = (imageData) => {
  return tf.tidy(() => {
    // Make sure we're getting a proper tensor from the image data
    let tensor = tf.browser.fromPixels(imageData);
    
    // Convert to grayscale if needed (if RGB/RGBA)
    if (tensor.shape[2] === 4 || tensor.shape[2] === 3) {
      tensor = tf.mean(tensor, -1, true);  // Keep the channel dimension
    }

    // Resize and normalize
    tensor = tf.image.resizeBilinear(tensor, [28, 28])
      .toFloat()
      .div(255.0);

    // Add batch dimension if needed
    if (tensor.shape.length === 3) {
      tensor = tensor.expandDims(0);
    }

    // Ensure we have shape [batch, height, width, channels]
    if (tensor.shape.length === 2) {
      tensor = tensor.expandDims(-1).expandDims(0);
    }

    // Log shape for debugging
    console.log('Preprocessed tensor shape:', tensor.shape);
    
    return tensor;
  });
};

const loadUserModel = async () => {
  try {
    const model = await tf.loadLayersModel('indexeddb://user-trained-sudoku-model');
    console.log('Loaded user-trained model');
    return model;
  } catch (error) {
    console.log('No user-trained model found, using base model');
    return null;
  }
};

export const loadModel = async () => {
  console.log('Model loading disabled');
  return null;
};

export const predictDigit = async (imageData) => {
  try {
    const model = await loadModel();
    
    // Preprocess image with proper shape handling
    const tensor = preprocessImage(imageData);
    console.log('Input tensor shape:', tensor.shape);
    
    // Enhanced preprocessing
    const enhanced = enhanceImageProcessing(tensor);
    
    // Make prediction with more detailed logging
    const predictions = await model.predict(enhanced).data();
    const predictionArray = Array.from(predictions);
    console.log('Raw predictions:', predictionArray);
    
    // Get top predictions
    const sortedPreds = predictionArray
      .map((prob, digit) => ({ digit, prob }))
      .sort((a, b) => b.prob - a.prob);
    
    console.log('Top 3 predictions:', sortedPreds.slice(0, 3));
    
    // Calculate confidence
    const maxConfidence = sortedPreds[0].prob;
    const digit = sortedPreds[0].digit;
    
    // More sophisticated empty cell detection
    const mean = enhanced.mean().dataSync()[0];
    const isEmpty = mean < 0.1;
    
    if (isEmpty) {
      console.log('Cell detected as empty:', { mean });
      return { digit: 0, confidence: 1.0 };
    }
    
    // Check confidence threshold
    if (maxConfidence < 0.3) {
      console.log('Low confidence prediction:', { maxConfidence, digit });
      return { digit: 0, confidence: maxConfidence };
    }

    console.log('Final prediction:', { digit, confidence: maxConfidence });
    return {
      digit: digit,
      confidence: maxConfidence
    };
  } catch (error) {
    console.error('Error predicting digit:', error);
    return { digit: 0, confidence: 0 };
  }
};

export const saveTrainingExample = async (imageData, correctDigit) => {
  try {
    const storage = window.localStorage;
    const trainingData = JSON.parse(storage.getItem(LOCAL_STORAGE_KEY) || '[]');
    
    trainingData.push({
      imageData,
      digit: correctDigit,
      timestamp: Date.now()
    });

    storage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(trainingData));
    console.log(`Saved training example for digit ${correctDigit}`);
  } catch (error) {
    console.error('Error saving training example:', error);
  }
};

export const trainOnUserData = async () => {
  if (!loadedModel) return;

  try {
    const storage = window.localStorage;
    const trainingData = JSON.parse(storage.getItem(LOCAL_STORAGE_KEY) || '[]');

    if (trainingData.length < 10) {
      console.log('Not enough training data yet (need at least 10 examples)');
      return;
    }

    console.log(`Training on ${trainingData.length} examples`);

    // Prepare training data
    const xs = tf.stack(trainingData.map(item => 
      preprocessImage(item.imageData)
    ));
    
    const ys = tf.oneHot(
      tf.tensor1d(trainingData.map(item => item.digit), 'int32'),
      10
    );

    // Train the model
    await loadedModel.fit(xs, ys, {
      epochs: 5,
      batchSize: 32,
      shuffle: true,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          console.log(`Training epoch ${epoch + 1}/5: loss = ${logs.loss.toFixed(4)}`);
        }
      }
    });

    // Save improved model
    await loadedModel.save('indexeddb://user-trained-sudoku-model');
    
    // Clear used training data
    storage.removeItem(LOCAL_STORAGE_KEY);

    // Cleanup
    xs.dispose();
    ys.dispose();

    console.log('Model updated with user data');
  } catch (error) {
    console.error('Error training model:', error);
  }
};