mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-06-02 06:23:37 +02:00
Add mux control frames
This commit is contained in:
+16
-14
@@ -103,18 +103,19 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
|
||||
})
|
||||
|
||||
for i := 0; i < peerCount; i++ {
|
||||
peerID := i
|
||||
peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Client peer %d reported conference end: %s", i, reason)
|
||||
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
c.peers = append(c.peers, peer)
|
||||
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
log.Printf("Client peer %d reconnected - resetting multiplexer state", i)
|
||||
log.Printf("Client peer %d reconnected - resetting multiplexer state", peerID)
|
||||
|
||||
c.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
@@ -130,11 +131,11 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
|
||||
log.Println("Client multiplexer reset complete")
|
||||
})
|
||||
|
||||
log.Printf("Connecting peer %d to Telemost...", i)
|
||||
log.Printf("Connecting peer %d to Telemost...", peerID)
|
||||
if err := peer.Connect(runCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Peer %d connected", i)
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
@@ -145,17 +146,18 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
resetFrame := make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(resetFrame[0:4], c.clientID)
|
||||
binary.BigEndian.PutUint16(resetFrame[4:6], 0xFFFF)
|
||||
binary.BigEndian.PutUint16(resetFrame[6:8], 0xFFFF)
|
||||
binary.BigEndian.PutUint32(resetFrame[8:12], 0)
|
||||
encrypted, _ := cipher.Encrypt(resetFrame)
|
||||
|
||||
for _, peer := range c.peers {
|
||||
peer.Send(encrypted)
|
||||
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
|
||||
encrypted, err := cipher.Encrypt(resetFrame)
|
||||
if err != nil {
|
||||
log.Printf("Failed to encrypt reset signal: %v", err)
|
||||
} else {
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Send(encrypted); err != nil {
|
||||
log.Printf("Failed to send reset signal to server: %v", err)
|
||||
}
|
||||
}
|
||||
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
|
||||
}
|
||||
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
|
||||
|
||||
err = c.runSOCKS5(runCtx, socksPort, socksUser, socksPass)
|
||||
|
||||
|
||||
+74
-29
@@ -6,12 +6,25 @@ package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
ControlStreamID uint16 = 0xFFFF
|
||||
ControlLength uint16 = 0xFFFF
|
||||
|
||||
ControlResetClient uint32 = 1
|
||||
)
|
||||
|
||||
type ControlFrame struct {
|
||||
ClientID uint32
|
||||
Type uint32
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
ID uint16
|
||||
ClientID uint32
|
||||
@@ -144,24 +157,47 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
return m.onSend(frame)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
if len(frame) < 12 {
|
||||
if len(frame) >= 8 {
|
||||
clientID := binary.BigEndian.Uint32(frame[0:4])
|
||||
sid := binary.BigEndian.Uint16(frame[4:6])
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
func (m *Multiplexer) SendClientReset() error {
|
||||
if m.clientID == 0 {
|
||||
return errors.New("client reset requires a non-zero client id")
|
||||
}
|
||||
return m.onSend(BuildControlFrame(m.clientID, ControlResetClient))
|
||||
}
|
||||
|
||||
if sid == 0xFFFF && length == 0xFFFF {
|
||||
m.mu.Lock()
|
||||
for streamSid, stream := range m.streams {
|
||||
if stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
delete(m.streams, streamSid)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
frame := make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(frame[0:4], clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
|
||||
binary.BigEndian.PutUint16(frame[6:8], ControlLength)
|
||||
binary.BigEndian.PutUint32(frame[8:12], controlType)
|
||||
return frame
|
||||
}
|
||||
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
if len(frame) < 12 {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
sid := binary.BigEndian.Uint16(frame[4:6])
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
if sid != ControlStreamID || length != ControlLength {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
return ControlFrame{
|
||||
ClientID: binary.BigEndian.Uint32(frame[0:4]),
|
||||
Type: binary.BigEndian.Uint32(frame[8:12]),
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if ok {
|
||||
m.handleControlFrame(control)
|
||||
return
|
||||
}
|
||||
|
||||
if len(frame) < 12 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,18 +206,6 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
seq := binary.BigEndian.Uint32(frame[8:12])
|
||||
|
||||
if sid == 0xFFFF && length == 0xFFFF {
|
||||
m.mu.Lock()
|
||||
for streamSid, stream := range m.streams {
|
||||
if stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
delete(m.streams, streamSid)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
m.mu.Lock()
|
||||
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
|
||||
@@ -270,6 +294,27 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) handleControlFrame(control ControlFrame) {
|
||||
switch control.Type {
|
||||
case ControlResetClient:
|
||||
m.ResetClient(control.ClientID)
|
||||
default:
|
||||
logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for streamSid, stream := range m.streams {
|
||||
if stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
delete(m.streams, streamSid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForBufferSpace releases m.mu and waits until the stream's recvBuf has
|
||||
// room for `need` more bytes, then re-acquires the lock. Returns the (possibly
|
||||
// re-fetched) stream, or nil if the stream disappeared / was reset / closed.
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseControlFrame(t *testing.T) {
|
||||
frame := BuildControlFrame(42, ControlResetClient)
|
||||
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if !ok {
|
||||
t.Fatal("expected control frame")
|
||||
}
|
||||
if control.ClientID != 42 {
|
||||
t.Fatalf("ClientID = %d, want 42", control.ClientID)
|
||||
}
|
||||
if control.Type != ControlResetClient {
|
||||
t.Fatalf("Type = %d, want %d", control.Type, ControlResetClient)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleControlResetClient(t *testing.T) {
|
||||
m := New(0, func([]byte) error { return nil })
|
||||
|
||||
dataFrame := make([]byte, 13)
|
||||
binary.BigEndian.PutUint32(dataFrame[0:4], 42)
|
||||
binary.BigEndian.PutUint16(dataFrame[4:6], 7)
|
||||
binary.BigEndian.PutUint16(dataFrame[6:8], 1)
|
||||
binary.BigEndian.PutUint32(dataFrame[8:12], 0)
|
||||
dataFrame[12] = 0xAA
|
||||
|
||||
m.HandleFrame(dataFrame)
|
||||
if stream := m.GetStream(7); stream == nil {
|
||||
t.Fatal("expected data stream before reset")
|
||||
}
|
||||
|
||||
m.HandleFrame(BuildControlFrame(42, ControlResetClient))
|
||||
if stream := m.GetStream(7); stream != nil {
|
||||
t.Fatal("expected data stream to be removed by client reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendClientReset(t *testing.T) {
|
||||
var sent []byte
|
||||
m := New(99, func(frame []byte) error {
|
||||
sent = append([]byte(nil), frame...)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := m.SendClientReset(); err != nil {
|
||||
t.Fatalf("SendClientReset failed: %v", err)
|
||||
}
|
||||
control, ok := ParseControlFrame(sent)
|
||||
if !ok {
|
||||
t.Fatal("expected sent control frame")
|
||||
}
|
||||
if control.ClientID != 99 || control.Type != ControlResetClient {
|
||||
t.Fatalf("control = %#v", control)
|
||||
}
|
||||
}
|
||||
+128
-81
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -27,6 +26,8 @@ type Server struct {
|
||||
mux *mux.Multiplexer
|
||||
connections map[uint16]net.Conn
|
||||
connMu sync.RWMutex
|
||||
streamPumps map[uint16]net.Conn
|
||||
pumpMu sync.Mutex
|
||||
peerIdx atomic.Uint32
|
||||
wg sync.WaitGroup
|
||||
dnsServer string
|
||||
@@ -76,6 +77,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
connections: make(map[uint16]net.Conn),
|
||||
streamPumps: make(map[uint16]net.Conn),
|
||||
peers: make([]*telemost.Peer, 0),
|
||||
dnsServer: dnsServer,
|
||||
}
|
||||
@@ -122,18 +124,19 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
|
||||
})
|
||||
|
||||
for i := 0; i < peerCount; i++ {
|
||||
peerID := i
|
||||
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Server peer %d reported conference end: %s", i, reason)
|
||||
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
s.peers = append(s.peers, peer)
|
||||
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
log.Printf("Server peer %d reconnected - resetting multiplexer state", i)
|
||||
log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID)
|
||||
|
||||
s.connMu.Lock()
|
||||
for sid, conn := range s.connections {
|
||||
@@ -160,11 +163,11 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
|
||||
log.Println("Server multiplexer reset complete")
|
||||
})
|
||||
|
||||
log.Printf("Connecting peer %d to Telemost...", i)
|
||||
log.Printf("Connecting peer %d to Telemost...", peerID)
|
||||
if err := peer.Connect(runCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Peer %d connected", i)
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
@@ -189,49 +192,29 @@ func (s *Server) onData(data []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(plaintext) >= 12 {
|
||||
clientID := binary.BigEndian.Uint32(plaintext[0:4])
|
||||
sid := binary.BigEndian.Uint16(plaintext[4:6])
|
||||
length := binary.BigEndian.Uint16(plaintext[6:8])
|
||||
|
||||
if sid == 0xFFFF && length == 0xFFFF {
|
||||
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
|
||||
s.connMu.Lock()
|
||||
for streamSid, conn := range s.connections {
|
||||
stream := s.mux.GetStream(streamSid)
|
||||
if stream != nil && stream.ClientID == clientID {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
delete(s.connections, streamSid)
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
} else if len(plaintext) >= 8 {
|
||||
clientID := binary.BigEndian.Uint32(plaintext[0:4])
|
||||
sid := binary.BigEndian.Uint16(plaintext[4:6])
|
||||
length := binary.BigEndian.Uint16(plaintext[6:8])
|
||||
|
||||
if sid == 0xFFFF && length == 0xFFFF {
|
||||
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
|
||||
s.connMu.Lock()
|
||||
for streamSid, conn := range s.connections {
|
||||
stream := s.mux.GetStream(streamSid)
|
||||
if stream != nil && stream.ClientID == clientID {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
delete(s.connections, streamSid)
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
|
||||
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID)
|
||||
s.closeClientConnections(control.ClientID)
|
||||
}
|
||||
|
||||
s.mux.HandleFrame(plaintext)
|
||||
}
|
||||
|
||||
func (s *Server) closeClientConnections(clientID uint32) {
|
||||
s.connMu.Lock()
|
||||
defer s.connMu.Unlock()
|
||||
|
||||
for streamSid, conn := range s.connections {
|
||||
stream := s.mux.GetStream(streamSid)
|
||||
if stream != nil && stream.ClientID == clientID {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
delete(s.connections, streamSid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) error {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
@@ -263,50 +246,78 @@ func (s *Server) run(ctx context.Context) error {
|
||||
sids := s.mux.GetStreams()
|
||||
|
||||
for _, sid := range sids {
|
||||
go func(sid uint16) {
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) > 0 {
|
||||
s.connMu.RLock()
|
||||
conn, exists := s.connections[sid]
|
||||
s.connMu.RUnlock()
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnection(sid)
|
||||
continue
|
||||
}
|
||||
|
||||
if exists && conn != nil {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
s.mux.CloseStream(sid)
|
||||
conn.Close()
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, sid)
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
} else {
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
|
||||
s.connMu.Lock()
|
||||
if oldConn, exists := s.connections[sid]; exists && oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
go s.handleConnect(sid, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.hasConnection(sid) {
|
||||
continue
|
||||
}
|
||||
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.connMu.Lock()
|
||||
conn, exists := s.connections[sid]
|
||||
if exists && conn != nil {
|
||||
conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
}(sid)
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
|
||||
s.closeStreamConnection(sid)
|
||||
go s.handleConnect(ctx, sid, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
|
||||
func (s *Server) hasConnection(sid uint16) bool {
|
||||
s.connMu.RLock()
|
||||
defer s.connMu.RUnlock()
|
||||
conn := s.connections[sid]
|
||||
return conn != nil
|
||||
}
|
||||
|
||||
func (s *Server) closeStreamConnection(sid uint16) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn == expected {
|
||||
conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
|
||||
s.pumpMu.Lock()
|
||||
defer s.pumpMu.Unlock()
|
||||
if current := s.streamPumps[sid]; current == conn {
|
||||
return false
|
||||
} else if current != nil {
|
||||
current.Close()
|
||||
}
|
||||
s.streamPumps[sid] = conn
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
|
||||
s.pumpMu.Lock()
|
||||
if s.streamPumps[sid] == conn {
|
||||
delete(s.streamPumps, sid)
|
||||
}
|
||||
s.pumpMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
|
||||
startTime := time.Now()
|
||||
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
|
||||
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
|
||||
@@ -347,6 +358,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
|
||||
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
|
||||
|
||||
s.mux.SendData(sid, []byte{0x00})
|
||||
s.startStreamPump(ctx, sid, conn)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
@@ -386,6 +398,41 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
if !s.markStreamPump(sid, conn) {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.unmarkStreamPump(sid, conn)
|
||||
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
s.mux.CloseStream(sid)
|
||||
s.closeStreamConnectionIfCurrent(sid, conn)
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnectionIfCurrent(sid, conn)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) canSendData() bool {
|
||||
for _, peer := range s.peers {
|
||||
if !peer.CanSend() {
|
||||
|
||||
Reference in New Issue
Block a user