import React, { useEffect, useCallback } from "react";
import {
  useNodesState,
  useEdgesState,
  addEdge,
  Edge,
  Node,
  Connection,
  MarkerType,
  ReactFlow,
  Position,
} from "@xyflow/react";
import "@xyflow/react/dist/style.css";
import { Edges, Metric } from "../../Interfaces";
import { useBatchContext } from "../../../context/batch/BatchContext";
import { MetricEdgeType } from "../../../utils/enum";
import { getAnalysisPlanNodeName, getAnalysisPlanEdgeLabel } from "../../../utils/nameUtils";
import dagre from "dagre";
import CustomNodeComponent from "./CustomNode";
import { Box, Center, Spinner, Text } from "@chakra-ui/react";

const proOptions = { hideAttribution: true };
const nodeWidth = 150; // Increased width
const nodeHeight = 60; // Increased height

interface CustomNodeData extends Record<string, unknown> {
  label: string;
}

type CustomNode = Node<CustomNodeData>;

// Helper function to determine font size based on label length
const getFontSize = (label: string): number => {
  const maxLength = 30; // Define the maximum length before reducing font size
  return label.length > maxLength ? 10 : 12; // 10px for long labels, 12px otherwise
};

const getLayoutedElements = (nodes: CustomNode[], edges: Edge[]) => {
  const dagreGraph = new dagre.graphlib.Graph();

  dagreGraph.setDefaultEdgeLabel(() => ({}));

  dagreGraph.setGraph({
    rankdir: "LR", // Left to Right
    nodesep: 20, // Separation between nodes
    ranksep: 80, // Separation between ranks
  });

  nodes.forEach((node) => {
    dagreGraph.setNode(node.id, { width: nodeWidth, height: nodeHeight });
  });

  edges.forEach((edge) => {
    dagreGraph.setEdge(edge.source, edge.target);
  });

  dagre.layout(dagreGraph);

  // Define a Y-offset to center the graph vertically
  const yOffset = 120; // Adjust this value as needed
  const xOffset = 100; // Adjust this value as needed

  const layoutedNodes: CustomNode[] = nodes.map((node) => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = Position.Left;
    node.sourcePosition = Position.Right;

    return {
      ...node,
      position: {
        x: nodeWithPosition.x - nodeWidth / 2 + xOffset, // Center the node horizontally
        y: nodeWithPosition.y - nodeHeight / 2 + yOffset, // Apply Y-offset
      },
      style: {
        ...node.style, // Preserve existing styles
        width: nodeWidth,
        height: nodeHeight,
      },
    };
  });

  return { nodes: layoutedNodes, edges };
};

