xmpp/jid/jid.go

451 lines
13 KiB
Go

// Copyright 2014 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.
package jid
import (
"bytes"
"encoding/xml"
"errors"
"net"
"strconv"
"strings"
"unicode/utf8"
"golang.org/x/net/idna"
"golang.org/x/text/secure/precis"
)
var (
errForbiddenLocalpart = errors.New("localpart contains forbidden characters")
errInvalidDomainLen = errors.New("the domainpart must be between 1 and 1023 bytes")
errInvalidUTF8 = errors.New("jID contains invalid UTF-8")
errLongLocalpart = errors.New("the localpart must be smaller than 1024 bytes")
errLongResourcepart = errors.New("the resourcepart must be smaller than 1024 bytes")
errNoLocalpart = errors.New("the localpart must be larger than 0 bytes")
errNoResourcepart = errors.New("the resourcepart must be larger than 0 bytes")
)
// JID represents an XMPP address (Jabber ID) comprising a localpart,
// domainpart, and resourcepart. All parts of a JID are guaranteed to be valid
// UTF-8 and will be represented in their canonical form which gives comparison
// the greatest chance of succeeding.
type JID struct {
locallen int
domainlen int
data []byte
}
// Parse constructs a new JID from the given string representation.
func Parse(s string) (JID, error) {
localpart, domainpart, resourcepart, err := SplitString(s)
if err != nil {
return JID{}, err
}
return New(localpart, domainpart, resourcepart)
}
// MustParse is like Parse but panics if the JID cannot be parsed.
// It simplifies safe initialization of JIDs from known-good constant strings.
func MustParse(s string) JID {
j, err := Parse(s)
if err != nil {
if strconv.CanBackquote(s) {
s = "`" + s + "`"
} else {
s = strconv.Quote(s)
}
panic(`jid: Parse(` + s + `): ` + err.Error())
}
return j
}
// New constructs a new JID from the given localpart, domainpart, and
// resourcepart.
func New(localpart, domainpart, resourcepart string) (JID, error) {
// Ensure that parts are valid UTF-8 (and short circuit the rest of the
// process if they're not).
// The domainpart is checked in normalizeDomainpart.
if !utf8.ValidString(localpart) || !utf8.ValidString(resourcepart) {
return JID{}, errInvalidUTF8
}
var err error
domainpart, err = normalizeDomainpart(domainpart)
if err != nil {
return JID{}, err
}
var lenlocal int
data := make([]byte, 0, len(localpart)+len(domainpart)+len(resourcepart))
if localpart != "" {
data, err = precis.UsernameCaseMapped.Append(data, []byte(localpart))
if err != nil {
return JID{}, err
}
lenlocal = len(data)
}
data = append(data, []byte(domainpart)...)
if resourcepart != "" {
data, err = precis.OpaqueString.Append(data, []byte(resourcepart))
if err != nil {
return JID{}, err
}
}
err = localChecks(data[:lenlocal])
if err != nil {
return JID{}, err
}
err = resourceChecks(data[lenlocal+len(domainpart):])
if err != nil {
return JID{}, err
}
return JID{
locallen: lenlocal,
domainlen: len(domainpart),
data: data,
}, nil
}
// WithLocal returns a copy of the JID with a new localpart.
// This elides validation of the domainpart and resourcepart.
func (j JID) WithLocal(localpart string) (JID, error) {
var err error
data := make([]byte, 0, len(localpart)+len(j.data[j.locallen:]))
if localpart != "" {
if !utf8.ValidString(localpart) {
return j, errInvalidUTF8
}
data, err = precis.UsernameCaseMapped.Append(data, []byte(localpart))
if err != nil {
return j, err
}
}
ll := len(data)
data = append(data, j.data[j.locallen:]...)
if err != nil {
return j, err
}
j.locallen = ll
j.data = data
return j, localChecks(data[:ll])
}
// WithDomain returns a copy of the JID with a new domainpart.
// This elides validation of the localpart and resourcepart.
func (j JID) WithDomain(domainpart string) (JID, error) {
var err error
domainpart, err = normalizeDomainpart(domainpart)
if err != nil {
return j, err
}
dl := len(domainpart)
data := make([]byte, 0, len(j.data)-j.domainlen+dl)
data = append(data, j.data[:j.locallen]...)
data = append(data, domainpart...)
data = append(data, j.data[j.locallen+j.domainlen:]...)
j.domainlen = dl
j.data = data
return j, nil
}
// WithResource returns a copy of the JID with a new resourcepart.
// This elides validation of the localpart and domainpart.
func (j JID) WithResource(resourcepart string) (JID, error) {
var err error
new := j.Bare()
data := make([]byte, len(new.data), len(new.data)+len(resourcepart))
copy(data, new.data)
if resourcepart != "" {
if !utf8.ValidString(resourcepart) {
return JID{}, errInvalidUTF8
}
data, err = precis.OpaqueString.Append(data, []byte(resourcepart))
if err != nil {
return JID{}, err
}
new.data = data
}
return new, resourceChecks(data[j.locallen+j.domainlen:])
}
// Bare returns a copy of the JID without a resourcepart. This is sometimes
// called a "bare" JID.
func (j JID) Bare() JID {
return JID{
locallen: j.locallen,
domainlen: j.domainlen,
data: j.data[:j.domainlen+j.locallen],
}
}
// Domain returns a copy of the JID without a resourcepart or localpart.
func (j JID) Domain() JID {
return JID{
domainlen: j.domainlen,
data: j.data[j.locallen : j.domainlen+j.locallen],
}
}
// Localpart gets the localpart of a JID (eg "username").
func (j JID) Localpart() string {
return string(j.data[:j.locallen])
}
// Domainpart gets the domainpart of a JID (eg. "example.net").
func (j JID) Domainpart() string {
return string(j.data[j.locallen : j.locallen+j.domainlen])
}
// Resourcepart gets the resourcepart of a JID.
func (j JID) Resourcepart() string {
return string(j.data[j.locallen+j.domainlen:])
}
// Copy makes a copy of the given JID. j.Equal(j.Copy()) will always return
// true.
func (j JID) Copy() JID {
return j
}
// Network satisfies the net.Addr interface by returning the name of the network
// ("xmpp").
func (JID) Network() string {
return "xmpp"
}
// String converts an JID to its string representation.
func (j JID) String() string {
s := string(j.data[j.locallen : j.locallen+j.domainlen])
var addsep int
if j.locallen > 0 {
s = string(j.data[:j.locallen]) + "@" + s
addsep = 1
}
if len(s) != len(j.data)+addsep {
s = s + "/" + string(j.data[j.locallen+j.domainlen:])
}
return s
}
// Equal performs an octet-for-octet comparison with the given JID.
func (j JID) Equal(j2 JID) bool {
if len(j.data) != len(j2.data) {
return false
}
for i := 0; i < len(j.data); i++ {
if j.data[i] != j2.data[i] {
return false
}
}
return j.locallen == j2.locallen && j.domainlen == j2.domainlen
}
// MarshalXML satisfies the xml.Marshaler interface and marshals the JID as
// XML chardata.
func (j JID) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
if err = e.EncodeToken(start); err != nil {
return
}
if err = e.EncodeToken(xml.CharData(j.String())); err != nil {
return
}
if err = e.EncodeToken(start.End()); err != nil {
return
}
err = e.Flush()
return
}
// UnmarshalXML satisfies the xml.Unmarshaler interface and unmarshals the JID
// from the elements chardata.
func (j *JID) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err error) {
data := struct {
CharData string `xml:",chardata"`
}{}
if err = d.DecodeElement(&data, &start); err != nil {
return
}
j2, err := Parse(data.CharData)
if err == nil {
j.locallen = j2.locallen
j.domainlen = j2.domainlen
j.data = j2.data
}
return
}
// MarshalXMLAttr satisfies the xml.MarshalerAttr interface and marshals the JID
// as an XML attribute.
func (j JID) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
return xml.Attr{Name: name, Value: j.String()}, nil
}
// UnmarshalXMLAttr satisfies the xml.UnmarshalerAttr interface and unmarshals
// an XML attribute into a valid JID (or returns an error).
func (j *JID) UnmarshalXMLAttr(attr xml.Attr) error {
if attr.Value == "" {
return nil
}
jid, err := Parse(attr.Value)
j.locallen = jid.locallen
j.domainlen = jid.domainlen
j.data = jid.data
return err
}
// SplitString splits out the localpart, domainpart, and resourcepart from a
// string representation of a JID. The parts are not guaranteed to be valid, and
// each part must be 1023 bytes or less.
func SplitString(s string) (localpart, domainpart, resourcepart string, err error) {
return splitString(s, true)
}
func splitString(s string, safe bool) (localpart, domainpart, resourcepart string, err error) {
// RFC 7622 §3.1. Fundamentals:
//
// Implementation Note: When dividing a JID into its component parts,
// an implementation needs to match the separator characters '@' and
// '/' before applying any transformation algorithms, which might
// decompose certain Unicode code points to the separator characters.
//
// so let's do that now. First we'll parse the domainpart using the rules
// defined in §3.2:
//
// The domainpart of a JID is the portion that remains once the
// following parsing steps are taken:
//
// 1. Remove any portion from the first '/' character to the end of the
// string (if there is a '/' character present).
sep := strings.Index(s, "/")
if sep == -1 {
resourcepart = ""
} else {
// If the resource part exists, make sure it isn't empty.
if safe && sep == len(s)-1 {
err = errNoResourcepart
return
}
resourcepart = s[sep+1:]
s = s[:sep]
}
// 2. Remove any portion from the beginning of the string to the first
// '@' character (if there is an '@' character present).
sep = strings.Index(s, "@")
switch {
case sep == -1:
// There is no @ sign, and therefore no localpart.
localpart = ""
domainpart = s
case safe && sep == 0:
// The JID starts with an @ sign (invalid empty localpart)
err = errNoLocalpart
return
default:
domainpart = s[sep+1:]
localpart = s[:sep]
}
return
}
func localChecks(localpart []byte) error {
if len(localpart) > 1023 {
return errLongLocalpart
}
// RFC 7622 §3.3.1 provides a small table of characters which are still not
// allowed in localpart's even though the IdentifierClass base class and the
// UsernameCaseMapped profile don't forbid them; disallow them here.
// We can't add them to the profiles disallowed characters because they get
// checked before the profile is applied (so some characters may still be
// normalized to characters in this set).
if bytes.ContainsAny(localpart, `"&'/:<>@`) {
return errForbiddenLocalpart
}
return nil
}
func resourceChecks(resourcepart []byte) error {
if len(resourcepart) > 1023 {
return errLongResourcepart
}
return nil
}
func normalizeDomainpart(domainpart string) (string, error) {
if !utf8.ValidString(domainpart) {
return domainpart, errInvalidUTF8
}
// If the domainpart is a valid IPv6 address (with brackets), short circuit.
if l := len(domainpart) - 1; l > 1 && domainpart[0] == '[' && domainpart[l] == ']' {
if ip := net.ParseIP(domainpart[1:l]); ip != nil && ip.To4() == nil {
return domainpart, nil
}
}
// If the domainpart is a valid IPv4 address, short circuit.
if ip := net.ParseIP(domainpart); ip != nil && ip.To4() != nil {
return domainpart, nil
}
// RFC 7622 §3.2. Domainpart
//
// If the domainpart includes a final character considered to be a label
// separator (dot) by [RFC1034], this character MUST be stripped from
// the domainpart before the JID of which it is a part is used for the
// purpose of routing an XML stanza, comparing against another JID, or
// constructing an XMPP URI or IRI [RFC5122]. In particular, such a
// character MUST be stripped before any other canonicalization steps
// are taken.
domainpart = strings.TrimSuffix(domainpart, ".")
// RFC 7622 §3.2.1. Preparation
//
// An entity that prepares a string for inclusion in an XMPP domainpart
// slot MUST ensure that the string consists only of Unicode code points
// that are allowed in NR-LDH labels or U-labels as defined in
// [RFC5890]. This implies that the string MUST NOT include A-labels as
// defined in [RFC5890]; each A-label MUST be converted to a U-label
// during preparation of a string for inclusion in a domainpart slot.
//
// RFC 7622 §3.2.2. Enforcement
//
// An entity that performs enforcement in XMPP domainpart slots MUST
// prepare a string as described in Section 3.2.1 and MUST also apply
// the normalization, case-mapping, and width-mapping rules defined in
// [RFC5892].
//
// Per EID 4534 this is actually talking about RFC 5895.
var err error
domainpart, err = idna.Display.ToUnicode(domainpart)
if err != nil {
return domainpart, err
}
if l := len(domainpart); l < 1 || l > 1023 {
return domainpart, errInvalidDomainLen
}
return domainpart, nil
}