Add mux control frames

This commit is contained in:
Qtozdec
2026-04-10 16:26:15 +03:00
parent ca0191d0de
commit 913cabe222
4 changed files with 279 additions and 124 deletions
+16 -14
View File
@@ -103,18 +103,19 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
})
for i := 0; i < peerCount; i++ {
peerID := i
peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData)
if err != nil {
return err
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Client peer %d reported conference end: %s", i, reason)
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
cancel()
})
c.peers = append(c.peers, peer)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
log.Printf("Client peer %d reconnected - resetting multiplexer state", i)
log.Printf("Client peer %d reconnected - resetting multiplexer state", peerID)
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
@@ -130,11 +131,11 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
log.Println("Client multiplexer reset complete")
})
log.Printf("Connecting peer %d to Telemost...", i)
log.Printf("Connecting peer %d to Telemost...", peerID)
if err := peer.Connect(runCtx); err != nil {
return err
}
log.Printf("Peer %d connected", i)
log.Printf("Peer %d connected", peerID)
c.wg.Add(1)
go func() {
@@ -145,17 +146,18 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s
time.Sleep(100 * time.Millisecond)
resetFrame := make([]byte, 12)
binary.BigEndian.PutUint32(resetFrame[0:4], c.clientID)
binary.BigEndian.PutUint16(resetFrame[4:6], 0xFFFF)
binary.BigEndian.PutUint16(resetFrame[6:8], 0xFFFF)
binary.BigEndian.PutUint32(resetFrame[8:12], 0)
encrypted, _ := cipher.Encrypt(resetFrame)
for _, peer := range c.peers {
peer.Send(encrypted)
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
encrypted, err := cipher.Encrypt(resetFrame)
if err != nil {
log.Printf("Failed to encrypt reset signal: %v", err)
} else {
for _, peer := range c.peers {
if err := peer.Send(encrypted); err != nil {
log.Printf("Failed to send reset signal to server: %v", err)
}
}
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
}
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
err = c.runSOCKS5(runCtx, socksPort, socksUser, socksPass)
+74 -29
View File
@@ -6,12 +6,25 @@ package mux
import (
"encoding/binary"
"errors"
"sync"
"time"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
const (
ControlStreamID uint16 = 0xFFFF
ControlLength uint16 = 0xFFFF
ControlResetClient uint32 = 1
)
type ControlFrame struct {
ClientID uint32
Type uint32
}
type Stream struct {
ID uint16
ClientID uint32
@@ -144,24 +157,47 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
return m.onSend(frame)
}
func (m *Multiplexer) HandleFrame(frame []byte) {
if len(frame) < 12 {
if len(frame) >= 8 {
clientID := binary.BigEndian.Uint32(frame[0:4])
sid := binary.BigEndian.Uint16(frame[4:6])
length := binary.BigEndian.Uint16(frame[6:8])
func (m *Multiplexer) SendClientReset() error {
if m.clientID == 0 {
return errors.New("client reset requires a non-zero client id")
}
return m.onSend(BuildControlFrame(m.clientID, ControlResetClient))
}
if sid == 0xFFFF && length == 0xFFFF {
m.mu.Lock()
for streamSid, stream := range m.streams {
if stream.ClientID == clientID {
stream.closed = true
delete(m.streams, streamSid)
}
}
m.mu.Unlock()
}
}
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
frame := make([]byte, 12)
binary.BigEndian.PutUint32(frame[0:4], clientID)
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
binary.BigEndian.PutUint16(frame[6:8], ControlLength)
binary.BigEndian.PutUint32(frame[8:12], controlType)
return frame
}
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
if len(frame) < 12 {
return ControlFrame{}, false
}
sid := binary.BigEndian.Uint16(frame[4:6])
length := binary.BigEndian.Uint16(frame[6:8])
if sid != ControlStreamID || length != ControlLength {
return ControlFrame{}, false
}
return ControlFrame{
ClientID: binary.BigEndian.Uint32(frame[0:4]),
Type: binary.BigEndian.Uint32(frame[8:12]),
}, true
}
func (m *Multiplexer) HandleFrame(frame []byte) {
control, ok := ParseControlFrame(frame)
if ok {
m.handleControlFrame(control)
return
}
if len(frame) < 12 {
return
}
@@ -170,18 +206,6 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
length := binary.BigEndian.Uint16(frame[6:8])
seq := binary.BigEndian.Uint32(frame[8:12])
if sid == 0xFFFF && length == 0xFFFF {
m.mu.Lock()
for streamSid, stream := range m.streams {
if stream.ClientID == clientID {
stream.closed = true
delete(m.streams, streamSid)
}
}
m.mu.Unlock()
return
}
if length == 0 {
m.mu.Lock()
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
@@ -270,6 +294,27 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
}
}
func (m *Multiplexer) handleControlFrame(control ControlFrame) {
switch control.Type {
case ControlResetClient:
m.ResetClient(control.ClientID)
default:
logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
}
}
func (m *Multiplexer) ResetClient(clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
for streamSid, stream := range m.streams {
if stream.ClientID == clientID {
stream.closed = true
delete(m.streams, streamSid)
}
}
}
// waitForBufferSpace releases m.mu and waits until the stream's recvBuf has
// room for `need` more bytes, then re-acquires the lock. Returns the (possibly
// re-fetched) stream, or nil if the stream disappeared / was reset / closed.
+61
View File
@@ -0,0 +1,61 @@
package mux
import (
"encoding/binary"
"testing"
)
func TestParseControlFrame(t *testing.T) {
frame := BuildControlFrame(42, ControlResetClient)
control, ok := ParseControlFrame(frame)
if !ok {
t.Fatal("expected control frame")
}
if control.ClientID != 42 {
t.Fatalf("ClientID = %d, want 42", control.ClientID)
}
if control.Type != ControlResetClient {
t.Fatalf("Type = %d, want %d", control.Type, ControlResetClient)
}
}
func TestHandleControlResetClient(t *testing.T) {
m := New(0, func([]byte) error { return nil })
dataFrame := make([]byte, 13)
binary.BigEndian.PutUint32(dataFrame[0:4], 42)
binary.BigEndian.PutUint16(dataFrame[4:6], 7)
binary.BigEndian.PutUint16(dataFrame[6:8], 1)
binary.BigEndian.PutUint32(dataFrame[8:12], 0)
dataFrame[12] = 0xAA
m.HandleFrame(dataFrame)
if stream := m.GetStream(7); stream == nil {
t.Fatal("expected data stream before reset")
}
m.HandleFrame(BuildControlFrame(42, ControlResetClient))
if stream := m.GetStream(7); stream != nil {
t.Fatal("expected data stream to be removed by client reset")
}
}
func TestSendClientReset(t *testing.T) {
var sent []byte
m := New(99, func(frame []byte) error {
sent = append([]byte(nil), frame...)
return nil
})
if err := m.SendClientReset(); err != nil {
t.Fatalf("SendClientReset failed: %v", err)
}
control, ok := ParseControlFrame(sent)
if !ok {
t.Fatal("expected sent control frame")
}
if control.ClientID != 99 || control.Type != ControlResetClient {
t.Fatalf("control = %#v", control)
}
}
+128 -81
View File
@@ -3,7 +3,6 @@ package server
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
@@ -27,6 +26,8 @@ type Server struct {
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
streamPumps map[uint16]net.Conn
pumpMu sync.Mutex
peerIdx atomic.Uint32
wg sync.WaitGroup
dnsServer string
@@ -76,6 +77,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
s := &Server{
cipher: cipher,
connections: make(map[uint16]net.Conn),
streamPumps: make(map[uint16]net.Conn),
peers: make([]*telemost.Peer, 0),
dnsServer: dnsServer,
}
@@ -122,18 +124,19 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
})
for i := 0; i < peerCount; i++ {
peerID := i
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
if err != nil {
return err
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Server peer %d reported conference end: %s", i, reason)
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
cancel()
})
s.peers = append(s.peers, peer)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
log.Printf("Server peer %d reconnected - resetting multiplexer state", i)
log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID)
s.connMu.Lock()
for sid, conn := range s.connections {
@@ -160,11 +163,11 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
log.Println("Server multiplexer reset complete")
})
log.Printf("Connecting peer %d to Telemost...", i)
log.Printf("Connecting peer %d to Telemost...", peerID)
if err := peer.Connect(runCtx); err != nil {
return err
}
log.Printf("Peer %d connected", i)
log.Printf("Peer %d connected", peerID)
s.wg.Add(1)
go func() {
@@ -189,49 +192,29 @@ func (s *Server) onData(data []byte) {
return
}
if len(plaintext) >= 12 {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
}
delete(s.connections, streamSid)
}
}
s.connMu.Unlock()
}
} else if len(plaintext) >= 8 {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
}
delete(s.connections, streamSid)
}
}
s.connMu.Unlock()
}
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID)
s.closeClientConnections(control.ClientID)
}
s.mux.HandleFrame(plaintext)
}
func (s *Server) closeClientConnections(clientID uint32) {
s.connMu.Lock()
defer s.connMu.Unlock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
}
delete(s.connections, streamSid)
}
}
}
func (s *Server) run(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
@@ -263,50 +246,78 @@ func (s *Server) run(ctx context.Context) error {
sids := s.mux.GetStreams()
for _, sid := range sids {
go func(sid uint16) {
data := s.mux.ReadStream(sid)
if len(data) > 0 {
s.connMu.RLock()
conn, exists := s.connections[sid]
s.connMu.RUnlock()
if s.mux.StreamClosed(sid) {
s.closeStreamConnection(sid)
continue
}
if exists && conn != nil {
if _, err := conn.Write(data); err != nil {
s.mux.CloseStream(sid)
conn.Close()
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}
} else {
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
s.connMu.Lock()
if oldConn, exists := s.connections[sid]; exists && oldConn != nil {
oldConn.Close()
}
s.connMu.Unlock()
go s.handleConnect(sid, req)
}
}
}
if s.hasConnection(sid) {
continue
}
if s.mux.StreamClosed(sid) {
s.connMu.Lock()
conn, exists := s.connections[sid]
if exists && conn != nil {
conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
}(sid)
data := s.mux.ReadStream(sid)
if len(data) == 0 {
continue
}
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
}
}
}
func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
func (s *Server) hasConnection(sid uint16) bool {
s.connMu.RLock()
defer s.connMu.RUnlock()
conn := s.connections[sid]
return conn != nil
}
func (s *Server) closeStreamConnection(sid uint16) {
s.connMu.Lock()
conn := s.connections[sid]
if conn != nil {
conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
s.connMu.Lock()
conn := s.connections[sid]
if conn == expected {
conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
s.pumpMu.Lock()
defer s.pumpMu.Unlock()
if current := s.streamPumps[sid]; current == conn {
return false
} else if current != nil {
current.Close()
}
s.streamPumps[sid] = conn
return true
}
func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
s.pumpMu.Lock()
if s.streamPumps[sid] == conn {
delete(s.streamPumps, sid)
}
s.pumpMu.Unlock()
}
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
startTime := time.Now()
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
@@ -347,6 +358,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
s.mux.SendData(sid, []byte{0x00})
s.startStreamPump(ctx, sid, conn)
go func() {
defer func() {
@@ -386,6 +398,41 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
}()
}
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
if !s.markStreamPump(sid, conn) {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.unmarkStreamPump(sid, conn)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
data := s.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
s.mux.CloseStream(sid)
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
if s.mux.StreamClosed(sid) {
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
}
}()
}
func (s *Server) canSendData() bool {
for _, peer := range s.peers {
if !peer.CanSend() {