package huff0

import (
	"fmt"
	"runtime"
	"sync"
)

// Compress1X will compress the input.
// The output can be decoded using Decompress1X.
// Supply a Scratch object. The scratch object contains state about re-use,
// So when sharing across independent encodes, be sure to set the re-use policy.
func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
	s, err = s.prepare(in)
	if err != nil {
		return nil, false, err
	}
	return compress(in, s, s.compress1X)
}

// Compress4X will compress the input. The input is split into 4 independent blocks
// and compressed similar to Compress1X.
// The output can be decoded using Decompress4X.
// Supply a Scratch object. The scratch object contains state about re-use,
// So when sharing across independent encodes, be sure to set the re-use policy.
func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
	s, err = s.prepare(in)
	if err != nil {
		return nil, false, err
	}
	if false {
		// TODO: compress4Xp only slightly faster.
		const parallelThreshold = 8 << 10
		if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
			return compress(in, s, s.compress4X)
		}
		return compress(in, s, s.compress4Xp)
	}
	return compress(in, s, s.compress4X)
}

func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
	// Nuke previous table if we cannot reuse anyway.
	if s.Reuse == ReusePolicyNone {
		s.prevTable = s.prevTable[:0]
	}

	// Create histogram, if none was provided.
	maxCount := s.maxCount
	var canReuse = false
	if maxCount == 0 {
		maxCount, canReuse = s.countSimple(in)
	} else {
		canReuse = s.canUseTable(s.prevTable)
	}

	// We want the output size to be less than this:
	wantSize := len(in)
	if s.WantLogLess > 0 {
		wantSize -= wantSize >> s.WantLogLess
	}

	// Reset for next run.
	s.clearCount = true
	s.maxCount = 0
	if maxCount >= len(in) {
		if maxCount > len(in) {
			return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
		}
		if len(in) == 1 {
			return nil, false, ErrIncompressible
		}
		// One symbol, use RLE
		return nil, false, ErrUseRLE
	}
	if maxCount == 1 || maxCount < (len(in)>>7) {
		// Each symbol present maximum once or too well distributed.
		return nil, false, ErrIncompressible
	}
	if s.Reuse == ReusePolicyMust && !canReuse {
		// We must reuse, but we can't.
		return nil, false, ErrIncompressible
	}
	if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse {
		keepTable := s.cTable
		keepTL := s.actualTableLog
		s.cTable = s.prevTable
		s.actualTableLog = s.prevTableLog
		s.Out, err = compressor(in)
		s.cTable = keepTable
		s.actualTableLog = keepTL
		if err == nil && len(s.Out) < wantSize {
			s.OutData = s.Out
			return s.Out, true, nil
		}
		if s.Reuse == ReusePolicyMust {
			return nil, false, ErrIncompressible
		}
		// Do not attempt to re-use later.
		s.prevTable = s.prevTable[:0]
	}

	// Calculate new table.
	err = s.buildCTable()
	if err != nil {
		return nil, false, err
	}

	if false && !s.canUseTable(s.cTable) {
		panic("invalid table generated")
	}

	if s.Reuse == ReusePolicyAllow && canReuse {
		hSize := len(s.Out)
		oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
		newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
		if oldSize <= hSize+newSize || hSize+12 >= wantSize {
			// Retain cTable even if we re-use.
			keepTable := s.cTable
			keepTL := s.actualTableLog

			s.cTable = s.prevTable
			s.actualTableLog = s.prevTableLog
			s.Out, err = compressor(in)

			// Restore ctable.
			s.cTable = keepTable
			s.actualTableLog = keepTL
			if err != nil {
				return nil, false, err
			}
			if len(s.Out) >= wantSize {
				return nil, false, ErrIncompressible
			}
			s.OutData = s.Out
			return s.Out, true, nil
		}
	}

	// Use new table
	err = s.cTable.write(s)
	if err != nil {
		s.OutTable = nil
		return nil, false, err
	}
	s.OutTable = s.Out

	// Compress using new table
	s.Out, err = compressor(in)
	if err != nil {
		s.OutTable = nil
		return nil, false, err
	}
	if len(s.Out) >= wantSize {
		s.OutTable = nil
		return nil, false, ErrIncompressible
	}
	// Move current table into previous.
	s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
	s.OutData = s.Out[len(s.OutTable):]
	return s.Out, false, nil
}

