mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-11-04 10:44:12 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			276 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			276 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2011 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package ldap
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/tls"
 | 
						|
	"errors"
 | 
						|
	"log"
 | 
						|
	"net"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"github.com/gogits/gogs/modules/asn1-ber"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	MessageQuit     = 0
 | 
						|
	MessageRequest  = 1
 | 
						|
	MessageResponse = 2
 | 
						|
	MessageFinish   = 3
 | 
						|
)
 | 
						|
 | 
						|
type messagePacket struct {
 | 
						|
	Op        int
 | 
						|
	MessageID uint64
 | 
						|
	Packet    *ber.Packet
 | 
						|
	Channel   chan *ber.Packet
 | 
						|
}
 | 
						|
 | 
						|
// Conn represents an LDAP Connection
 | 
						|
type Conn struct {
 | 
						|
	conn          net.Conn
 | 
						|
	isTLS         bool
 | 
						|
	isClosing     bool
 | 
						|
	Debug         debugging
 | 
						|
	chanConfirm   chan bool
 | 
						|
	chanResults   map[uint64]chan *ber.Packet
 | 
						|
	chanMessage   chan *messagePacket
 | 
						|
	chanMessageID chan uint64
 | 
						|
	wgSender      sync.WaitGroup
 | 
						|
	wgClose       sync.WaitGroup
 | 
						|
	once          sync.Once
 | 
						|
}
 | 
						|
 | 
						|
// Dial connects to the given address on the given network using net.Dial
 | 
						|
// and then returns a new Conn for the connection.
 | 
						|
func Dial(network, addr string) (*Conn, error) {
 | 
						|
	c, err := net.Dial(network, addr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, NewError(ErrorNetwork, err)
 | 
						|
	}
 | 
						|
	conn := NewConn(c)
 | 
						|
	conn.start()
 | 
						|
	return conn, nil
 | 
						|
}
 | 
						|
 | 
						|
// DialTLS connects to the given address on the given network using tls.Dial
 | 
						|
// and then returns a new Conn for the connection.
 | 
						|
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
 | 
						|
	c, err := tls.Dial(network, addr, config)
 | 
						|
	if err != nil {
 | 
						|
		return nil, NewError(ErrorNetwork, err)
 | 
						|
	}
 | 
						|
	conn := NewConn(c)
 | 
						|
	conn.isTLS = true
 | 
						|
	conn.start()
 | 
						|
	return conn, nil
 | 
						|
}
 | 
						|
 | 
						|
// NewConn returns a new Conn using conn for network I/O.
 | 
						|
