import { useMsal } from "@azure/msal-react"
import { getEnvVar } from "../../env"
import { loginRequest } from "../../MsalAuth/authConfig"
import { v4 as uuidv4 } from "uuid"
import { Dispatch, useRef } from "react"
import { AccountInfo, IPublicClientApplication } from "@azure/msal-browser"

/**
 * Output format: key1=val1&key2=val2&key3=val3
 * If array param: arr_key=val1&arr_key=val2&arr_key=val3
 * */
const combineUrlParameters = (params: {
  [key: string]: string | string[] | undefined
}) =>
  Object.entries(params)
    .filter(([key, value]) => value !== undefined)
    .map(([key, value]) =>
      typeof value === "string"
        ? `${key}=${value}`
        : value!.map((val) => `${key}=${val}`).join("&")
    )
    .join("&")

const acquireApiAccessToken = async (
  msalInstance: IPublicClientApplication,
  account: AccountInfo
) => {
  const acquireTokenRequestParams = {
    ...loginRequest,
    account,
    scopes: [`api://${getEnvVar("REACT_APP_API_APP_CLIENT_ID")}/access_as_user`],
  }

  try {
    return await msalInstance.acquireTokenSilent(acquireTokenRequestParams)
  } catch (error) {
    return await msalInstance.acquireTokenPopup(acquireTokenRequestParams)
  }
}

/** Returns false if request failed or error is thrown. Otherwise returns Response object */
export const useApiQueryWithMsalAuth = () => {
  const { instance: msalInstance, accounts } = useMsal()

  return {
    send: async ({
      uri,
      method,
      headers,
      urlParams,
      body,
    }: {
      uri: string
      method: "GET" | "POST" | "DELETE" | "PUT" | "PATCH"
      headers: { [header: string]: string }
      urlParams?: { [key: string]: string | string[] | undefined }
      body?: BodyInit | undefined
    }) => {
      const authResult = await acquireApiAccessToken(msalInstance, accounts[0])
      const urlParamsToConcat = !urlParams
        ? ""
        : uri.includes("?")
        ? `&${combineUrlParameters(urlParams)}`
        : `?${combineUrlParameters(urlParams)}`

      let response: Response
      try {
        response = await fetch(
          `${getEnvVar("REACT_APP_AJAX_API_URL")}${uri}${urlParamsToConcat}`,
          {
            method,
            headers: {
              ...headers,
              Authorization: `Bearer ${authResult.accessToken}`,
            },
            body,
          }
        )
      } catch (error) {
        console.error("API call failed: ", error)
        return false
      }

      return response
    },
  }
}

export enum AllWebSocketActionTypes {
  AISummarize = "summarize",
  AIComparison = "comparison",
  AICustomRequest = "custom-request",
  AITranslate = "translate",
}
const conflictingWebSocketGroups = [
  [
    AllWebSocketActionTypes.AISummarize,
    AllWebSocketActionTypes.AIComparison,
    AllWebSocketActionTypes.AICustomRequest,
    AllWebSocketActionTypes.AITranslate,
  ],
]
type WebSocketPoolRecord = {
  streamingId: string
  actionType: AllWebSocketActionTypes
  connection: WebSocket
}
let webSocketsPool: WebSocketPoolRecord[] = []
const addNewWSConnectionAndCloseConflicting = (
  newWebSocketPoolRecord: WebSocketPoolRecord
) => {
  const conflictingActionTypes = conflictingWebSocketGroups
    .filter((group) => group.includes(newWebSocketPoolRecord.actionType))
    .flat()
  const newWsPool: WebSocketPoolRecord[] = []
  webSocketsPool.forEach((wsRecord) => {
    if (conflictingActionTypes.includes(wsRecord.actionType)) {
      wsRecord.connection.close(
        undefined,
        "Conflicting with newly opened connection"
      )
    } else {
      newWsPool.push(wsRecord)
    }
  })
  webSocketsPool = [...newWsPool]
}
const closeAndRemoveConnection = (streamingId: string) => {
  const wsRecordInd = webSocketsPool.findIndex(
    (wsRecord) => wsRecord.streamingId === streamingId
  )
  if (wsRecordInd) {
    webSocketsPool[wsRecordInd]?.connection.close()
    webSocketsPool.splice(wsRecordInd, 1)
  }
}

export interface WebSocketActionMessage {
  action_type: string
  result: string
  error?: string
  streaming_id?: string | null
}

/**
 * Helper to authorize and establish WebSocket connection. Uses MSAL access token
 * Automatically closes all existing conflicting connections
 * If you add new WebSocket flow, please register your WS Action in the conflictingWebSocketGroups
 * So other conflicting connections (including the type you're adding) will be closed automatically
 */
export const useWebSocketWithMsalAuth = () => {
  const connectionRef = useRef<WebSocket | null>(null)
  const { instance: msalInstance, accounts } = useMsal()

  return {
    run: async ({
      uri,
      actionType,
      urlParams,
      setCurrentState,
      onWebSocketOpen,
      onWebSocketClose,
      onWebSocketError,
    }: {
      uri: string
      actionType: AllWebSocketActionTypes
      urlParams?: { [key: string]: string | string[] | undefined }
      setCurrentState: Dispatch<string>
      onWebSocketOpen?: (ws: WebSocket) => void
      onWebSocketClose?: () => void
      onWebSocketError?: (errorMessage: string) => void
    }) => {
      const authResult = await acquireApiAccessToken(msalInstance, accounts[0])
      const streamingId = uuidv4()
      uri += uri.includes("?")
        ? `&streaming_id=${streamingId}`
        : `?streaming_id=${streamingId}`

      connectionRef.current = new WebSocket(
        `${getEnvVar("REACT_APP_WEBSOCKET_API_URL")}${uri}` +
          (urlParams ? `&${combineUrlParameters(urlParams)}` : "") +
          `&access_token=${authResult.accessToken}`
      )

      const wsPoolRecord: WebSocketPoolRecord = {
        actionType,
        connection: connectionRef.current,
        streamingId,
      }
      addNewWSConnectionAndCloseConflicting(wsPoolRecord)

      connectionRef.current.onclose = () => {
        closeAndRemoveConnection(streamingId)
        connectionRef.current = null
        onWebSocketClose && onWebSocketClose()
      }

      connectionRef.current.onopen = () => {
        onWebSocketOpen &&
          connectionRef.current &&
          onWebSocketOpen(connectionRef.current)
      }

      const messages: string[] = []
      connectionRef.current.onmessage = (e) => {
        const data = JSON.parse(e.data) as WebSocketActionMessage
        if (data.error) {
          onWebSocketError && onWebSocketError(data.error)
          return
        }
        if (!data.streaming_id || data.streaming_id !== streamingId) {
          return
        }

        messages.push(data.result)
        setCurrentState(messages.join(""))
      }
    },
    stop: () => {
      connectionRef.current?.close()
    },
  }
}
