import { canvasUtil, grayScaleConverter } from './imageHashUtils'
import { ImageHash } from './hash'

const DEFAULT_SIZE = 8
const DEFAULT_HIGH_FREQUENCY_FACTOR = 4

type CosDict = Map<number, number>
const cosCache: Map<number, CosDict> = new Map()

function precomputeCos(L: number): CosDict {
  if (cosCache.has(L)) {
    return cosCache.get(L)!
  }

  const piOver2L = Math.PI / (2 * L)
  const cos: CosDict = new Map()

  for (let u = 0; u < L; u++) {
    const uTimesPiOver2L = u * piOver2L
    for (let x = 0; x < L; x++) {
      cos.set((u << 8) + x, Math.cos((2 * x + 1) * uTimesPiOver2L))
    }
  }

  cosCache.set(L, cos)
  return cos
}

/**
 * 2D DCT-II
 * @param matrix Must be a square matrix
 * @return {Array}
 */
function dctTransform(matrix: Uint8ClampedArray) {
  const L = Math.round(Math.sqrt(matrix.length))
  const cos = precomputeCos(L)
  const dct = new Array(L * L)

  for (let u = 0; u < L; u++) {
    for (let v = 0; v < L; v++) {
      let sum = 0
      const _u = u << 8 // Equivalent to u * 256
      const _v = v << 8 // Equivalent to v * 256

      for (let x = 0; x < L; x++) {
        const cos_u_x = cos.get(_u + x) ?? 0
        for (let y = 0; y < L; y++) {
          sum += matrix[x * L + y] * cos_u_x * (cos.get(_v + y) ?? 0)
        }
      }

      dct[u * L + v] = sum
    }
  }

  return dct
}

function median(values: Float64Array) {
  values.sort((a, b) => a - b)
  return values[Math.floor(values.length / 2)]
}

export async function pHash(
  image: HTMLImageElement,
  size = DEFAULT_SIZE,
  highFrequencyFactor = DEFAULT_HIGH_FREQUENCY_FACTOR
) {
  try {
    const imageSize = size * highFrequencyFactor
    const pixels = grayScaleConverter.convert(
      await canvasUtil.resizeImageAndGetData(image, imageSize, imageSize)
    )

    const dctOut = dctTransform(pixels)

    const dctLowFreq = new Float64Array(size * size)
    const sorted = new Float64Array(size * size)

    let ptrLow = 0
    let ptr = 0
    for (let i = 0; i < size; i++) {
      for (let j = 0; j < size; j++) {
        dctLowFreq[ptrLow] = dctOut[ptr]
        sorted[ptrLow] = dctOut[ptr]
        ptrLow += 1
        ptr += 1
      }
      ptr += imageSize - size
    }

    const med = median(sorted)
    const hash = new Uint8ClampedArray(size * size)

    for (let i = 0; i < hash.length; ++i) {
      hash[i] = dctLowFreq[i] > med ? 1 : 0
    }

    return new ImageHash(hash).toHexString()
  } catch (error) {
    console.error('Error generating pHash:', error)
    throw error
  }
}

export function comparePHashes(hashStrA: string, hashStrB: string) {
  try {
    const hashA = ImageHash.fromHexString(hashStrA)
    const hashB = ImageHash.fromHexString(hashStrB)
    return hashA.hammingDistance(hashB)
  } catch (error) {
    console.error('Error in comparePHashes function:', error)
    throw error
  }
}
