package nfs import ( "bufio" "bytes" "context" "encoding/binary" "errors" "fmt" "io" "net" xdr2 "github.com/rasky/go-xdr/xdr2" "github.com/willscott/go-nfs-client/nfs/rpc" "github.com/willscott/go-nfs-client/nfs/xdr" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/codes" ) var ( // ErrInputInvalid is returned when input cannot be parsed ErrInputInvalid = errors.New("invalid input") // ErrAlreadySent is returned when writing a header/status multiple times ErrAlreadySent = errors.New("response already started") ) // ResponseCode is a combination of accept_stat and reject_stat. type ResponseCode uint32 // ResponseCode Codes const ( ResponseCodeSuccess ResponseCode = iota ResponseCodeProgUnavailable ResponseCodeProcUnavailable ResponseCodeGarbageArgs ResponseCodeSystemErr ResponseCodeRPCMismatch ResponseCodeAuthError ) type conn struct { *Server writeSerializer chan []byte net.Conn } var tracer = otel.Tracer("git.kmsign.ru/royalcat/tstor/pkg/go-nfs") func (c *conn) serve() { ctx := context.Background() // TODO implement correct timeout on serve side c.writeSerializer = make(chan []byte, 1) go c.serializeWrites(ctx) bio := bufio.NewReader(c.Conn) for { w, err := c.readRequestHeader(ctx, bio) if err != nil { if err == io.EOF { // Clean close. c.Close() return } return } Log.Tracef("request: %v", w.req) err = c.handle(ctx, w) respErr := w.finish(ctx) if err != nil { Log.Errorf("error handling req: %v", err) // failure to handle at a level needing to close the connection. c.Close() return } if respErr != nil { Log.Errorf("error sending response: %v", respErr) c.Close() return } } } func (c *conn) serializeWrites(ctx context.Context) { // todo: maybe don't need the extra buffer writer := bufio.NewWriter(c.Conn) var fragmentBuf [4]byte var fragmentInt uint32 for { select { case <-ctx.Done(): return case msg, ok := <-c.writeSerializer: if !ok { return } // prepend the fragmentation header fragmentInt = uint32(len(msg)) fragmentInt |= (1 << 31) binary.BigEndian.PutUint32(fragmentBuf[:], fragmentInt) n, err := writer.Write(fragmentBuf[:]) if n < 4 || err != nil { return } n, err = writer.Write(msg) if err != nil { return } if n < len(msg) { panic("todo: ensure writes complete fully.") } if err = writer.Flush(); err != nil { return } } } } // Handle a request. errors from this method indicate a failure to read or // write on the network stream, and trigger a disconnection of the connection. func (c *conn) handle(ctx context.Context, w *response) (err error) { ctx, span := tracer.Start(ctx, fmt.Sprintf("nfs.handle.%s", NFSProcedure(w.req.Header.Proc).String())) defer span.End() defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } else { span.SetStatus(codes.Ok, "") } }() handler := c.Server.handlerFor(w.req.Header.Prog, w.req.Header.Proc) if handler == nil { Log.Errorf("No handler for %d.%d", w.req.Header.Prog, w.req.Header.Proc) if err := w.drain(ctx); err != nil { return err } return c.err(ctx, w, &ResponseCodeProcUnavailableError{}) } appError := handler(ctx, w, c.Server.Handler) if drainErr := w.drain(ctx); drainErr != nil { return drainErr } if appError != nil && !w.responded { Log.Errorf("call to %+v failed: %v", handler, appError) if err := c.err(ctx, w, appError); err != nil { return err } } if !w.responded { Log.Errorf("Handler did not indicate response status via writing or erroring") if err := c.err(ctx, w, &ResponseCodeSystemError{}); err != nil { return err } } return nil } func (c *conn) err(ctx context.Context, w *response, err error) error { select { case <-ctx.Done(): return nil default: } if w.err == nil { w.err = err } if w.responded { return nil } rpcErr := w.errorFmt(err) if writeErr := w.writeHeader(rpcErr.Code()); writeErr != nil { return writeErr } body, _ := rpcErr.MarshalBinary() return w.Write(body) } type request struct { xid uint32 rpc.Header Body io.Reader } func (r *request) String() string { if r.Header.Prog == nfsServiceID { return fmt.Sprintf("RPC #%d (nfs.%s)", r.xid, NFSProcedure(r.Header.Proc)) } else if r.Header.Prog == mountServiceID { return fmt.Sprintf("RPC #%d (mount.%s)", r.xid, MountProcedure(r.Header.Proc)) } return fmt.Sprintf("RPC #%d (%d.%d)", r.xid, r.Header.Prog, r.Header.Proc) } type response struct { *conn writer *bytes.Buffer responded bool err error errorFmt func(error) RPCError req *request } func (w *response) writeXdrHeader() error { err := xdr.Write(w.writer, &w.req.xid) if err != nil { return err } respType := uint32(1) err = xdr.Write(w.writer, &respType) if err != nil { return err } return nil } func (w *response) writeHeader(code ResponseCode) error { if w.responded { return ErrAlreadySent } w.responded = true if err := w.writeXdrHeader(); err != nil { return err } status := rpc.MsgAccepted if code == ResponseCodeAuthError || code == ResponseCodeRPCMismatch { status = rpc.MsgDenied } err := xdr.Write(w.writer, &status) if err != nil { return err } if status == rpc.MsgAccepted { // Write opaque_auth header. err = xdr.Write(w.writer, &rpc.AuthNull) if err != nil { return err } } return xdr.Write(w.writer, &code) } // Write a response to an xdr message func (w *response) Write(dat []byte) error { if !w.responded { if err := w.writeHeader(ResponseCodeSuccess); err != nil { return err } } acc := 0 for acc < len(dat) { n, err := w.writer.Write(dat[acc:]) if err != nil { return err } acc += n } return nil } // drain reads the rest of the request frame if not consumed by the handler. func (w *response) drain(ctx context.Context) error { if reader, ok := w.req.Body.(*io.LimitedReader); ok { if reader.N == 0 { return nil } // todo: wrap body in a context reader. _, err := io.CopyN(io.Discard, w.req.Body, reader.N) if err == nil || err == io.EOF { return nil } return err } return io.ErrUnexpectedEOF } func (w *response) finish(ctx context.Context) error { select { case w.conn.writeSerializer <- w.writer.Bytes(): return nil case <-ctx.Done(): return ctx.Err() } } func (c *conn) readRequestHeader(ctx context.Context, reader *bufio.Reader) (w *response, err error) { fragment, err := xdr.ReadUint32(reader) if err != nil { if xdrErr, ok := err.(*xdr2.UnmarshalError); ok { if xdrErr.Err == io.EOF { return nil, io.EOF } } return nil, err } if fragment&(1<<31) == 0 { Log.Warnf("Warning: haven't implemented fragment reconstruction.\n") return nil, ErrInputInvalid } reqLen := fragment - uint32(1<<31) if reqLen < 40 { return nil, ErrInputInvalid } r := io.LimitedReader{R: reader, N: int64(reqLen)} xid, err := xdr.ReadUint32(&r) if err != nil { return nil, err } reqType, err := xdr.ReadUint32(&r) if err != nil { return nil, err } if reqType != 0 { // 0 = request, 1 = response return nil, ErrInputInvalid } req := request{ xid, rpc.Header{}, &r, } if err = xdr.Read(&r, &req.Header); err != nil { return nil, err } w = &response{ conn: c, req: &req, errorFmt: basicErrorFormatter, // TODO: use a pool for these. writer: bytes.NewBuffer([]byte{}), } return w, nil }