func NewConn(conn net.Conn) *Conn {
 | 
						|
	return &Conn{
 | 
						|
		conn:          conn,
 | 
						|
		chanConfirm:   make(chan bool),
 | 
						|
		chanMessageID: make(chan uint64),
 | 
						|
		chanMessage:   make(chan *messagePacket, 10),
 | 
						|
		chanResults:   map[uint64]chan *ber.Packet{},
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) start() {
 | 
						|
	go l.reader()
 | 
						|
	go l.processMessages()
 | 
						|
	l.wgClose.Add(1)
 | 
						|
}
 | 
						|
 | 
						|
// Close closes the connection.
 | 
						|
func (l *Conn) Close() {
 | 
						|
	l.once.Do(func() {
 | 
						|
		l.isClosing = true
 | 
						|
		l.wgSender.Wait()
 | 
						|
 | 
						|
		l.Debug.Printf("Sending quit message and waiting for confirmation")
 | 
						|
		l.chanMessage <- &messagePacket{Op: MessageQuit}
 | 
						|
		<-l.chanConfirm
 | 
						|
		close(l.chanMessage)
 | 
						|
 | 
						|
		l.Debug.Printf("Closing network connection")
 | 
						|
		if err := l.conn.Close(); err != nil {
 | 
						|
			log.Print(err)
 | 
						|
		}
 | 
						|
 | 
						|
		l.conn = nil
 | 
						|
		l.wgClose.Done()
 | 
						|
	})
 | 
						|
	l.wgClose.Wait()
 | 
						|
}
 | 
						|
 | 
						|
// Returns the next available messageID
 | 
						|
func (l *Conn) nextMessageID() uint64 {
 | 
						|
	if l.chanMessageID != nil {
 | 
						|
		if messageID, ok := <-l.chanMessageID; ok {
 | 
						|
			return messageID
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return 0
 | 
						|
}
 | 
						|
 | 
						|
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
 | 
						|
func (l *Conn) StartTLS(config *tls.Config) error {
 | 
						|
	messageID := l.nextMessageID()
 | 
						|
 | 
						|
	if l.isTLS {
 | 
						|
		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
 | 
						|
	}
 | 
						|
 | 
						|
	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
 | 
						|
	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
 | 
						|
	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
 | 
						|
	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
 | 
						|
	packet.AppendChild(request)
 | 
						|
	l.Debug.PrintPacket(packet)
 | 
						|
 | 
						|
	_, err := l.conn.Write(packet.Bytes())
 | 
						|
	if err != nil {
 | 
						|
		return NewError(ErrorNetwork, err)
 | 
						|
	}
 | 
						|
 | 
						|
	packet, err = ber.ReadPacket(l.conn)
 | 
						|
	if err != nil {
 | 
						|
		return NewError(ErrorNetwork, err)
 | 
						|
	}
 | 
						|
 | 
						|
	if l.Debug {
 | 
						|
		if err := addLDAPDescriptions(packet); err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		ber.PrintPacket(packet)
 | 
						|
	}
 | 
						|
 | 
						|
	if packet.Children[1].Children[0].Value.(uint64) == 0 {
 | 
						|
		conn := tls.Client(l.conn, config)
 | 
						|
		l.isTLS = true
 | 
						|
		l.conn = conn
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
 | 
						|
	if l.isClosing {
 | 
						|
		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
 | 
						|
	}
 | 
						|
	out := make(chan *ber.Packet)
 | 
						|
	message := &messagePacket{
 | 
						|
		Op:        MessageRequest,
 | 
						|
		MessageID: packet.Children[0].Value.(uint64),
 | 
						|
		Packet:    packet,
 | 
						|
		Channel:   out,
 | 
						|
	}
 | 
						|
	l.sendProcessMessage(message)
 | 
						|
	return out, nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) finishMessage(messageID uint64) {
 | 
						|
	if l.isClosing {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	message := &messagePacket{
 | 
						|
		Op:        MessageFinish,
 | 
						|
		MessageID: messageID,
 | 
						|
	}
 | 
						|
	l.sendProcessMessage(message)
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
 | 
						|
	if l.isClosing {
 | 
						|
		return false
 | 
						|
	}
 | 
						|
	l.wgSender.Add(1)
 | 
						|
	l.chanMessage <- message
 | 
						|
	l.wgSender.Done()
 | 
						|
	return true
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) processMessages() {
 | 
						|
	defer func() {
 | 
						|
		for messageID, channel := range l.chanResults {
 | 
						|
			l.Debug.Printf("Closing channel for MessageID %d", messageID)
 | 
						|
			close(channel)
 | 
						|
			delete(l.chanResults, messageID)
 | 
						|
		}
 | 
						|
		close(l.chanMessageID)
 | 
						|
		l.chanConfirm <- true
 | 
						|
		close(l.chanConfirm)
 | 
						|
	}()
 | 
						|
 | 
						|
	var messageID uint64 = 1
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case l.chanMessageID <- messageID:
 | 
						|
			messageID++
 | 
						|
		case messagePacket, ok := <-l.chanMessage:
 | 
						|
			if !ok {
 | 
						|
				l.Debug.Printf("Shutting down - message channel is closed")
 | 
						|
				return
 | 
						|
			}
 | 
						|
			switch messagePacket.Op {
 | 
						|
			case MessageQuit:
 | 
						|
				l.Debug.Printf("Shutting down - quit message received")
 | 
						|
				return
 | 
						|
			case MessageRequest:
 | 
						|
				// Add to message list and write to network
 | 
						|
				l.Debug.Printf("Sending message %d", messagePacket.MessageID)
 | 
						|
				l.chanResults[messagePacket.MessageID] = messagePacket.Channel
 | 
						|
				// go routine
 | 
						|
				buf := messagePacket.Packet.Bytes()
 | 
						|
 | 
						|
				_, err := l.conn.Write(buf)
 | 
						|
				if err != nil {
 | 
						|
					l.Debug.Printf("Error Sending Message: %s", err.Error())
 | 
						|
					break
 | 
						|
				}
 | 
						|
			case MessageResponse:
 | 
						|
				l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
 | 
						|
				if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
 | 
						|
					chanResult <- messagePacket.Packet
 | 
						|
				} else {
 | 
						|
					log.Printf("Received unexpected message %d", messagePacket.MessageID)
 | 
						|
					ber.PrintPacket(messagePacket.Packet)
 | 
						|
				}
 | 
						|
			case MessageFinish:
 | 
						|
				// Remove from message list
 | 
						|
				l.Debug.Printf("Finished message %d", messagePacket.MessageID)
 | 
						|
				close(l.chanResults[messagePacket.MessageID])
 | 
						|
				delete(l.chanResults, messagePacket.MessageID)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (l *Conn) reader() {
 | 
						|
	defer func() {
 | 
						|
		l.Close()
 | 
						|
	}()
 | 
						|
 | 
						|
	for {
 | 
						|
		packet, err := ber.ReadPacket(l.conn)
 | 
						|
		if err != nil {
 | 
						|
			l.Debug.Printf("reader: %s", err.Error())
 | 
						|
			return
 | 
						|
		}
 | 
						|
		addLDAPDescriptions(packet)
 | 
						|
		message := &messagePacket{
 | 
						|
			Op:        MessageResponse,
 | 
						|
			MessageID: packet.Children[0].Value.(uint64),
 | 
						|
			Packet:    packet,
 | 
						|
		}
 | 
						|
		if !l.sendProcessMessage(message) {
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
}
 |