import * as d3 from 'd3';
import { formatCurrency } from '../../../utils/calculations';
import { getFilterDescription } from '../../../config/plotConfig';
import createLegend from '../../../components/visualization/Legend';
import { calculateRegression, calculateCorrelation } from '../../../utils/calculations';

const createScatterPlot = (container, data, config, responsiveConfig, sex, status, jitterEnabled = false) => {
  const { margin, pointRadius, fontSize } = {
    ...responsiveConfig,
    margin: {
      ...responsiveConfig.margin,
      top: 50,
      bottom: 70,
      right: 50,
      left: 70
    }
  };

  const validData = data.filter(d => 
    !isNaN(+d[config.x]) && 
    !isNaN(+d[config.y]) && 
    d[config.x] !== null && 
    d[config.y] !== null
  );

  if (validData.length === 0) return;

  const width = container.node().clientWidth - margin.left - margin.right;
  const height = Math.min(
    container.node().clientHeight - margin.top - margin.bottom,
    width * 0.75
  );

  const svg = container.append('svg')
    .attr('width', width + margin.left + margin.right)
    .attr('height', height + margin.top + margin.bottom)
    .append('g')
    .attr('transform', `translate(${margin.left},${margin.top})`);

  const x = d3.scaleLinear()
    .domain(d3.extent(validData, d => +d[config.x]))
    .range([0, width])
    .nice();

  const y = d3.scaleLinear()
    .domain(d3.extent(validData, d => +d[config.y]))
    .range([height, 0])
    .nice();

  // Create maps to count points at each position
  const positionMap = new Map();
  const xPositionMap = new Map();
  const yPositionMap = new Map();

  validData.forEach(d => {
    const key = `${+d[config.x]},${+d[config.y]}`;
    const xKey = `${+d[config.x]}`;
    const yKey = `${+d[config.y]}`;
    positionMap.set(key, (positionMap.get(key) || 0) + 1);
    xPositionMap.set(xKey, (xPositionMap.get(xKey) || 0) + 1);
    yPositionMap.set(yKey, (yPositionMap.get(yKey) || 0) + 1);
  });

  // Calculate density for color scaling
  const densityMap = new Map();
  const roundTo = Math.max(20, Math.min(30, width * 0.03));
  
  validData.forEach(d => {
    const xPos = Math.round(x(+d[config.x]) / roundTo) * roundTo;
    const yPos = Math.round(y(+d[config.y]) / roundTo) * roundTo;
    const key = `${xPos},${yPos}`;
    densityMap.set(key, (densityMap.get(key) || 0) + 1);
  });

  const maxDensity = Math.max(...densityMap.values());
  const colorScale = d3.scaleSequential()
    .domain([1, maxDensity])
    .interpolator(t => {
      const colors = [
        "#AB51E3", "#8B2FC9", "#6818A5", "#5A108F", "#4A0A77"
      ];
      const i = Math.min(Math.floor(t * (colors.length - 1)), colors.length - 2);
      const f = t * (colors.length - 1) - i;
      return d3.interpolate(colors[i], colors[i + 1])(f);
    });

  // Calculate regression and correlation
  const regression = calculateRegression(validData, config.x, config.y);
  const correlation = calculateCorrelation(validData, config.x, config.y);

  // Add regression line
  const xDomain = x.domain();
  const yAtMinX = regression.slope * xDomain[0] + regression.intercept;
  const yAtMaxX = regression.slope * xDomain[1] + regression.intercept;

  const yDomain = y.domain();
  const points = [
    {
      x: xDomain[0],
      y: Math.max(yDomain[0], Math.min(yDomain[1], yAtMinX))
    },
    {
      x: xDomain[1],
      y: Math.max(yDomain[0], Math.min(yDomain[1], yAtMaxX))
    }
  ];

  const regressionLine = d3.line()
    .x(d => x(d.x))
    .y(d => y(d.y));

  svg.append("path")
    .datum(points)
    .attr("class", "regression-line")
    .attr("d", regressionLine)
    .style("stroke", "red")
    .style("stroke-width", Math.max(2, Math.min(4, width * 0.004)))
    .style("fill", "none");

  const tooltip = d3.select('body').append('div')
    .attr('class', 'tooltip')
    .style('opacity', 0)
    .style('background-color', 'var(--bg-color)')
    .style('color', 'var(--text-color)')
    .style('border', '1px solid var(--sub-color)')
    .style('font-family', "'Roboto Mono', monospace")
    .style('padding', '15px') // Increased padding
    .style('border-radius', '8px') // Increased border radius
    .style('font-size', '1.1em') // Increased font size
    .style('box-shadow', '0 4px 12px rgba(0,0,0,0.15)'); // Added shadow for better visibility

  // Add points with conditional jittering
  svg.selectAll('circle')
    .data(validData)
    .enter()
    .append('circle')
    .attr('cx', d => {
      if (!jitterEnabled) return x(+d[config.x]);
      
      const xKey = `${+d[config.x]}`;
      const xCount = xPositionMap.get(xKey);
      
      // Only apply jitter if there are multiple points at this x value
      if (xCount > 1) {
        const jitterRange = Math.min(width * 0.015, Math.log2(xCount + 1) * pointRadius * 4);
        return x(+d[config.x]) + (Math.random() - 0.5) * jitterRange;
      }
      return x(+d[config.x]);
    })
    .attr('cy', d => {
      if (!jitterEnabled) return y(+d[config.y]);
      
      // Only apply y-jitter for age vs length of stay
      if (config.x === 'AGE' && config.y === 'LOS') {
        const key = `${+d[config.x]},${+d[config.y]}`;
        const count = positionMap.get(key);
        if (count > 1) {
          const jitterRange = Math.min(height * 0.01, Math.log2(count + 1) * pointRadius * 3);
          return y(+d[config.y]) + (Math.random() - 0.5) * jitterRange;
        }
      }
      return y(+d[config.y]);
    })
    .attr('r', pointRadius)
    .style('fill', d => {
      const xPos = Math.round(x(+d[config.x]) / roundTo) * roundTo;
      const yPos = Math.round(y(+d[config.y]) / roundTo) * roundTo;
      const density = densityMap.get(`${xPos},${yPos}`);
      return colorScale(density);
    })
    .style('stroke', 'rgba(0,0,0,0.2)')
    .style('stroke-width', 1)
    .style('opacity', 0.8)
    .on('mouseover', function(event, d) {
      const xPos = Math.round(x(+d[config.x]) / roundTo) * roundTo;
      const yPos = Math.round(y(+d[config.y]) / roundTo) * roundTo;
      const density = densityMap.get(`${xPos},${yPos}`);
      
      d3.select(this)
        .attr('r', pointRadius * 1.5)
        .style('stroke-width', 2)
        .style('opacity', 1);
        
      tooltip.transition()
        .duration(200)
        .style('opacity', .9);
      tooltip.html(`Sex: ${d.SEX}<br/>
                  ${config.x}: ${d[config.x]}<br/>
                  ${config.y}: ${config.y === 'TC' ? formatCurrency(+d[config.y]) : d[config.y]}<br/>
                  Points in this area: ${density}`)
        .style('left', (event.pageX + 10) + 'px')
        .style('top', (event.pageY - 28) + 'px');
    })
    .on('mouseout', function() {
      d3.select(this)
        .attr('r', pointRadius)
        .style('stroke-width', 1)
        .style('opacity', 0.8);
      tooltip.transition()
        .duration(500)
        .style('opacity', 0);
    });

  // Add axes
  const xAxis = svg.append('g')
    .attr('transform', `translate(0,${height})`)
    .call(d3.axisBottom(x));

  xAxis.selectAll('text')
    .style('font-size', `${fontSize.axis}px`)
    .style('fill', 'var(--text-color)')
    .style('font-family', "'Roboto Mono', monospace");

  xAxis.selectAll('line')
    .style('stroke', 'var(--axis-color)');

  xAxis.selectAll('path')
    .style('stroke', 'var(--axis-color)');

  const yAxis = svg.append('g')
    .call(d3.axisLeft(y)
      .tickFormat(config.y === 'TC' ? formatCurrency : d3.format('d')));

  yAxis.selectAll('text')
    .style('font-size', `${fontSize.axis}px`)
    .style('fill', 'var(--text-color)')
    .style('font-family', "'Roboto Mono', monospace");

  yAxis.selectAll('line')
    .style('stroke', 'var(--axis-color)');

  yAxis.selectAll('path')
    .style('stroke', 'var(--axis-color)');

  // Add axis labels
  svg.append('text')
    .attr('x', width / 2)
    .attr('y', height + margin.bottom * 0.7)
    .style('text-anchor', 'middle')
    .style('font-size', `${fontSize.axis}px`)
    .style('fill', 'var(--text-color)')
    .style('font-family', "'Roboto Mono', monospace")
    .text(config.x === 'AGE' ? 'Age (Years)' : 'Length of Stay (Days)');

  svg.append('text')
    .attr('transform', 'rotate(-90)')
    .attr('x', -height / 2)
    .attr('y', -margin.left * 0.85)
    .style('text-anchor', 'middle')
    .style('font-size', `${fontSize.axis}px`)
    .style('fill', 'var(--text-color)')
    .style('font-family', "'Roboto Mono', monospace")
    .text(config.y === 'TC' ? 'Total Cost ($)' : (config.y === 'LOS' ? 'Length of Stay (Days)' : config.y));

  // Add title
  const filterDesc = getFilterDescription(sex, status);
  const titleText = filterDesc ? `${config.label} for ${filterDesc} Patients` : config.label;
  
  svg.append('text')
    .attr('x', width / 2)
    .attr('y', -margin.top / 3)
    .style('text-anchor', 'middle')
    .style('font-size', `${fontSize.title}px`)
    .style('fill', 'var(--text-color)')
    .style('font-family', "'Roboto Mono', monospace")
    .text(titleText);

  // Create legend
  createLegend(svg, width, colorScale, true, correlation, responsiveConfig, validData.length);

  return { svg, width, height, x, y, colorScale, validData };
};

export default createScatterPlot;
