package ipmi

import (
	"context"
	"fmt"
)

const (
	RmcpVersion uint8 = 0x06

	RMCP_TYPE_MASK = 0x80
	RMCP_TYPE_NORM = 0x00
	RMCP_TYPE_ACK  = 0x01
)

// Rmcp holds the data that will be send over UDP
type Rmcp struct {
	// Multi-byte fields in RMCP/ASF fields are specified as being transmitted in "Network Byte Order" - meaning most-significant byte first.
	// RMCP and ASF-specified fields are therefore transferred **most-significant byte first**.
	RmcpHeader *RmcpHeader

	// Multi-byte fields in RMCP/ASF fields are specified as being transmitted in "Network Byte Order"
	ASF *ASF

	// The IPMI convention is to transfer multi-byte numeric fields least-significant Byte first. Therefore, unless otherwise specified:
	// Data in the IPMI Session Header and IPMI Message fields are transmitted **least-significant byte first**.
	Session15 *Session15
	Session20 *Session20
}

func (r *Rmcp) Pack() []byte {
	out := r.RmcpHeader.Pack()
	if r.ASF != nil {
		out = append(out, r.ASF.Pack()...)
	}
	if r.Session15 != nil {
		out = append(out, r.Session15.Pack()...)
	}
	if r.Session20 != nil {
		out = append(out, r.Session20.Pack()...)
	}
	return out
}

func (r *Rmcp) Unpack(msg []byte) error {
	if len(msg) < 4 {
		return ErrUnpackedDataTooShortWith(len(msg), 4)
	}
	rmcpHeader := &RmcpHeader{}
	err := rmcpHeader.Unpack(msg[:4])
	if err != nil {
		return fmt.Errorf("unpack RmcpHeader failed, err: %w", err)
	}
	r.RmcpHeader = rmcpHeader

	if len(msg) < 4+1 {
		return fmt.Errorf("msg length too short, no session inside")
	}

	if r.RmcpHeader.MessageClass == MessageClassASF {
		asf := &ASF{}
		err := asf.Unpack(msg[4:])
		if err != nil {
			return fmt.Errorf("unpack ASF failed, err: %w", err)
		}
		r.ASF = asf
		return nil
	}

	authTypeOrFormat := msg[4]
	if authTypeOrFormat == byte(AuthTypeRMCPPlus) {
		// IPMI 2.0
		s20 := &Session20{}
		err = s20.Unpack(msg[4:])
		if err != nil {
			return fmt.Errorf("unpack IPMI 2.0 Session failed, err: %w", err)
		}
		r.Session20 = s20
	} else {
		// IPMI 1.5
		s15 := &Session15{}
		err = s15.Unpack(msg[4:])
		if err != nil {
			return fmt.Errorf("unpack IPMI 1.5 Session failed, err: %w", err)
		}
		r.Session15 = s15
	}

	if r.Session15 != nil && r.Session20 != nil {
		return fmt.Errorf("the IPMI session can not be both version 1.5 and 2.0")
	}

	return nil
}

// RmcpHeader represents RMCP Message Header
// 13.1.3
type RmcpHeader struct {
	// 06h = RMCP Version 1.0
	// IPMI-over-LAN uses version 1 of the RMCP protocol and packet format
	Version uint8

	// RMCP Messages with class=IPMI should be sent with an RMCP Sequence Number of FFh
	// to indicate that an RMCP ACK message should not be generated by the message receiver.
	SequenceNumber uint8

	// This field identifies the format of the messages that follow this header.
	// All messages of class ASF (6) conform to the formats defined in this
	// specification and can be extended via an OEM IANA.
	// Bit 7 RMCP ACK
	// 0 - Normal RMCP message
	// 1 - RMCP ACK message
	ACKFlag bool
	// Bit 6:5 Reserved
	// Bit 4:0 Message Class
	// 0-5 = Reserved
	// 6 = ASF
	// 7 = IPMI
	// 8 = OEM defined
	// all other = Reserved
	MessageClass MessageClass // Can be IPMI Messages, ASF, OEM
}

func NewRmcpHeader() *RmcpHeader {
	return &RmcpHeader{
		Version:        RmcpVersion,
		SequenceNumber: 0xff,
		MessageClass:   MessageClassIPMI,
	}
}