func (s *Scratch) compress1X(src []byte) ([]byte, error) {
	return s.compress1xDo(s.Out, src)
}

func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
	var bw = bitWriter{out: dst}

	// N is length divisible by 4.
	n := len(src)
	n -= n & 3
	cTable := s.cTable[:256]

	// Encode last bytes.
	for i := len(src) & 3; i > 0; i-- {
		bw.encSymbol(cTable, src[n+i-1])
	}
	n -= 4
	if s.actualTableLog <= 8 {
		for ; n >= 0; n -= 4 {
			tmp := src[n : n+4]
			// tmp should be len 4
			bw.flush32()
			bw.encTwoSymbols(cTable, tmp[3], tmp[2])
			bw.encTwoSymbols(cTable, tmp[1], tmp[0])
		}
	} else {
		for ; n >= 0; n -= 4 {
			tmp := src[n : n+4]
			// tmp should be len 4
			bw.flush32()
			bw.encTwoSymbols(cTable, tmp[3], tmp[2])
			bw.flush32()
			bw.encTwoSymbols(cTable, tmp[1], tmp[0])
		}
	}
	err := bw.close()
	return bw.out, err
}

var sixZeros [6]byte

func (s *Scratch) compress4X(src []byte) ([]byte, error) {
	if len(src) < 12 {
		return nil, ErrIncompressible
	}
	segmentSize := (len(src) + 3) / 4

	// Add placeholder for output length
	offsetIdx := len(s.Out)
	s.Out = append(s.Out, sixZeros[:]...)

	for i := 0; i < 4; i++ {
		toDo := src
		if len(toDo) > segmentSize {
			toDo = toDo[:segmentSize]
		}
		src = src[len(toDo):]

		var err error
		idx := len(s.Out)
		s.Out, err = s.compress1xDo(s.Out, toDo)
		if err != nil {
			return nil, err
		}
		// Write compressed length as little endian before block.
		if i < 3 {
			// Last length is not written.
			length := len(s.Out) - idx
			s.Out[i*2+offsetIdx] = byte(length)
			s.Out[i*2+offsetIdx+1] = byte(length >> 8)
		}
	}

	return s.Out, nil
}

// compress4Xp will compress 4 streams using separate goroutines.
func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
	if len(src) < 12 {
		return nil, ErrIncompressible
	}
	// Add placeholder for output length
	s.Out = s.Out[:6]

	segmentSize := (len(src) + 3) / 4
	var wg sync.WaitGroup
	var errs [4]error
	wg.Add(4)
	for i := 0; i < 4; i++ {
		toDo := src
		if len(toDo) > segmentSize {
			toDo = toDo[:segmentSize]
		}
		src = src[len(toDo):]

		// Separate goroutine for each block.
		go func(i int) {
			s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
			wg.Done()
		}(i)
	}
	wg.Wait()
	for i := 0; i < 4; i++ {
		if errs[i] != nil {
			return nil, errs[i]
		}
		o := s.tmpOut[i]
		// Write compressed length as little endian before block.
		if i < 3 {
			// Last length is not written.
			s.Out[i*2] = byte(len(o))
			s.Out[i*2+1] = byte(len(o) >> 8)
		}

		// Write output.
		s.Out = append(s.Out, o...)
	}
	return s.Out, nil
}

// countSimple will create a simple histogram in s.count.
// Returns the biggest count.
// Does not update s.clearCount.
func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
	reuse = true
	for _, v := range in {
		s.count[v]++
	}
	m := uint32(0)
	if len(s.prevTable) > 0 {
		for i, v := range s.count[:] {
			if v > m {
				m = v
			}
			if v > 0 {
				s.symbolLen = uint16(i) + 1
				if i >= len(s.prevTable) {
					reuse = false
				} else {
					if s.prevTable[i].nBits == 0 {
						reuse = false
					}
				}
			}
		}
		return int(m), reuse
	}
	for i, v := range s.count[:] {
		if v > m {
			m = v
		}
		if v > 0 {
			s.symbolLen = uint16(i) + 1
		}
	}
	return int(m), false
}

func (s *Scratch) canUseTable(c cTable) bool {
	if len(c) < int(s.symbolLen) {
		return false
	}
	for i, v := range s.count[:s.symbolLen] {
		if v != 0 && c[i].nBits == 0 {
			return false
		}
	}
	return true
}

