mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2024-12-29 22:23:10 +00:00
202 lines
5.3 KiB
Go
202 lines
5.3 KiB
Go
|
package pgproto3
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
)
|
||
|
|
||
|
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
|
||
|
type Frontend struct {
|
||
|
cr ChunkReader
|
||
|
w io.Writer
|
||
|
|
||
|
// Backend message flyweights
|
||
|
authenticationOk AuthenticationOk
|
||
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||
|
authenticationMD5Password AuthenticationMD5Password
|
||
|
authenticationSASL AuthenticationSASL
|
||
|
authenticationSASLContinue AuthenticationSASLContinue
|
||
|
authenticationSASLFinal AuthenticationSASLFinal
|
||
|
backendKeyData BackendKeyData
|
||
|
bindComplete BindComplete
|
||
|
closeComplete CloseComplete
|
||
|
commandComplete CommandComplete
|
||
|
copyBothResponse CopyBothResponse
|
||
|
copyData CopyData
|
||
|
copyInResponse CopyInResponse
|
||
|
copyOutResponse CopyOutResponse
|
||
|
copyDone CopyDone
|
||
|
dataRow DataRow
|
||
|
emptyQueryResponse EmptyQueryResponse
|
||
|
errorResponse ErrorResponse
|
||
|
functionCallResponse FunctionCallResponse
|
||
|
noData NoData
|
||
|
noticeResponse NoticeResponse
|
||
|
notificationResponse NotificationResponse
|
||
|
parameterDescription ParameterDescription
|
||
|
parameterStatus ParameterStatus
|
||
|
parseComplete ParseComplete
|
||
|
readyForQuery ReadyForQuery
|
||
|
rowDescription RowDescription
|
||
|
portalSuspended PortalSuspended
|
||
|
|
||
|
bodyLen int
|
||
|
msgType byte
|
||
|
partialMsg bool
|
||
|
authType uint32
|
||
|
}
|
||
|
|
||
|
// NewFrontend creates a new Frontend.
|
||
|
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
||
|
return &Frontend{cr: cr, w: w}
|
||
|
}
|
||
|
|
||
|
// Send sends a message to the backend.
|
||
|
func (f *Frontend) Send(msg FrontendMessage) error {
|
||
|
_, err := f.w.Write(msg.Encode(nil))
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func translateEOFtoErrUnexpectedEOF(err error) error {
|
||
|
if err == io.EOF {
|
||
|
return io.ErrUnexpectedEOF
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
|
||
|
func (f *Frontend) Receive() (BackendMessage, error) {
|
||
|
if !f.partialMsg {
|
||
|
header, err := f.cr.Next(5)
|
||
|
if err != nil {
|
||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||
|
}
|
||
|
|
||
|
f.msgType = header[0]
|
||
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||
|
f.partialMsg = true
|
||
|
}
|
||
|
|
||
|
msgBody, err := f.cr.Next(f.bodyLen)
|
||
|
if err != nil {
|
||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||
|
}
|
||
|
|
||
|
f.partialMsg = false
|
||
|
|
||
|
var msg BackendMessage
|
||
|
switch f.msgType {
|
||
|
case '1':
|
||
|
msg = &f.parseComplete
|
||
|
case '2':
|
||
|
msg = &f.bindComplete
|
||
|
case '3':
|
||
|
msg = &f.closeComplete
|
||
|
case 'A':
|
||
|
msg = &f.notificationResponse
|
||
|
case 'c':
|
||
|
msg = &f.copyDone
|
||
|
case 'C':
|
||
|
msg = &f.commandComplete
|
||
|
case 'd':
|
||
|
msg = &f.copyData
|
||
|
case 'D':
|
||
|
msg = &f.dataRow
|
||
|
case 'E':
|
||
|
msg = &f.errorResponse
|
||
|
case 'G':
|
||
|
msg = &f.copyInResponse
|
||
|
case 'H':
|
||
|
msg = &f.copyOutResponse
|
||
|
case 'I':
|
||
|
msg = &f.emptyQueryResponse
|
||
|
case 'K':
|
||
|
msg = &f.backendKeyData
|
||
|
case 'n':
|
||
|
msg = &f.noData
|
||
|
case 'N':
|
||
|
msg = &f.noticeResponse
|
||
|
case 'R':
|
||
|
var err error
|
||
|
msg, err = f.findAuthenticationMessageType(msgBody)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case 's':
|
||
|
msg = &f.portalSuspended
|
||
|
case 'S':
|
||
|
msg = &f.parameterStatus
|
||
|
case 't':
|
||
|
msg = &f.parameterDescription
|
||
|
case 'T':
|
||
|
msg = &f.rowDescription
|
||
|
case 'V':
|
||
|
msg = &f.functionCallResponse
|
||
|
case 'W':
|
||
|
msg = &f.copyBothResponse
|
||
|
case 'Z':
|
||
|
msg = &f.readyForQuery
|
||
|
default:
|
||
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||
|
}
|
||
|
|
||
|
err = msg.Decode(msgBody)
|
||
|
return msg, err
|
||
|
}
|
||
|
|
||
|
// Authentication message type constants.
|
||
|
// See src/include/libpq/pqcomm.h for all
|
||
|
// constants.
|
||
|
const (
|
||
|
AuthTypeOk = 0
|
||
|
AuthTypeCleartextPassword = 3
|
||
|
AuthTypeMD5Password = 5
|
||
|
AuthTypeSCMCreds = 6
|
||
|
AuthTypeGSS = 7
|
||
|
AuthTypeGSSCont = 8
|
||
|
AuthTypeSSPI = 9
|
||
|
AuthTypeSASL = 10
|
||
|
AuthTypeSASLContinue = 11
|
||
|
AuthTypeSASLFinal = 12
|
||
|
)
|
||
|
|
||
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||
|
if len(src) < 4 {
|
||
|
return nil, errors.New("authentication message too short")
|
||
|
}
|
||
|
f.authType = binary.BigEndian.Uint32(src[:4])
|
||
|
|
||
|
switch f.authType {
|
||
|
case AuthTypeOk:
|
||
|
return &f.authenticationOk, nil
|
||
|
case AuthTypeCleartextPassword:
|
||
|
return &f.authenticationCleartextPassword, nil
|
||
|
case AuthTypeMD5Password:
|
||
|
return &f.authenticationMD5Password, nil
|
||
|
case AuthTypeSCMCreds:
|
||
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||
|
case AuthTypeGSS:
|
||
|
return nil, errors.New("AuthTypeGSS is unimplemented")
|
||
|
case AuthTypeGSSCont:
|
||
|
return nil, errors.New("AuthTypeGSSCont is unimplemented")
|
||
|
case AuthTypeSSPI:
|
||
|
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||
|
case AuthTypeSASL:
|
||
|
return &f.authenticationSASL, nil
|
||
|
case AuthTypeSASLContinue:
|
||
|
return &f.authenticationSASLContinue, nil
|
||
|
case AuthTypeSASLFinal:
|
||
|
return &f.authenticationSASLFinal, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// GetAuthType returns the authType used in the current state of the frontend.
|
||
|
// See SetAuthType for more information.
|
||
|
func (f *Frontend) GetAuthType() uint32 {
|
||
|
return f.authType
|
||
|
}
|