func NewRmcpHeaderASF() *RmcpHeader {
	return &RmcpHeader{
		Version:        RmcpVersion,
		SequenceNumber: 0xff,
		MessageClass:   MessageClassASF,
	}
}

func (r *RmcpHeader) Pack() []byte {
	msg := make([]byte, 4)
	packUint8(r.Version, msg, 0)
	// 1 byte reserved
	packUint8(r.SequenceNumber, msg, 2)

	var messageClass uint8 = 0x00
	if r.ACKFlag {
		messageClass |= 0x80
	} else {
		messageClass |= 0x00
	}
	messageClass |= uint8(r.MessageClass)
	packUint8(messageClass, msg, 3)
	return msg
}

func (r *RmcpHeader) Unpack(msg []byte) error {
	if len(msg) < 4 {
		return ErrUnpackedDataTooShortWith(len(msg), 4)
	}

	r.Version, _, _ = unpackUint8(msg, 0)
	// 1 byte reserved
	r.SequenceNumber, _, _ = unpackUint8(msg, 2)

	var b uint8
	b, _, _ = unpackUint8(msg, 3)
	r.ACKFlag = isBit7Set(b)
	messageClass := b & 0x7f // clear the ACK bit
	r.MessageClass = MessageClass(messageClass)
	return nil
}

type MessageType uint8

const (
	MessageACKBit    uint8 = 0x80
	MessageNormalBit uint8 = 0x00
)

const (
	MessageTypeUndefined MessageType = 0x00
	MessageTypePing      MessageType = 0x80
	MessageTypeRMCPACK   MessageType = (0x80 | 6)
	MessageTypeASF       MessageType = (0x00 | 6)
	MessageTypeIPMI      MessageType = (0x00 | 7)
	MessageTypeOEM       MessageType = (0x00 | 8)
)

// the ACK/Normal Bit and the Message Class combine to identify the type of
// message under RMCP
// see: Table 13-, Message Type Determination Under RMCP
func (r *RmcpHeader) MessageType() MessageType {
	if r.ACKFlag {
		switch r.MessageClass {
		case MessageClassASF:
			return MessageTypeRMCPACK
		default:
			return MessageTypeUndefined
		}
	}

	switch r.MessageClass {
	case MessageClassASF:
		return MessageTypeASF
	case MessageClassOEM:
		return MessageTypeOEM
	case MessageClassIPMI:
		return MessageTypeIPMI
	default:
		return MessageTypeIPMI
	}
}

type MessageClass uint8

const (
	// 0-5 Reserved

	MessageClassASF  = 6
	MessageClassIPMI = 7
	MessageClassOEM  = 8

	// 9-15 Reserved
)

func (mc MessageClass) NormalACKFlag() bool {
	i := uint8(mc) & uint8(1) << 7
	return i == uint8(1)<<7
}

// 13.2.1 RMCP ACK Messages
type RmcpAckMessage struct {
	// Copied from received message
	Version uint8

	// Copied from received message
	SequenceNumber uint8

	// [7] - Set to 1 to indicate ACK packet
	// [6:0] - Copied from received message.
	ACKFlag      bool
	MessageClass MessageClass // Can be IPMI Messages, ASF, OEM
}

type ASF struct {
	IANA        uint32 // 4542
	MessageType uint8

	// 0-FEh, generated by remote console. This is an RMCP version of a sequence number.
	// Values 0-254 (0-FEh) are used for RMCP request/response messages.
	// 255 indicates the message is unidirectional and not part of a request/response pair.
	MessageTag uint8

	DataLength uint8 // 00h

	Data []byte
}

func (asf *ASF) Pack() []byte {
	msg := make([]byte, 8+len(asf.Data))
	packUint32(asf.IANA, msg, 0) // MSB, not LSB
	packUint8(asf.MessageType, msg, 4)
	packUint8(asf.MessageTag, msg, 5)
	// 1 byte reserved
	packUint8(asf.DataLength, msg, 7)
	packBytes(asf.Data, msg, 8)
	return msg
}