func (s *Scratch) validateTable(c cTable) bool {
	if len(c) < int(s.symbolLen) {
		return false
	}
	for i, v := range s.count[:s.symbolLen] {
		if v != 0 {
			if c[i].nBits == 0 {
				return false
			}
			if c[i].nBits > s.actualTableLog {
				return false
			}
		}
	}
	return true
}

// minTableLog provides the minimum logSize to safely represent a distribution.
func (s *Scratch) minTableLog() uint8 {
	minBitsSrc := highBit32(uint32(s.br.remain())) + 1
	minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
	if minBitsSrc < minBitsSymbols {
		return uint8(minBitsSrc)
	}
	return uint8(minBitsSymbols)
}

// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
func (s *Scratch) optimalTableLog() {
	tableLog := s.TableLog
	minBits := s.minTableLog()
	maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
	if maxBitsSrc < tableLog {
		// Accuracy can be reduced
		tableLog = maxBitsSrc
	}
	if minBits > tableLog {
		tableLog = minBits
	}
	// Need a minimum to safely represent all symbol values
	if tableLog < minTablelog {
		tableLog = minTablelog
	}
	if tableLog > tableLogMax {
		tableLog = tableLogMax
	}
	s.actualTableLog = tableLog
}

type cTableEntry struct {
	val   uint16
	nBits uint8
	// We have 8 bits extra
}

const huffNodesMask = huffNodesLen - 1

func (s *Scratch) buildCTable() error {
	s.optimalTableLog()
	s.huffSort()
	if cap(s.cTable) < maxSymbolValue+1 {
		s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
	} else {
		s.cTable = s.cTable[:s.symbolLen]
		for i := range s.cTable {
			s.cTable[i] = cTableEntry{}
		}
	}

	var startNode = int16(s.symbolLen)
	nonNullRank := s.symbolLen - 1

	nodeNb := startNode
	huffNode := s.nodes[1 : huffNodesLen+1]

	// This overlays the slice above, but allows "-1" index lookups.
	// Different from reference implementation.
	huffNode0 := s.nodes[0 : huffNodesLen+1]

	for huffNode[nonNullRank].count == 0 {
		nonNullRank--
	}

	lowS := int16(nonNullRank)
	nodeRoot := nodeNb + lowS - 1
	lowN := nodeNb
	huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
	huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
	nodeNb++
	lowS -= 2
	for n := nodeNb; n <= nodeRoot; n++ {
		huffNode[n].count = 1 << 30
	}
	// fake entry, strong barrier
	huffNode0[0].count = 1 << 31

	// create parents
	for nodeNb <= nodeRoot {
		var n1, n2 int16
		if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
			n1 = lowS
			lowS--
		} else {
			n1 = lowN
			lowN++
		}
		if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
			n2 = lowS
			lowS--
		} else {
			n2 = lowN
			lowN++
		}

		huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
		huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
		nodeNb++
	}

	// distribute weights (unlimited tree height)
	huffNode[nodeRoot].nbBits = 0
	for n := nodeRoot - 1; n >= startNode; n-- {
		huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
	}
	for n := uint16(0); n <= nonNullRank; n++ {
		huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
	}
	s.actualTableLog = s.setMaxHeight(int(nonNullRank))
	maxNbBits := s.actualTableLog

	// fill result into tree (val, nbBits)
	if maxNbBits > tableLogMax {
		return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
	}
	var nbPerRank [tableLogMax + 1]uint16
	var valPerRank [16]uint16
	for _, v := range huffNode[:nonNullRank+1] {
		nbPerRank[v.nbBits]++
	}
	// determine stating value per rank
	{
		min := uint16(0)
		for n := maxNbBits; n > 0; n-- {
			// get starting value within each rank
			valPerRank[n] = min
			min += nbPerRank[n]
			min >>= 1
		}
	}

	// push nbBits per symbol, symbol order
	for _, v := range huffNode[:nonNullRank+1] {
		s.cTable[v.symbol].nBits = v.nbBits
	}

	// assign value within rank, symbol order
	t := s.cTable[:s.symbolLen]
	for n, val := range t {
		nbits := val.nBits & 15
		v := valPerRank[nbits]
		t[n].val = v
		valPerRank[nbits] = v + 1
	}

	return nil
}

