import { makeStyles } from '@material-ui/core'
import { axisBottom, axisLeft, axisRight, AxisScale, axisTop } from 'd3-axis'
import { select as d3select, Selection } from 'd3-selection'
import React, { memo, useEffect, useState } from 'react'
import { useCallbackRef } from 'use-callback-ref'

const axes = {
  top: axisTop,
  bottom: axisBottom,
  left: axisLeft,
  right: axisRight,
} as const

const useStyles = makeStyles({
  root: {
    '& .domain': {
      display: 'none',
    },
    '& .tick line': {
      display: 'none',
    },
  },
})

export const FastAxis = memo(
  ({
    orientation = 'bottom',
    scale,
    left = 0,
    top = 0,
    ticks = undefined,
  }: {
    orientation?: keyof typeof axes
    scale: AxisScale<number | Date>
    left?: number
    top?: number
    ticks?: number
  }) => {
    const classes = useStyles()
    // store the d3 selection
    const [selection, setSelection] = useState<
      Selection<SVGGElement, unknown, null, undefined>
    >()

    // store d3 axis
    const [axis] = useState(() =>
      axes[orientation]<number | Date>(scale).tickSizeOuter(0)
    )

    // svg g element callback
    const gRef = useCallbackRef<SVGGElement>(null, (gElement) => {
      if (gElement) {
        const selection = d3select(gElement)
        selection.call(axis)
        setSelection(selection)
      }
    })

    // subscribe to scale to rerender axis on scale change
    useEffect(
      function rescaleAxis() {
        if (selection && scale) {
          axis.scale(scale)
          selection.call(axis)
        }
      },
      [scale, axis, selection]
    )

    axis.ticks(ticks)

    return (
      <g
        ref={gRef}
        className={classes.root}
        transform={`translate(${left} ${top})`}
        style={{
          fontSize: 14,
        }}
      ></g>
    )
  }
)

FastAxis.displayName = 'FastAxis'