func (asf *ASF) Unpack(msg []byte) error {
	if len(msg) < 8 {
		return ErrUnpackedDataTooShortWith(len(msg), 8)
	}

	asf.IANA, _, _ = unpackUint32L(msg, 0)
	asf.MessageType, _, _ = unpackUint8(msg, 4)
	asf.MessageTag, _, _ = unpackUint8(msg, 5)
	// 1 byte reserved
	asf.DataLength, _, _ = unpackUint8(msg, 7)

	if len(msg) < 8+int(asf.DataLength) {
		return ErrUnpackedDataTooShortWith(len(msg), 8+int(asf.DataLength))
	}
	asf.Data, _, _ = unpackBytes(msg, 8, int(asf.DataLength))
	return nil
}

func (c *Client) BuildRmcpRequest(ctx context.Context, reqCmd Request) (*Rmcp, error) {
	payloadType, rawPayload, err := c.buildRawPayload(ctx, reqCmd)
	if err != nil {
		return nil, fmt.Errorf("buildRawPayload failed, err: %w", err)
	}
	c.DebugBytes("rawPayload", rawPayload, 16)

	// ASF
	if _, ok := reqCmd.(*RmcpPingRequest); ok {
		rmcp := &Rmcp{
			RmcpHeader: NewRmcpHeaderASF(),
			ASF: &ASF{
				IANA:        4542,
				MessageType: uint8(MessageTypePing),
				MessageTag:  0,
				DataLength:  0,
				Data:        rawPayload,
			},
		}
		return rmcp, nil
	}

	// IPMI 2.0
	if c.v20 {
		session20, err := c.genSession20(payloadType, rawPayload)
		if err != nil {
			return nil, fmt.Errorf("genSession20 failed, err: %w", err)
		}

		rmcp := &Rmcp{
			RmcpHeader: NewRmcpHeader(),
			Session20:  session20,
		}
		return rmcp, nil
	}

	// IPMI 1.5
	session15, err := c.genSession15(rawPayload)
	if err != nil {
		return nil, fmt.Errorf("genSession15 failed, err: %w", err)
	}

	rmcp := &Rmcp{
		RmcpHeader: NewRmcpHeader(),
		Session15:  session15,
	}
	return rmcp, nil
}

// ParseRmcpResponse parses msg bytes.
// The response param should be passed as a pointer of the struct which implements the Response interface.
func (c *Client) ParseRmcpResponse(ctx context.Context, msg []byte, response Response) error {
	rmcp := &Rmcp{}
	if err := rmcp.Unpack(msg); err != nil {
		return fmt.Errorf("unpack rmcp failed, err: %w", err)
	}
	c.Debug("<<<<<< RMCP Response", rmcp)

	if rmcp.ASF != nil {
		if int(rmcp.ASF.DataLength) != len(rmcp.ASF.Data) {
			return fmt.Errorf("asf Data Length not equal")
		}
		if err := response.Unpack(rmcp.ASF.Data); err != nil {
			return fmt.Errorf("unpack asf response failed, err: %w", err)
		}
		return nil
	}

	if rmcp.Session15 != nil {
		ipmiPayload := rmcp.Session15.Payload

		ipmiRes := IPMIResponse{}
		if err := ipmiRes.Unpack(ipmiPayload); err != nil {
			return fmt.Errorf("unpack ipmiRes failed, err: %w", err)
		}
		c.Debug("<<<< IPMI Response", ipmiRes)

		ccode := ipmiRes.CompletionCode
		if ccode != 0x00 {
			return &ResponseError{
				completionCode: CompletionCode(ccode),
				description:    fmt.Sprintf("ipmiRes CompletionCode (%#02x) is not normal: %s", ccode, StrCC(response, ccode)),
			}
		}

		// now ccode is 0x00, we can continue to deserialize response
		if err := response.Unpack(ipmiRes.Data); err != nil {
			return &ResponseError{
				completionCode: 0x00,
				description:    fmt.Sprintf("unpack response failed, err: %s", err),
			}
		}
	}

	if rmcp.Session20 != nil {
		sessionHdr := rmcp.Session20.SessionHeader20

		switch sessionHdr.PayloadType {
		case
			PayloadTypeRmcpOpenSessionResponse,
			PayloadTypeRAKPMessage2,
			PayloadTypeRAKPMessage4:
			// Session Setup Payload Types

			if err := response.Unpack(rmcp.Session20.SessionPayload); err != nil {
				return fmt.Errorf("unpack session setup response failed, err: %w", err)
			}
			return nil

		case PayloadTypeIPMI:
			// Standard Payload Types
			ipmiPayload := rmcp.Session20.SessionPayload
			if sessionHdr.PayloadEncrypted {
				c.DebugBytes("decrypting", ipmiPayload, 16)
				d, err := c.decryptPayload(rmcp.Session20.SessionPayload)
				if err != nil {
					return fmt.Errorf("decrypt session payload failed, err: %w", err)
				}
				ipmiPayload = d
				c.DebugBytes("decrypted", ipmiPayload, 16)
			}

			ipmiRes := IPMIResponse{}
			if err := ipmiRes.Unpack(ipmiPayload); err != nil {
				return fmt.Errorf("unpack ipmiRes failed, err: %w", err)
			}
			c.Debug("<<<< IPMI Response", ipmiRes)

			ccode := ipmiRes.CompletionCode
			if ccode != 0x00 {
				return &ResponseError{
					completionCode: CompletionCode(ccode),
					description:    fmt.Sprintf("ipmiRes CompletionCode (%#02x) is not normal: %s", ccode, StrCC(response, ccode)),
				}
			}

			// now ccode is 0x00, we can continue to deserialize response
			if err := response.Unpack(ipmiRes.Data); err != nil {
				return &ResponseError{
					completionCode: 0x00,
					description:    fmt.Sprintf("unpack response failed, err: %s", err),
				}
			}
		}
	}

	return nil
}

