small refactor*
This commit is contained in:
parent
b6b541e050
commit
24a4d30275
232 changed files with 2164 additions and 1906 deletions
server/pkg/go-nfs
344
server/pkg/go-nfs/conn.go
Normal file
344
server/pkg/go-nfs/conn.go
Normal file
|
@ -0,0 +1,344 @@
|
|||
package nfs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
xdr2 "github.com/rasky/go-xdr/xdr2"
|
||||
"github.com/willscott/go-nfs-client/nfs/rpc"
|
||||
"github.com/willscott/go-nfs-client/nfs/xdr"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInputInvalid is returned when input cannot be parsed
|
||||
ErrInputInvalid = errors.New("invalid input")
|
||||
// ErrAlreadySent is returned when writing a header/status multiple times
|
||||
ErrAlreadySent = errors.New("response already started")
|
||||
)
|
||||
|
||||
// ResponseCode is a combination of accept_stat and reject_stat.
|
||||
type ResponseCode uint32
|
||||
|
||||
// ResponseCode Codes
|
||||
const (
|
||||
ResponseCodeSuccess ResponseCode = iota
|
||||
ResponseCodeProgUnavailable
|
||||
ResponseCodeProcUnavailable
|
||||
ResponseCodeGarbageArgs
|
||||
ResponseCodeSystemErr
|
||||
ResponseCodeRPCMismatch
|
||||
ResponseCodeAuthError
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
*Server
|
||||
writeSerializer chan []byte
|
||||
net.Conn
|
||||
}
|
||||
|
||||
var tracer = otel.Tracer("git.kmsign.ru/royalcat/tstor/server/pkg/go-nfs")
|
||||
|
||||
func (c *conn) serve() {
|
||||
ctx := context.Background() // TODO implement correct timeout on serve side
|
||||
|
||||
c.writeSerializer = make(chan []byte, 1)
|
||||
go c.serializeWrites(ctx)
|
||||
|
||||
bio := bufio.NewReader(c.Conn)
|
||||
for {
|
||||
w, err := c.readRequestHeader(ctx, bio)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Clean close.
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
Log.Tracef("request: %v", w.req)
|
||||
err = c.handle(ctx, w)
|
||||
respErr := w.finish(ctx)
|
||||
if err != nil {
|
||||
Log.Errorf("error handling req: %v", err)
|
||||
// failure to handle at a level needing to close the connection.
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
if respErr != nil {
|
||||
Log.Errorf("error sending response: %v", respErr)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) serializeWrites(ctx context.Context) {
|
||||
// todo: maybe don't need the extra buffer
|
||||
writer := bufio.NewWriter(c.Conn)
|
||||
var fragmentBuf [4]byte
|
||||
var fragmentInt uint32
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-c.writeSerializer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// prepend the fragmentation header
|
||||
fragmentInt = uint32(len(msg))
|
||||
fragmentInt |= (1 << 31)
|
||||
binary.BigEndian.PutUint32(fragmentBuf[:], fragmentInt)
|
||||
n, err := writer.Write(fragmentBuf[:])
|
||||
if n < 4 || err != nil {
|
||||
return
|
||||
}
|
||||
n, err = writer.Write(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n < len(msg) {
|
||||
panic("todo: ensure writes complete fully.")
|
||||
}
|
||||
if err = writer.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle a request. errors from this method indicate a failure to read or
|
||||
// write on the network stream, and trigger a disconnection of the connection.
|
||||
func (c *conn) handle(ctx context.Context, w *response) (err error) {
|
||||
ctx, span := tracer.Start(ctx, fmt.Sprintf("nfs.handle.%s", NFSProcedure(w.req.Header.Proc).String()))
|
||||
defer span.End()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
} else {
|
||||
span.SetStatus(codes.Ok, "")
|
||||
}
|
||||
}()
|
||||
|
||||
handler := c.Server.handlerFor(w.req.Header.Prog, w.req.Header.Proc)
|
||||
if handler == nil {
|
||||
Log.Errorf("No handler for %d.%d", w.req.Header.Prog, w.req.Header.Proc)
|
||||
if err := w.drain(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.err(ctx, w, &ResponseCodeProcUnavailableError{})
|
||||
}
|
||||
appError := handler(ctx, w, c.Server.Handler)
|
||||
if drainErr := w.drain(ctx); drainErr != nil {
|
||||
return drainErr
|
||||
}
|
||||
if appError != nil && !w.responded {
|
||||
Log.Errorf("call to %+v failed: %v", handler, appError)
|
||||
if err := c.err(ctx, w, appError); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !w.responded {
|
||||
Log.Errorf("Handler did not indicate response status via writing or erroring")
|
||||
if err := c.err(ctx, w, &ResponseCodeSystemError{}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) err(ctx context.Context, w *response, err error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if w.err == nil {
|
||||
w.err = err
|
||||
}
|
||||
|
||||
if w.responded {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcErr := w.errorFmt(err)
|
||||
if writeErr := w.writeHeader(rpcErr.Code()); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
body, _ := rpcErr.MarshalBinary()
|
||||
return w.Write(body)
|
||||
}
|
||||
|
||||
type request struct {
|
||||
xid uint32
|
||||
rpc.Header
|
||||
Body io.Reader
|
||||
}
|
||||
|
||||
func (r *request) String() string {
|
||||
if r.Header.Prog == nfsServiceID {
|
||||
return fmt.Sprintf("RPC #%d (nfs.%s)", r.xid, NFSProcedure(r.Header.Proc))
|
||||
} else if r.Header.Prog == mountServiceID {
|
||||
return fmt.Sprintf("RPC #%d (mount.%s)", r.xid, MountProcedure(r.Header.Proc))
|
||||
}
|
||||
return fmt.Sprintf("RPC #%d (%d.%d)", r.xid, r.Header.Prog, r.Header.Proc)
|
||||
}
|
||||
|
||||
type response struct {
|
||||
*conn
|
||||
writer *bytes.Buffer
|
||||
responded bool
|
||||
err error
|
||||
errorFmt func(error) RPCError
|
||||
req *request
|
||||
}
|
||||
|
||||
func (w *response) writeXdrHeader() error {
|
||||
err := xdr.Write(w.writer, &w.req.xid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
respType := uint32(1)
|
||||
err = xdr.Write(w.writer, &respType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *response) writeHeader(code ResponseCode) error {
|
||||
if w.responded {
|
||||
return ErrAlreadySent
|
||||
}
|
||||
w.responded = true
|
||||
if err := w.writeXdrHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := rpc.MsgAccepted
|
||||
if code == ResponseCodeAuthError || code == ResponseCodeRPCMismatch {
|
||||
status = rpc.MsgDenied
|
||||
}
|
||||
|
||||
err := xdr.Write(w.writer, &status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if status == rpc.MsgAccepted {
|
||||
// Write opaque_auth header.
|
||||
err = xdr.Write(w.writer, &rpc.AuthNull)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return xdr.Write(w.writer, &code)
|
||||
}
|
||||
|
||||
// Write a response to an xdr message
|
||||
func (w *response) Write(dat []byte) error {
|
||||
if !w.responded {
|
||||
if err := w.writeHeader(ResponseCodeSuccess); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
acc := 0
|
||||
for acc < len(dat) {
|
||||
n, err := w.writer.Write(dat[acc:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acc += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// drain reads the rest of the request frame if not consumed by the handler.
|
||||
func (w *response) drain(ctx context.Context) error {
|
||||
if reader, ok := w.req.Body.(*io.LimitedReader); ok {
|
||||
if reader.N == 0 {
|
||||
return nil
|
||||
}
|
||||
// todo: wrap body in a context reader.
|
||||
_, err := io.CopyN(io.Discard, w.req.Body, reader.N)
|
||||
if err == nil || err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
func (w *response) finish(ctx context.Context) error {
|
||||
select {
|
||||
case w.conn.writeSerializer <- w.writer.Bytes():
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) readRequestHeader(ctx context.Context, reader *bufio.Reader) (w *response, err error) {
|
||||
fragment, err := xdr.ReadUint32(reader)
|
||||
if err != nil {
|
||||
if xdrErr, ok := err.(*xdr2.UnmarshalError); ok {
|
||||
if xdrErr.Err == io.EOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if fragment&(1<<31) == 0 {
|
||||
Log.Warnf("Warning: haven't implemented fragment reconstruction.\n")
|
||||
return nil, ErrInputInvalid
|
||||
}
|
||||
reqLen := fragment - uint32(1<<31)
|
||||
if reqLen < 40 {
|
||||
return nil, ErrInputInvalid
|
||||
}
|
||||
|
||||
r := io.LimitedReader{R: reader, N: int64(reqLen)}
|
||||
|
||||
xid, err := xdr.ReadUint32(&r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqType, err := xdr.ReadUint32(&r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reqType != 0 { // 0 = request, 1 = response
|
||||
return nil, ErrInputInvalid
|
||||
}
|
||||
|
||||
req := request{
|
||||
xid,
|
||||
rpc.Header{},
|
||||
&r,
|
||||
}
|
||||
if err = xdr.Read(&r, &req.Header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w = &response{
|
||||
conn: c,
|
||||
req: &req,
|
||||
errorFmt: basicErrorFormatter,
|
||||
// TODO: use a pool for these.
|
||||
writer: bytes.NewBuffer([]byte{}),
|
||||
}
|
||||
return w, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue