package torrent

import (
	"context"
	"crypto/sha1"
	"fmt"
	"io"
	"io/fs"
	"log/slog"
	"os"
	"path/filepath"
	"slices"

	"git.kmsign.ru/royalcat/tstor/pkg/rlog"
	"github.com/dustin/go-humanize"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
	"golang.org/x/exp/maps"
	"golang.org/x/sys/unix"
)

func (s *fileStorage) Dedupe(ctx context.Context) (uint64, error) {
	ctx, span := tracer.Start(ctx, fmt.Sprintf("Dedupe"))
	defer span.End()

	log := s.log

	sizeMap := map[int64][]string{}
	err := s.iterFiles(ctx, func(ctx context.Context, path string, info fs.FileInfo) error {
		size := info.Size()
		sizeMap[size] = append(sizeMap[size], path)
		return nil
	})
	if err != nil {
		return 0, err
	}

	maps.DeleteFunc(sizeMap, func(k int64, v []string) bool {
		return len(v) <= 1
	})

	span.AddEvent("collected files with same size", trace.WithAttributes(
		attribute.Int("count", len(sizeMap)),
	))

	var deduped uint64 = 0

	i := 0
	for _, paths := range sizeMap {
		if i%100 == 0 {
			log.Info(ctx, "deduping in progress", slog.Int("current", i), slog.Int("total", len(sizeMap)))
		}
		i++

		if ctx.Err() != nil {
			return deduped, ctx.Err()
		}

		slices.Sort(paths)
		paths = slices.Compact(paths)
		if len(paths) <= 1 {
			continue
		}

		paths, err = applyErr(paths, filepath.Abs)
		if err != nil {
			return deduped, err
		}

		dedupedGroup, err := s.dedupeFiles(ctx, paths)
		if err != nil {
			log.Error(ctx, "Error applying dedupe", slog.Any("files", paths), rlog.Error(err))
			continue
		}

		if dedupedGroup > 0 {
			deduped += dedupedGroup
			log.Info(ctx, "deduped file group",
				slog.String("files", fmt.Sprint(paths)),
				slog.String("deduped", humanize.Bytes(dedupedGroup)),
				slog.String("deduped_total", humanize.Bytes(deduped)),
			)
		}

	}

	return deduped, nil
}

func applyErr[E, O any](in []E, apply func(E) (O, error)) ([]O, error) {
	out := make([]O, 0, len(in))
	for _, p := range in {
		o, err := apply(p)
		if err != nil {
			return out, err
		}
		out = append(out, o)

	}
	return out, nil
}

// const blockSize uint64 = 4096

func (s *fileStorage) dedupeFiles(ctx context.Context, paths []string) (deduped uint64, err error) {
	ctx, span := tracer.Start(ctx, fmt.Sprintf("dedupeFiles"), trace.WithAttributes(
		attribute.StringSlice("files", paths),
	))
	defer func() {
		span.SetAttributes(attribute.Int64("deduped", int64(deduped)))
		if err != nil {
			span.RecordError(err)
		}
		span.End()
	}()

	log := s.log

	srcF, err := os.Open(paths[0])
	if err != nil {
		return deduped, fmt.Errorf("error opening file %s: %w", paths[0], err)
	}
	defer srcF.Close()
	srcStat, err := srcF.Stat()
	if err != nil {
		return deduped, fmt.Errorf("error stat file %s: %w", paths[0], err)
	}

	srcFd := int(srcF.Fd())
	srcSize := srcStat.Size()

	fsStat := unix.Statfs_t{}
	err = unix.Fstatfs(srcFd, &fsStat)
	if err != nil {
		span.RecordError(err)
		return deduped, fmt.Errorf("error statfs file %s: %w", paths[0], err)
	}

	srcHash, err := filehash(srcF)
	if err != nil {
		return deduped, fmt.Errorf("error hashing file %s: %w", paths[0], err)
	}

	if int64(fsStat.Bsize) > srcSize { // for btrfs it means file in residing in not deduplicatable metadata
		return deduped, nil
	}

	blockSize := uint64((srcSize % int64(fsStat.Bsize)) * int64(fsStat.Bsize))

	span.SetAttributes(attribute.Int64("blocksize", int64(blockSize)))

	rng := unix.FileDedupeRange{
		Src_offset: 0,
		Src_length: blockSize,
		Info:       []unix.FileDedupeRangeInfo{},
	}

	for _, dst := range paths[1:] {
		if ctx.Err() != nil {
			return deduped, ctx.Err()
		}

		destF, err := os.OpenFile(dst, os.O_RDWR, os.ModePerm)
		if err != nil {
			return deduped, fmt.Errorf("error opening file %s: %w", dst, err)
		}
		defer destF.Close()

		dstHash, err := filehash(destF)
		if err != nil {
			return deduped, fmt.Errorf("error hashing file %s: %w", dst, err)
		}

		if srcHash != dstHash {
			destF.Close()
			continue
		}

		rng.Info = append(rng.Info, unix.FileDedupeRangeInfo{
			Dest_fd:     int64(destF.Fd()),
			Dest_offset: 0,
		})
	}

	if len(rng.Info) == 0 {
		return deduped, nil
	}

	log.Info(ctx, "found same files, deduping", slog.Any("files", paths), slog.String("size", humanize.Bytes(uint64(srcStat.Size()))))

	if ctx.Err() != nil {
		return deduped, ctx.Err()
	}

	rng.Src_offset = 0
	for i := range rng.Info {
		rng.Info[i].Dest_offset = 0
	}

	err = unix.IoctlFileDedupeRange(srcFd, &rng)
	if err != nil {
		return deduped, fmt.Errorf("error calling FIDEDUPERANGE: %w", err)
	}

	for i := range rng.Info {
		deduped += rng.Info[i].Bytes_deduped

		rng.Info[i].Status = 0
		rng.Info[i].Bytes_deduped = 0
	}

	return deduped, nil
}

const compareBlockSize = 1024 * 128

func filehash(r io.Reader) ([20]byte, error) {
	buf := make([]byte, compareBlockSize)
	_, err := r.Read(buf)
	if err != nil && err != io.EOF {
		return [20]byte{}, err
	}

	return sha1.Sum(buf), nil
}