const MKGTraversalContainer: React.FC = () => {
  const {
    metricInfoMap,
    selectedMetricHeader,
    selectedEdgeToExclude,
    customBenchmark,
    selectedContextMap,
    selectedNodeIds,
    setSelectedNodeIds,
    customBenchmarkName,
    analysisPlanLoader,
    selectedPurpose,
  } = useBatchContext();

  const [nodes, setNodes, onNodesChange] = useNodesState<CustomNode>([]);
  const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([]);

  useEffect(() => {
    if (!selectedMetricHeader || !metricInfoMap[selectedMetricHeader.id!]) {
      setNodes([]);
      setEdges([]);
      return;
    }

    const createGraph = async () => {
      const metric: Metric = metricInfoMap[selectedMetricHeader.id!];
      
      // Clear previous nodes and edges to re-render graph correctly
      setNodes([]);
      setEdges([]);

      // Root node for metric header
      const rootNode: CustomNode = {
        id: metric.header.id!,
        data: { label: `Compute ${metric.header.name!}` },
        position: { x: 20, y: 50 }, // Starting position for the first node
        type: "default",
        sourcePosition: Position.Right, // Edge will come out from the right side
        targetPosition: Position.Left, // Edge will enter from the left side
        style: {
          background: '#f0f8ff',
          color: '#333',
          padding: 12,
          border: '2px solid #1e90ff',
          borderRadius: 10,
          boxShadow: '2px 2px 5px rgba(0,0,0,0.3)',
          fontWeight: 'bold',
        },
      };

      const initialNodes: CustomNode[] = [rootNode];
      const initialEdges: Edge[] = [];

      metric && metric?.edges && metric.edges.map((edge: Edges) => {
        if(edge.edge_type === MetricEdgeType.BENCHMARK_EDGE && customBenchmarkName.includes(edge?.benchmark?.header?.name!)){
          // Create the default linear graph with the default benchmark
          createLinearGraph(metric, initialNodes, initialEdges, new Set()); 
        }
      })
      

      // Create subtrees for each custom benchmark
      customBenchmark.forEach((benchmarkName, index) => {
        if (
          !metric?.edges?.some(
            (edge) =>
              edge.edge_type === MetricEdgeType.BENCHMARK_EDGE &&
              edge.benchmark?.header.name === benchmarkName.name
          )
        ) {
          createCustomBenchmarkBranch(
            metric,
            rootNode.id,
            benchmarkName,
            initialNodes,
            initialEdges,
            index
          );
        }
      });
      const { nodes: layoutedNodes, edges: layoutedEdges } =
      getLayoutedElements(initialNodes, initialEdges);
      setNodes(layoutedNodes);
      setEdges(layoutedEdges);
    };

    createGraph();
  }, [
    selectedMetricHeader,
    metricInfoMap,
    selectedEdgeToExclude,
    customBenchmark,
    selectedContextMap,
    customBenchmarkName,
    selectedPurpose, 
  ]);

  // Function to create nodes and edges for default metric structure
  const createLinearGraph = (
    metric: Metric,
    nodes: Node[],
    edges: Edge[],
    localUsedEdges: Set<string>
  ) => {
    const rootNodeId = metric.header.id;
    let benchMarknodeId = "";
    let xPosition = 100; // Horizontal position for nodes, starts after root node
    let previousNodeId = rootNodeId; // Keep track of the previous node to connect edges sequentially

    // Sort the edges by edge weight in descending order
    const sortedEdges = [...(metric.edges || [])].sort(
      (a, b) => (b.edge_weight || 0) - (a.edge_weight || 0)
    );

    // Add Benchmark Node first if exists
    const benchmarkEdge = sortedEdges.find(
      (edge) =>
        edge.edge_type === MetricEdgeType.BENCHMARK_EDGE && edge.benchmark
    );
    if (benchmarkEdge) {
      benchMarknodeId = benchmarkEdge.benchmark!.header.id!;
      const label = getAnalysisPlanNodeName(
        benchmarkEdge.benchmark?.header.name!,
        benchmarkEdge.edge_type!
      );

      if (!nodes.find((node) => node.id === benchMarknodeId)) {
        nodes.push({
          id: benchMarknodeId,
          data: { label },
          position: { x: xPosition, y: 50 }, // Position nodes in a horizontal line
          type: "default",
          sourcePosition: Position.Right, // Edge will come out from the right side
          targetPosition: Position.Left, // Edge will enter from the left side
          style: {
            fontSize: `${getFontSize(label)}px`,
          },
        });
        xPosition += 180; // Increment x position for the next node
      }

      // Add the edge connecting to the root node
      edges.push({
        id: `${rootNodeId}-${benchMarknodeId}`,
        source: rootNodeId!,
        target: benchMarknodeId,
        type: "straight", // Use straight type for linear edges
        markerEnd: {
          type: MarkerType.ArrowClosed,
        },
        label: getAnalysisPlanEdgeLabel('Metric->Benchmark', selectedPurpose)
      });

      // Update the previous node to the current one (benchmark)
      previousNodeId = benchMarknodeId;

      // Mark this edge as used
      localUsedEdges.add(
        benchmarkEdge.edge_type + (benchmarkEdge.benchmark?.header.id || "")
      );
    }

    let isFirstEdge: boolean = true;

    // Create remaining nodes and edges sequentially based on the sorted edges
    for (const edge of sortedEdges) {
      // Skip already used edges or excluded edges
      if (
        localUsedEdges.has(
          edge.edge_type + (edge.related_attribute?.id || "")
        ) ||
        (edge.related_attribute &&
          (selectedEdgeToExclude.includes(edge.related_attribute.name!) ||
            selectedContextMap[edge.related_attribute.name!])) ||
        (edge.fundamental_relationship &&
          selectedEdgeToExclude.includes(
            edge.fundamental_relationship.header?.name!
          ))
      ) {
        continue;
      }

      let nodeId = "";
      let label = "";
      let edgeLabel: string = '';

      // Determine node details based on edge type
      if (
        edge.edge_type === MetricEdgeType.ATTRIBUTE_EDGE &&
        edge.related_attribute
      ) {
        nodeId = edge.related_attribute.id!;
        label = getAnalysisPlanNodeName(
          edge.related_attribute.name!,
          MetricEdgeType.ATTRIBUTE_EDGE
        );
        edgeLabel = getAnalysisPlanEdgeLabel(isFirstEdge? "Benchmark->First": "After", selectedPurpose); // Adjust as needed
      } else if (
        edge.edge_type === MetricEdgeType.FUNDAMENTAL_EDGE &&
        edge.fundamental_relationship
      ) {
        nodeId = edge.fundamental_relationship.header?.id!;
        label = getAnalysisPlanNodeName(
          edge.fundamental_relationship.header?.name!,
          MetricEdgeType.FUNDAMENTAL_EDGE
        );
        edgeLabel = getAnalysisPlanEdgeLabel(isFirstEdge? "Benchmark->First": "After", selectedPurpose); // Adjust as needed
      }

      if (nodeId && label) {
        // Add the node if it doesn't exist already
        if (!nodes.find((node) => node.id === nodeId)) {
          nodes.push({
            id: `${benchMarknodeId}|${nodeId}`,
            data: { label },
            position: { x: xPosition, y: 50 }, // Position nodes in a horizontal line
            type: "default",
            sourcePosition: Position.Right, // Edge will come out from the right side
            targetPosition: Position.Left, // Edge will enter from the left side
            style:{
              fontSize: `${getFontSize(label)}px`
            },
          });
          xPosition += 180; // Increment x position for the next node
        }

        // Add the edge connecting to the previous node
        edges.push({
          id: `${previousNodeId}-${nodeId}`,
          source: previousNodeId!,
          target: `${benchMarknodeId}|${nodeId}`,
          type: "straight", // Use straight type for linear edges
          markerEnd: {
            type: MarkerType.ArrowClosed,
          },
          label: edgeLabel || "Next"
        });

        // Update the previous node to the current one
        previousNodeId = `${benchMarknodeId}|${nodeId}`;

        // Mark this edge as used
        localUsedEdges.add(edge.edge_type + (edge.related_attribute?.id || ""));
        isFirstEdge = false;
      }
    }
  };

  // Function to create a custom benchmark branch
  const createCustomBenchmarkBranch = (
    metric: Metric,
    rootNodeId: string,
    customBenchmark: { name: string; id: string },
    nodes: Node[],
    edges: Edge[],
    branchIndex: number
  ) => {
    const benchmarkNodeId = customBenchmark.id; // Unique ID for custom benchmarks
    const yPosition = 200 + branchIndex * 100; // Adjust vertical position for each custom benchmark
    let xPosition = 200; // Start position for custom benchmark

    // Create a custom benchmark node
    if (!nodes.find((node) => node.id === benchmarkNodeId)) {
      const label = getAnalysisPlanNodeName(customBenchmark.name, MetricEdgeType.BENCHMARK_EDGE);
      nodes.push({
        id: benchmarkNodeId,
        data: {
          label,
        },
        style:{
          fontSize: `${getFontSize(label)}px`
        },
        position: { x: xPosition, y: yPosition }, // Adjust y position to create branches
        type: "default",
        sourcePosition: Position.Right,
        targetPosition: Position.Left,
      });
      xPosition += 180; // Increment x position for the next node
    }

    // Create an edge from root node to custom benchmark node
    edges.push({
      id: `${rootNodeId}-${benchmarkNodeId}`,
      source: rootNodeId,
      target: benchmarkNodeId,
      type: "straight",
      markerEnd: {
        type: MarkerType.ArrowClosed,
      },
      label: getAnalysisPlanEdgeLabel('Metric->Benchmark', selectedPurpose)
    });

    // Now, replicate the rest of the default graph structure starting from the custom benchmark node
    let previousNodeId = benchmarkNodeId; // Start with the custom benchmark node

    // Iterate over metric edges and create nodes and edges sequentially
    const sortedEdges = [...(metric.edges || [])].sort(
      (a, b) => (b.edge_weight || 0) - (a.edge_weight || 0)
    );

    let isFirstEdge: boolean = true;

    for (const edge of sortedEdges) {
      // Skip if it's already used or if it's excluded
      if (
        edge.edge_type === MetricEdgeType.BENCHMARK_EDGE || // Skip original benchmark nodes
        (edge.related_attribute &&
          selectedEdgeToExclude.includes(edge.related_attribute.name!)) ||
        selectedContextMap[edge.related_attribute?.name!] ||
        (edge.fundamental_relationship &&
          selectedEdgeToExclude.includes(
            edge.fundamental_relationship.header?.name!
          ))
      ) {
        continue;
      }

      let nodeId = "";
      let label = "";
      let edgeLabel: string = '';

      // Determine node details based on edge type
      if (
        edge.edge_type === MetricEdgeType.ATTRIBUTE_EDGE &&
        edge.related_attribute
      ) {
        nodeId = `${edge.related_attribute.id!}`;
        label = getAnalysisPlanNodeName(
          edge.related_attribute.name!,
          MetricEdgeType.ATTRIBUTE_EDGE
        );
        edgeLabel = getAnalysisPlanEdgeLabel(isFirstEdge? "Benchmark->First": "After", selectedPurpose); // Adjust as needed
      } else if (
        edge.edge_type === MetricEdgeType.FUNDAMENTAL_EDGE &&
        edge.fundamental_relationship
      ) {
        nodeId = `${edge.fundamental_relationship.header?.id!}`;
        label = getAnalysisPlanNodeName(
          edge.fundamental_relationship.header?.name!,
          MetricEdgeType.FUNDAMENTAL_EDGE
        );
        edgeLabel = getAnalysisPlanEdgeLabel(isFirstEdge? "Benchmark->First": "After", selectedPurpose); // Adjust as needed
      }

      if (nodeId && label) {
        // Add the node if it doesn't exist already
        if (!nodes.find((node) => node.id === nodeId)) {
          nodes.push({
            id: `${benchmarkNodeId}|${nodeId}`,
            data: { label },
            position: { x: xPosition, y: yPosition }, // Position nodes in a horizontal line
            type: "default",
            sourcePosition: Position.Right, // Edge will come out from the right side
            targetPosition: Position.Left, // Edge will enter from the left side
            style:{
              fontSize: `${getFontSize(label)}px`
            },
          });
          xPosition += 180; // Increment x position for the next node
        }

        // Add the edge connecting to the previous node
        edges.push({
          id: `${previousNodeId}-${nodeId}`,
          source: previousNodeId!,
          target: `${benchmarkNodeId}|${nodeId}`,
          type: "straight", // Use straight type for linear edges
          markerEnd: {
            type: MarkerType.ArrowClosed,
          },
          label: edgeLabel || "Next"
        });

        // Update the previous node to the current one
        previousNodeId = `${benchmarkNodeId}|${nodeId}`;
        isFirstEdge = false;
      }
    }
  };

   // Function to handle connection events
   const onConnect = useCallback(
    (params: Connection) => setEdges((eds) => addEdge(params, eds)),
    [setEdges]
  );

  // Function to handle node clicks with multi-selection
  const onNodeClick = useCallback(
    (event: React.MouseEvent, node: Node) => {
      if (event.shiftKey) {
        // If Shift key is held, toggle selection
        setSelectedNodeIds((prevSelected) => {
          if (prevSelected.includes(node.id)) {
            // If already selected, remove it
            return prevSelected.filter((id) => id !== node.id);
          } else {
            // If not selected, add it
            return [...prevSelected, node.id];
          }
        });
      } 
    },
    [selectedNodeIds]
  );

  // Define node types
  const nodeTypes = {
    custom: CustomNodeComponent, // Register the custom node component
  };

  const baseNodeStyle: React.CSSProperties = {
    borderRadius: 10,
    boxShadow: '2px 2px 5px rgba(0,0,0,0.3)',
    fontWeight: 'bold',
    transition: "background 0.3s, border 0.3s", // Smooth transition for all nodes
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'center',
    padding: '0 10px', // Horizontal padding for label
    overflow: 'hidden',
  };

  // Map through nodes and apply conditional styles
  const styledNodes = nodes.map((node) => ({
    ...node,
    style: selectedNodeIds.includes(node.id)
      ? {
          ...baseNodeStyle,
          ...node.style,
          background: '#e0f7fa',        // Greenish background for selected nodes
          color: '#004d40',              // Dark green text color
          border: '2px solid #004d40',   // Dark green border
        }
      : {
          ...baseNodeStyle,
          ...node.style,
          background: '#fff2f8',        // Light purple background for normal nodes
          color: '#333',                 // Default text color
          border: '2px solid #dfc2f0',   // Purple border for normal nodes
        },
  }));

  return (
    <Box style={{ width: "100%", height: "455px" }} key={selectedMetricHeader?.id}>
      {analysisPlanLoader ? (
        <Center height="100%">
          {/* Ensure the spinner is centered */}
          <Spinner size="xl" color="purple.700" />
        </Center>
      ) : selectedMetricHeader ? (
        <ReactFlow
          nodes={styledNodes}
          edges={edges}
          onNodesChange={onNodesChange}
          onEdgesChange={onEdgesChange}
          onNodeClick={onNodeClick}
          onConnect={onConnect}
          fitView={false}
          proOptions={proOptions}
          nodeTypes={nodeTypes}
        />
      ) : (
        <Center height="100%">
          <Text color="gray.500" fontSize="lg">
            No metric selected
          </Text>
        </Center>
      )}
    </Box>
  );
};

export default MKGTraversalContainer;
