package ioutils

import (
	"context"
	"io"
	"os"
	"sync"

	"github.com/royalcat/ctxio"
)

type DiskCacheReader struct {
	m sync.Mutex

	fo int64
	fr *os.File
	to int64
	tr ctxio.Reader
}

var _ ctxio.ReaderAt = (*DiskCacheReader)(nil)
var _ ctxio.Reader = (*DiskCacheReader)(nil)
var _ ctxio.Closer = (*DiskCacheReader)(nil)

func NewDiskCacheReader(r ctxio.Reader) (*DiskCacheReader, error) {
	tempDir, err := os.MkdirTemp("/tmp", "tstor")
	if err != nil {
		return nil, err
	}
	fr, err := os.CreateTemp(tempDir, "dtb_tmp")
	if err != nil {
		return nil, err
	}

	tr := ctxio.TeeReader(r, ctxio.WrapIoWriter(fr))
	return &DiskCacheReader{fr: fr, tr: tr}, nil
}

func (dtr *DiskCacheReader) ReadAt(ctx context.Context, p []byte, off int64) (int, error) {
	dtr.m.Lock()
	defer dtr.m.Unlock()
	tb := off + int64(len(p))

	if tb > dtr.fo {
		w, err := ctxio.CopyN(ctx, ctxio.Discard, dtr.tr, tb-dtr.fo)
		dtr.to += w
		if err != nil && err != io.EOF {
			return 0, err
		}
	}

	n, err := dtr.fr.ReadAt(p, off)
	dtr.fo += int64(n)
	return n, err
}

func (dtr *DiskCacheReader) Read(ctx context.Context, p []byte) (n int, err error) {
	dtr.m.Lock()
	defer dtr.m.Unlock()
	// use directly tee reader here
	n, err = dtr.tr.Read(ctx, p)
	dtr.to += int64(n)
	return
}

func (dtr *DiskCacheReader) Close(ctx context.Context) error {
	if err := dtr.fr.Close(); err != nil {
		return err
	}

	return os.Remove(dtr.fr.Name())
}