import { useState, useRef, useCallback, Ref, useImperativeHandle } from "react";
import { z } from "zod";
import { ChatSessionResponse } from "./useChatSession";

const connectionStatuses = [
  "disconnected",
  "connecting",
  "connected",
  "ready",
] as const;
export type ConnectionStatus = (typeof connectionStatuses)[number];

export interface ConnectionRef {
  startConnection: (chatSession: ChatSessionResponse) => void;
  closeConnection: () => void;
}

export function useConnectChatSession(
  ref: Ref<ConnectionRef>,
  onMessage: (message: any) => void
) {
  const [status, setStatus] = useState<ConnectionStatus>("disconnected");

  const isEnded = useRef<boolean>(false);
  const wsRef = useRef<WebSocket | null>(null);
  const reconnectTimeout = useRef<number | null>(null);
  const checkStatus = useRef<boolean>(false);
  const messageQueue = useRef<MessageEvent[]>([]);

  // Function to initiate the WebSocket connection
  const connect = useCallback(
    (chatSession: ChatSessionResponse) => {
      setStatus("connecting");
      wsRef.current = new WebSocket(chatSession.websocketEndpoint);

      wsRef.current.onopen = () => {
        // Clear any pending reconnect attempts
        if (reconnectTimeout.current) {
          clearTimeout(reconnectTimeout.current);
          reconnectTimeout.current = null;
        }
        setStatus("connected");

        wsRef.current?.send(
          JSON.stringify({ chatSessionId: chatSession.chatSessionId })
        );
        checkStatus.current = true;
      };

      wsRef.current.onmessage = (event: MessageEvent) => {
        if (checkStatus.current) {
          const FirstMsg = z.object({
            status: z.string(),
          });

          const { success, data } = FirstMsg.safeParse(JSON.parse(event.data));

          if (success && data.status === "success") {
            setStatus("ready");
            checkStatus.current = false;

            messageQueue.current.forEach((msg) => {
              onMessage(msg.data);
            });
            messageQueue.current = [];
          } else if (!success) {
            messageQueue.current.push(event);
          }
        } else {
          onMessage(event.data);
        }
      };

      wsRef.current.onclose = () => {
        setStatus("disconnected");

        if (isEnded.current) {
          return;
        }

        // Attempt to reconnect after 3 seconds
        reconnectTimeout.current = window.setTimeout(() => {
          connect(chatSession);
        }, 3000);
      };

      wsRef.current.onerror = (error) => {
        // Optionally close the socket if an error occurs
        // wsRef.current?.close();
      };
    },
    [onMessage]
  );

  const startConnection = useCallback(
    (chatSession: ChatSessionResponse) => {
      if (wsRef.current) {
        wsRef.current.close();
      }

      connect(chatSession);
    },
    [connect]
  );

  const closeConnection = useCallback(() => {
    if (wsRef.current) {
      wsRef.current.close();
    }

    if (reconnectTimeout.current) {
      clearTimeout(reconnectTimeout.current);
    }
    setStatus("disconnected");
  }, []);

  useImperativeHandle(ref, () => ({
    startConnection: (chatSession: ChatSessionResponse) => {
      startConnection(chatSession);
      isEnded.current = false;
    },
    closeConnection: () => {
      closeConnection();
      isEnded.current = true;
    },
  }));

  return {
    status,
  };
}