// huffSort will sort symbols, decreasing order.
func (s *Scratch) huffSort() {
	type rankPos struct {
		base    uint32
		current uint32
	}

	// Clear nodes
	nodes := s.nodes[:huffNodesLen+1]
	s.nodes = nodes
	nodes = nodes[1 : huffNodesLen+1]

	// Sort into buckets based on length of symbol count.
	var rank [32]rankPos
	for _, v := range s.count[:s.symbolLen] {
		r := highBit32(v+1) & 31
		rank[r].base++
	}
	// maxBitLength is log2(BlockSizeMax) + 1
	const maxBitLength = 18 + 1
	for n := maxBitLength; n > 0; n-- {
		rank[n-1].base += rank[n].base
	}
	for n := range rank[:maxBitLength] {
		rank[n].current = rank[n].base
	}
	for n, c := range s.count[:s.symbolLen] {
		r := (highBit32(c+1) + 1) & 31
		pos := rank[r].current
		rank[r].current++
		prev := nodes[(pos-1)&huffNodesMask]
		for pos > rank[r].base && c > prev.count {
			nodes[pos&huffNodesMask] = prev
			pos--
			prev = nodes[(pos-1)&huffNodesMask]
		}
		nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
	}
}

func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
	maxNbBits := s.actualTableLog
	huffNode := s.nodes[1 : huffNodesLen+1]
	//huffNode = huffNode[: huffNodesLen]

	largestBits := huffNode[lastNonNull].nbBits

	// early exit : no elt > maxNbBits
	if largestBits <= maxNbBits {
		return largestBits
	}
	totalCost := int(0)
	baseCost := int(1) << (largestBits - maxNbBits)
	n := uint32(lastNonNull)

	for huffNode[n].nbBits > maxNbBits {
		totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
		huffNode[n].nbBits = maxNbBits
		n--
	}
	// n stops at huffNode[n].nbBits <= maxNbBits

	for huffNode[n].nbBits == maxNbBits {
		n--
	}
	// n end at index of smallest symbol using < maxNbBits

	// renorm totalCost
	totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */

	// repay normalized cost
	{
		const noSymbol = 0xF0F0F0F0
		var rankLast [tableLogMax + 2]uint32

		for i := range rankLast[:] {
			rankLast[i] = noSymbol
		}

		// Get pos of last (smallest) symbol per rank
		{
			currentNbBits := maxNbBits
			for pos := int(n); pos >= 0; pos-- {
				if huffNode[pos].nbBits >= currentNbBits {
					continue
				}
				currentNbBits = huffNode[pos].nbBits // < maxNbBits
				rankLast[maxNbBits-currentNbBits] = uint32(pos)
			}
		}

		for totalCost > 0 {
			nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1

			for ; nBitsToDecrease > 1; nBitsToDecrease-- {
				highPos := rankLast[nBitsToDecrease]
				lowPos := rankLast[nBitsToDecrease-1]
				if highPos == noSymbol {
					continue
				}
				if lowPos == noSymbol {
					break
				}
				highTotal := huffNode[highPos].count
				lowTotal := 2 * huffNode[lowPos].count
				if highTotal <= lowTotal {
					break
				}
			}
			// only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
			// HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
			// FIXME: try to remove
			for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
				nBitsToDecrease++
			}
			totalCost -= 1 << (nBitsToDecrease - 1)
			if rankLast[nBitsToDecrease-1] == noSymbol {
				// this rank is no longer empty
				rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
			}
			huffNode[rankLast[nBitsToDecrease]].nbBits++
			if rankLast[nBitsToDecrease] == 0 {
				/* special case, reached largest symbol */
				rankLast[nBitsToDecrease] = noSymbol
			} else {
				rankLast[nBitsToDecrease]--
				if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
					rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
				}
			}
		}

		for totalCost < 0 { /* Sometimes, cost correction overshoot */
			if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
				for huffNode[n].nbBits == maxNbBits {
					n--
				}
				huffNode[n+1].nbBits--
				rankLast[1] = n + 1
				totalCost++
				continue
			}
			huffNode[rankLast[1]+1].nbBits--
			rankLast[1]++
			totalCost++
		}
	}
	return maxNbBits
}

type nodeElt struct {
	count  uint32
	parent uint16
	symbol byte
	nbBits uint8
}