// 13.24 RMCP+ and RAKP Message Status Codes
type RmcpStatusCode uint8

const (
	RmcpStatusCodeNoErrors                   RmcpStatusCode = 0x00
	RmcpStatusCodeNoResToCreateSess          RmcpStatusCode = 0x01
	RmcpStatusCodeInvalidSessionID           RmcpStatusCode = 0x02
	RmcpStatusCodeInvalidPayloadType         RmcpStatusCode = 0x03
	RmcpStatusCodeInvalidAuthAlg             RmcpStatusCode = 0x04
	RmcpStatusCodeInvalidIntegrityAlg        RmcpStatusCode = 0x05
	RmcpStatusCodeNoMatchingAuthPayload      RmcpStatusCode = 0x06
	RmcpStatusCodeNoMatchingIntegrityPayload RmcpStatusCode = 0x07
	RmcpStatusCodeInactiveSessionID          RmcpStatusCode = 0x08
	RmcpStatusCodeInvalidRole                RmcpStatusCode = 0x09
	RmcpStatusCodeUnauthorizedRoleOfPriLevel RmcpStatusCode = 0x0a
	RmcpStatusCodeNoResToCreateSessAtRole    RmcpStatusCode = 0x0b
	RmcpStatusCodeInvalidNameLength          RmcpStatusCode = 0x0c
	RmcpStatusCodeUnauthorizedName           RmcpStatusCode = 0x0d
	RmcpStatusCodeUnauthorizedGUID           RmcpStatusCode = 0x0e
	RmcpStatusCodeInvalidIntegrityCheckValue RmcpStatusCode = 0x0f
	RmcpStatusCodeInvalidConfidentAlg        RmcpStatusCode = 0x10
	RmcpStatusCodeNoCipherSuiteMatch         RmcpStatusCode = 0x11
	RmcpStatusCodeIllegalParameter           RmcpStatusCode = 0x12
)

func (c RmcpStatusCode) String() string {
	m := map[RmcpStatusCode]string{
		0x00: "No errors",
		0x01: "Insufficient resources to create a session",
		0x02: "Invalid Session ID",
		0x03: "Invalid payload type",
		0x04: "Invalid authentication algorithm",
		0x05: "Invalid integrity algorithm",
		0x06: "No matching authentication payload",
		0x07: "No matching integrity payload",
		0x08: "Inactive Session ID",
		0x09: "Invalid role",
		0x0a: "Unauthorized role of privilege level requested",
		0x0b: "Insufficient resources to create a session at the requested role",
		0x0c: "Invalid name length",
		0x0d: "Unauthorized name",
		0x0e: "Unauthorized GUID",
		0x0f: "Invalid integrity check value",
		0x10: "Invalid confidentiality algorithm",
		0x11: "No Cipher Suite match with proposed security algorithms",
		0x12: "Illegal or unrecognized parameter",
		// 0x13 - 0xff: Reserved for future definition by this specification.
	}

	s, ok := m[c]
	if ok {
		return s
	}
	return "Unknown"
}
