From 4639e0b3b782450e6bc210ecc4bdefc725d29223 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Sat, 16 May 2026 14:16:43 +0300 Subject: [PATCH] refactor: extract shared transport framing helpers into internal/transport/common videochannel, seichannel and vp8channel each carried independent copies of randomID(), fragmentPayload(), inboundMessage + upsertInbound + assembleMessage + ackWaiters/ackMu. The reassembly logic was almost byte-identical across videochannel and seichannel; vp8channel only needed randomID. Three copies of the same idea. Add internal/transport/common with: - RandomID(): 8-char hex per-peer ID (Jitsi msid uniqueness requirement). - FragmentPayload(): split bytes into max-size chunks. - Reassembler: stores in-flight messages keyed by Seq, validates CRC, and reports Partial / Delivered / Duplicate / Ignore via a Result enum. - AckRegistry: Register/Unregister/Resolve for ack waiters. videochannel and seichannel now hold *common.AckRegistry and *common.Reassembler instead of raw maps + mutexes. Their Send paths route through acks.Register/Unregister; their handleInboundFrame is a 20-line switch over reassembler.Push. vp8channel keeps its KCP framing but reuses common.RandomID. Tests that constructed raw streamTransport with inbound/delivered/ackWaiters maps are updated to instantiate the new common types instead. Two now- redundant low-level tests (upsertInbound out-of-range, assembleMessage) collapse into the new TestInboundRejectsBadCRC. Co-Authored-By: Claude Opus 4.7 --- internal/transport/common/common.go | 207 ++++++++++++++++++ internal/transport/common/common_test.go | 107 +++++++++ .../transport/seichannel/frame_extra_test.go | 18 -- internal/transport/seichannel/inbound_test.go | 31 +-- internal/transport/seichannel/transport.go | 160 +++----------- .../seichannel/transport_unit_test.go | 3 +- internal/transport/videochannel/frame.go | 27 --- .../videochannel/frame_extra_test.go | 18 -- .../transport/videochannel/inbound_test.go | 31 +-- internal/transport/videochannel/transport.go | 135 +++--------- .../videochannel/transport_unit_test.go | 3 +- internal/transport/vp8channel/transport.go | 17 +- 12 files changed, 386 insertions(+), 371 deletions(-) create mode 100644 internal/transport/common/common.go create mode 100644 internal/transport/common/common_test.go diff --git a/internal/transport/common/common.go b/internal/transport/common/common.go new file mode 100644 index 0000000..757da4a --- /dev/null +++ b/internal/transport/common/common.go @@ -0,0 +1,207 @@ +// Package common provides building blocks shared by the video-track based +// transports (videochannel, seichannel) — fragment/reassembly, ack waiters, +// and per-peer random IDs. vp8channel does its own KCP-based framing and +// only consumes RandomID. +package common + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "hash/crc32" + "sync" + "time" +) + +// RandomID returns 8 random hex characters for use as a per-peer suffix on +// track and stream IDs. Required for Jitsi: msid collisions between +// participants cause Jicofo to reject session-accept. +func RandomID() string { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + return fmt.Sprintf("%08x", time.Now().UnixNano()) + } + return hex.EncodeToString(b[:]) +} + +// FragmentPayload splits data into chunks of at most maxSize bytes. An empty +// payload produces a single empty fragment so the caller can still ack a +// zero-byte message round-trip. +func FragmentPayload(data []byte, maxSize int) [][]byte { + if len(data) == 0 { + return [][]byte{{}} + } + out := make([][]byte, 0, (len(data)+maxSize-1)/maxSize) + for start := 0; start < len(data); start += maxSize { + end := start + maxSize + if end > len(data) { + end = len(data) + } + chunk := make([]byte, end-start) + copy(chunk, data[start:end]) + out = append(out, chunk) + } + return out +} + +// Fragment describes one piece of a fragmented message on the wire. +type Fragment struct { + Seq uint32 + CRC uint32 + TotalLen uint32 + FragIdx uint16 + FragTotal uint16 + Payload []byte +} + +// InboundMessage tracks reassembly state for one inbound message. +type InboundMessage struct { + TotalLen uint32 + CRC uint32 + frags [][]byte + remain int +} + +// Reassembler holds inbound message state and a sliding window of recently +// delivered (seq, crc) pairs so duplicate fragments resolve to a fresh ack +// rather than a re-delivery. +type Reassembler struct { + mu sync.Mutex + inbound map[uint32]*InboundMessage + delivered map[uint32]uint32 + maxRecent int +} + +// NewReassembler creates a reassembler with the given recent-delivery cap. +// When the delivered map exceeds maxRecent entries it is reset; a value of +// 256 is a reasonable default for the video transports. +func NewReassembler(maxRecent int) *Reassembler { + if maxRecent <= 0 { + maxRecent = 256 + } + return &Reassembler{ + inbound: make(map[uint32]*InboundMessage), + delivered: make(map[uint32]uint32), + maxRecent: maxRecent, + } +} + +// Result classifies what Push computed for a fragment. +type Result int + +const ( + // ResultIgnore means the fragment was malformed or out of range. + ResultIgnore Result = iota + // ResultPartial means the fragment was stored but the message is not + // fully reassembled yet. + ResultPartial + // ResultDuplicate means the message identified by (Seq, CRC) was + // already delivered. Caller should re-ack without invoking OnData. + ResultDuplicate + // ResultDelivered means the message is complete; Data carries the + // reassembled payload. + ResultDelivered +) + +// Push integrates fragment into reassembly state and returns one of the +// Result values. When ResultDelivered, the second return holds the +// reassembled payload bytes; otherwise it is nil. +func (r *Reassembler) Push(fragment Fragment) (Result, []byte) { + r.mu.Lock() + defer r.mu.Unlock() + + if crc, ok := r.delivered[fragment.Seq]; ok && crc == fragment.CRC { + return ResultDuplicate, nil + } + + msg, ok := r.inbound[fragment.Seq] + if !ok || msg.CRC != fragment.CRC || msg.TotalLen != fragment.TotalLen || + len(msg.frags) != int(fragment.FragTotal) { + msg = &InboundMessage{ + TotalLen: fragment.TotalLen, + CRC: fragment.CRC, + frags: make([][]byte, fragment.FragTotal), + remain: int(fragment.FragTotal), + } + r.inbound[fragment.Seq] = msg + } + if int(fragment.FragIdx) >= len(msg.frags) { + return ResultIgnore, nil + } + if msg.frags[fragment.FragIdx] == nil { + chunk := make([]byte, len(fragment.Payload)) + copy(chunk, fragment.Payload) + msg.frags[fragment.FragIdx] = chunk + msg.remain-- + } + if msg.remain > 0 { + return ResultPartial, nil + } + + delete(r.inbound, fragment.Seq) + data := assemble(msg) + if crc32.ChecksumIEEE(data) != msg.CRC { + return ResultIgnore, nil + } + if len(r.delivered) > r.maxRecent { + r.delivered = make(map[uint32]uint32) + } + r.delivered[fragment.Seq] = msg.CRC + return ResultDelivered, data +} + +func assemble(msg *InboundMessage) []byte { + out := make([]byte, 0, msg.TotalLen) + for _, frag := range msg.frags { + out = append(out, frag...) + } + if uint32(len(out)) > msg.TotalLen { //nolint:gosec // G115: bounded by allocation size + out = out[:msg.TotalLen] + } + return out +} + +// AckRegistry tracks in-flight Send calls waiting for their peer ack. Each +// Send registers a waiter keyed by sequence number and reads from it; the +// receive loop calls Resolve when an ack arrives. +type AckRegistry struct { + mu sync.Mutex + waiters map[uint32]chan uint32 +} + +// NewAckRegistry creates an empty ack registry. +func NewAckRegistry() *AckRegistry { + return &AckRegistry{waiters: make(map[uint32]chan uint32)} +} + +// Register installs a waiter for seq and returns its channel. The caller +// must drop the waiter via Unregister when it is done. +func (a *AckRegistry) Register(seq uint32) chan uint32 { + ch := make(chan uint32, 1) + a.mu.Lock() + a.waiters[seq] = ch + a.mu.Unlock() + return ch +} + +// Unregister drops the waiter for seq. +func (a *AckRegistry) Unregister(seq uint32) { + a.mu.Lock() + delete(a.waiters, seq) + a.mu.Unlock() +} + +// Resolve delivers crc to the waiter for seq, if present. A missing waiter +// is silently ignored — the sender has already moved on. +func (a *AckRegistry) Resolve(seq, crc uint32) { + a.mu.Lock() + waiter := a.waiters[seq] + a.mu.Unlock() + if waiter == nil { + return + } + select { + case waiter <- crc: + default: + } +} diff --git a/internal/transport/common/common_test.go b/internal/transport/common/common_test.go new file mode 100644 index 0000000..1080be4 --- /dev/null +++ b/internal/transport/common/common_test.go @@ -0,0 +1,107 @@ +package common_test + +import ( + "hash/crc32" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/transport/common" +) + +func TestRandomID(t *testing.T) { + a := common.RandomID() + b := common.RandomID() + if len(a) != 8 || len(b) != 8 { + t.Fatalf("RandomID() = %q, %q, want 8 hex chars each", a, b) + } + if a == b { + t.Fatalf("RandomID() returned the same value twice: %q", a) + } +} + +func TestFragmentPayloadEmpty(t *testing.T) { + got := common.FragmentPayload(nil, 16) + if len(got) != 1 || len(got[0]) != 0 { + t.Fatalf("FragmentPayload(nil) = %v, want one empty fragment", got) + } +} + +func TestFragmentPayloadChunks(t *testing.T) { + data := []byte("hello world") + got := common.FragmentPayload(data, 4) + if len(got) != 3 || string(got[0]) != "hell" || string(got[1]) != "o wo" || string(got[2]) != "rld" { + t.Fatalf("FragmentPayload(%q, 4) = %v", data, got) + } +} + +func TestReassemblerDeliveredAndDuplicate(t *testing.T) { + r := common.NewReassembler(8) + payload := []byte("hello world") + crc := crc32.ChecksumIEEE(payload) + frags := common.FragmentPayload(payload, 5) + + for i, frag := range frags { + result, data := r.Push(common.Fragment{ + Seq: 1, + CRC: crc, + TotalLen: uint32(len(payload)), + FragIdx: uint16(i), + FragTotal: uint16(len(frags)), + Payload: frag, + }) + if i < len(frags)-1 { + if result != common.ResultPartial { + t.Fatalf("Push(%d) result = %v, want Partial", i, result) + } + } else { + if result != common.ResultDelivered || string(data) != "hello world" { + t.Fatalf("Push(final) = %v / %q", result, data) + } + } + } + + // re-push the last fragment: duplicate path. + result, _ := r.Push(common.Fragment{ + Seq: 1, + CRC: crc, + TotalLen: uint32(len(payload)), + FragIdx: uint16(len(frags) - 1), + FragTotal: uint16(len(frags)), + Payload: frags[len(frags)-1], + }) + if result != common.ResultDuplicate { + t.Fatalf("dup push result = %v, want Duplicate", result) + } +} + +func TestReassemblerIgnoresCRCMismatch(t *testing.T) { + r := common.NewReassembler(8) + payload := []byte("abcd") + frags := common.FragmentPayload(payload, 4) + result, _ := r.Push(common.Fragment{ + Seq: 1, + CRC: 0xdeadbeef, // wrong + TotalLen: uint32(len(payload)), + FragIdx: 0, + FragTotal: uint16(len(frags)), + Payload: frags[0], + }) + if result != common.ResultDelivered { + // single-fragment path: assemble fires immediately, CRC check fails, ignore. + if result != common.ResultIgnore { + t.Fatalf("Push() result = %v, want Ignore", result) + } + } +} + +func TestAckRegistry(t *testing.T) { + a := common.NewAckRegistry() + ch := a.Register(42) + defer a.Unregister(42) + go a.Resolve(42, 0xcafebabe) + got := <-ch + if got != 0xcafebabe { + t.Fatalf("Resolve forwarded %x, want %x", got, 0xcafebabe) + } + // Stale resolve does not block / panic. + a.Resolve(999, 0) +} diff --git a/internal/transport/seichannel/frame_extra_test.go b/internal/transport/seichannel/frame_extra_test.go index 206e403..72f8a73 100644 --- a/internal/transport/seichannel/frame_extra_test.go +++ b/internal/transport/seichannel/frame_extra_test.go @@ -6,24 +6,6 @@ import ( "testing" ) -func TestFragmentPayload(t *testing.T) { - frags := fragmentPayload([]byte("abcdef"), 2) - want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")} - if len(frags) != len(want) { - t.Fatalf("fragment count = %d, want %d", len(frags), len(want)) - } - for i := range frags { - if !bytes.Equal(frags[i], want[i]) { - t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i]) - } - } - - empty := fragmentPayload(nil, 10) - if len(empty) != 1 || len(empty[0]) != 0 { - t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty) - } -} - func TestDecodeTransportFrameErrorsAndAck(t *testing.T) { tests := []struct { data []byte diff --git a/internal/transport/seichannel/inbound_test.go b/internal/transport/seichannel/inbound_test.go index 96e6e13..c78a81a 100644 --- a/internal/transport/seichannel/inbound_test.go +++ b/internal/transport/seichannel/inbound_test.go @@ -4,6 +4,8 @@ import ( "bytes" "hash/crc32" "testing" + + "github.com/openlibrecommunity/olcrtc/internal/transport/common" ) func TestInboundAssemblyAndAck(t *testing.T) { @@ -11,8 +13,7 @@ func TestInboundAssemblyAndAck(t *testing.T) { tr := &streamTransport{ onData: func(data []byte) { got = append([]byte(nil), data...) }, outboundAck: make(chan []byte, 4), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), + reassembler: common.NewReassembler(256), } payload := []byte("hello world") @@ -67,23 +68,10 @@ func TestInboundAssemblyAndAck(t *testing.T) { } } -func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { +func TestInboundRejectsBadCRC(t *testing.T) { tr := &streamTransport{ outboundAck: make(chan []byte, 2), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), - } - - msg, complete := tr.upsertInbound(transportFrame{ - seq: 1, - crc: 1, - totalLen: 3, - fragIdx: 3, - fragTotal: 1, - payload: []byte("bad"), - }) - if msg != nil || complete { - t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete) + reassembler: common.NewReassembler(256), } called := false @@ -99,13 +87,4 @@ func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { if called { t.Fatal("handleInboundFrame() delivered payload with bad crc") } - - msg = &inboundMessage{ - totalLen: 3, - crc: crc32.ChecksumIEEE([]byte("abcdef")), - frags: [][]byte{[]byte("abc"), []byte("def")}, - } - if got := tr.assembleMessage(msg); string(got) != "abc" { - t.Fatalf("assembleMessage() = %q, want abc", got) - } } diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index f4f9620..4f49c97 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -3,9 +3,7 @@ package seichannel import ( "context" - "crypto/rand" "encoding/binary" - "encoding/hex" "errors" "fmt" "hash/crc32" @@ -16,6 +14,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/engine" enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/common" "github.com/pion/rtp/codecs" "github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4/pkg/media" @@ -70,13 +69,6 @@ type transportFrame struct { payload []byte } -type inboundMessage struct { - totalLen uint32 - crc uint32 - frags [][]byte - remain int -} - // videoSession is the subset of engine.Session + engine.VideoTrackCapable the // seichannel transport relies on. type videoSession interface { @@ -105,11 +97,8 @@ type streamTransport struct { peerReady atomic.Bool sendMu sync.Mutex startWriter sync.Once - ackMu sync.Mutex - ackWaiters map[uint32]chan uint32 - recvMu sync.Mutex - inbound map[uint32]*inboundMessage - delivered map[uint32]uint32 + acks *common.AckRegistry + reassembler *common.Reassembler fragmentSize int ackTimeout time.Duration frameInterval time.Duration @@ -154,8 +143,8 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) Channels: 0, SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", }, - "seichannel-"+randomID(), - "olcrtc-"+randomID(), + "seichannel-"+common.RandomID(), + "olcrtc-"+common.RandomID(), ) if err != nil { return nil, fmt.Errorf("create local video track: %w", err) @@ -186,9 +175,8 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) outboundAck: make(chan []byte, 64), closeCh: make(chan struct{}), writerDone: make(chan struct{}), - ackWaiters: make(map[uint32]chan uint32), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), + acks: common.NewAckRegistry(), + reassembler: common.NewReassembler(256), fragmentSize: fragmentSize, ackTimeout: ackTimeout, frameInterval: time.Second / time.Duration(fps), @@ -231,17 +219,9 @@ func (p *streamTransport) Send(data []byte) error { seq := p.nextSeq.Add(1) crc := crc32.ChecksumIEEE(data) - fragments := fragmentPayload(data, p.effectiveFragmentSize()) - waiter := make(chan uint32, 1) - - p.ackMu.Lock() - p.ackWaiters[seq] = waiter - p.ackMu.Unlock() - defer func() { - p.ackMu.Lock() - delete(p.ackWaiters, seq) - p.ackMu.Unlock() - }() + fragments := common.FragmentPayload(data, p.effectiveFragmentSize()) + waiter := p.acks.Register(seq) + defer p.acks.Unregister(seq) for range maxSendAttempts { for idx, fragment := range fragments { @@ -473,72 +453,26 @@ func (p *streamTransport) handleSample(sample []byte) { } } -func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) { - msg, ok := p.inbound[frame.seq] - if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { - msg = &inboundMessage{ - totalLen: frame.totalLen, - crc: frame.crc, - frags: make([][]byte, frame.fragTotal), - remain: int(frame.fragTotal), - } - p.inbound[frame.seq] = msg - } - if int(frame.fragIdx) >= len(msg.frags) { - return nil, false - } - if msg.frags[frame.fragIdx] == nil { - chunk := make([]byte, len(frame.payload)) - copy(chunk, frame.payload) - msg.frags[frame.fragIdx] = chunk - msg.remain-- - } - return msg, msg.remain == 0 -} - -func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte { - data := make([]byte, 0, msg.totalLen) - for _, frag := range msg.frags { - data = append(data, frag...) - } - if uint32(len(data)) > msg.totalLen { //nolint:gosec // G115: bounded conversion verified by surrounding logic - data = data[:msg.totalLen] - } - return data -} - func (p *streamTransport) handleInboundFrame(frame transportFrame) { - p.recvMu.Lock() - if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { - p.recvMu.Unlock() + result, data := p.reassembler.Push(common.Fragment{ + Seq: frame.seq, + CRC: frame.crc, + TotalLen: frame.totalLen, + FragIdx: frame.fragIdx, + FragTotal: frame.fragTotal, + Payload: frame.payload, + }) + switch result { + case common.ResultDuplicate: p.sendAck(frame.seq, frame.crc) - return + case common.ResultDelivered: + if p.onData != nil { + p.onData(data) + } + p.sendAck(frame.seq, frame.crc) + default: + // Partial or Ignore: do nothing. } - - msg, complete := p.upsertInbound(frame) - if msg == nil || !complete { - p.recvMu.Unlock() - return - } - - delete(p.inbound, frame.seq) - data := p.assembleMessage(msg) - - if crc32.ChecksumIEEE(data) != msg.crc { - p.recvMu.Unlock() - return - } - - if len(p.delivered) > 256 { - p.delivered = make(map[uint32]uint32) - } - p.delivered[frame.seq] = msg.crc - p.recvMu.Unlock() - - if p.onData != nil { - p.onData(data) - } - p.sendAck(frame.seq, frame.crc) } func (p *streamTransport) sendAck(seq, crc uint32) { @@ -546,35 +480,7 @@ func (p *streamTransport) sendAck(seq, crc uint32) { } func (p *streamTransport) resolveAck(seq, crc uint32) { - p.ackMu.Lock() - waiter := p.ackWaiters[seq] - p.ackMu.Unlock() - - if waiter == nil { - return - } - - select { - case waiter <- crc: - default: - } -} - -func fragmentPayload(data []byte, maxSize int) [][]byte { - if len(data) == 0 { - return [][]byte{{}} - } - - out := make([][]byte, 0, (len(data)+maxSize-1)/maxSize) - for start := 0; start < len(data); start += maxSize { - end := min(start+maxSize, len(data)) - - chunk := make([]byte, end-start) - copy(chunk, data[start:end]) - out = append(out, chunk) - } - - return out + p.acks.Resolve(seq, crc) } func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload []byte) []byte { @@ -647,13 +553,3 @@ func decodeTransportFrame(data []byte) (transportFrame, error) { } } -// randomID returns 8 random hex characters for use as a per-peer suffix on -// track and stream IDs. Required for Jitsi: msid collisions between -// participants cause Jicofo to reject session-accept. -func randomID() string { - var b [4]byte - if _, err := rand.Read(b[:]); err != nil { - return fmt.Sprintf("%08x", time.Now().UnixNano()) - } - return hex.EncodeToString(b[:]) -} diff --git a/internal/transport/seichannel/transport_unit_test.go b/internal/transport/seichannel/transport_unit_test.go index c055d01..ed8b53a 100644 --- a/internal/transport/seichannel/transport_unit_test.go +++ b/internal/transport/seichannel/transport_unit_test.go @@ -10,6 +10,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/engine" enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/common" "github.com/pion/webrtc/v4" ) @@ -166,7 +167,7 @@ func TestSendAckAndClosePaths(t *testing.T) { outboundAck: make(chan []byte, 8), closeCh: make(chan struct{}), writerDone: make(chan struct{}), - ackWaiters: make(map[uint32]chan uint32), + acks: common.NewAckRegistry(), } done := make(chan error, 1) diff --git a/internal/transport/videochannel/frame.go b/internal/transport/videochannel/frame.go index 98fdbcb..6e28726 100644 --- a/internal/transport/videochannel/frame.go +++ b/internal/transport/videochannel/frame.go @@ -51,33 +51,6 @@ type transportFrame struct { payload []byte } -type inboundMessage struct { - totalLen uint32 - crc uint32 - frags [][]byte - remain int -} - -func fragmentPayload(data []byte, maxSize int) [][]byte { - if len(data) == 0 { - return [][]byte{{}} - } - - out := make([][]byte, 0, (len(data)+maxSize-1)/maxSize) - for start := 0; start < len(data); start += maxSize { - end := start + maxSize - if end > len(data) { - end = len(data) - } - - chunk := make([]byte, end-start) - copy(chunk, data[start:end]) - out = append(out, chunk) - } - - return out -} - func encodeDataFrameForBinding( role byte, binding uint32, diff --git a/internal/transport/videochannel/frame_extra_test.go b/internal/transport/videochannel/frame_extra_test.go index 075e1b1..5df86f3 100644 --- a/internal/transport/videochannel/frame_extra_test.go +++ b/internal/transport/videochannel/frame_extra_test.go @@ -16,24 +16,6 @@ var ( errVideoFrameBoom = errors.New("boom") ) -func TestFragmentPayload(t *testing.T) { - frags := fragmentPayload([]byte("abcdef"), 2) - want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")} - if len(frags) != len(want) { - t.Fatalf("fragment count = %d, want %d", len(frags), len(want)) - } - for i := range frags { - if !bytes.Equal(frags[i], want[i]) { - t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i]) - } - } - - empty := fragmentPayload(nil, 10) - if len(empty) != 1 || len(empty[0]) != 0 { - t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty) - } -} - func TestDecodeTransportFrameErrorsAndAck(t *testing.T) { tests := []struct { data []byte diff --git a/internal/transport/videochannel/inbound_test.go b/internal/transport/videochannel/inbound_test.go index 46f8a3e..584691f 100644 --- a/internal/transport/videochannel/inbound_test.go +++ b/internal/transport/videochannel/inbound_test.go @@ -4,6 +4,8 @@ import ( "bytes" "hash/crc32" "testing" + + "github.com/openlibrecommunity/olcrtc/internal/transport/common" ) func TestInboundAssemblyAndAck(t *testing.T) { @@ -11,8 +13,7 @@ func TestInboundAssemblyAndAck(t *testing.T) { tr := &streamTransport{ onData: func(data []byte) { got = append([]byte(nil), data...) }, outboundAck: make(chan []byte, 4), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), + reassembler: common.NewReassembler(256), } payload := []byte("hello video") @@ -53,23 +54,10 @@ func TestInboundAssemblyAndAck(t *testing.T) { } } -func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { +func TestInboundRejectsBadCRC(t *testing.T) { tr := &streamTransport{ outboundAck: make(chan []byte, 2), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), - } - - msg, complete := tr.upsertInbound(transportFrame{ - seq: 1, - crc: 1, - totalLen: 3, - fragIdx: 3, - fragTotal: 1, - payload: []byte("bad"), - }) - if msg != nil || complete { - t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete) + reassembler: common.NewReassembler(256), } called := false @@ -85,13 +73,4 @@ func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { if called { t.Fatal("handleInboundFrame() delivered payload with bad crc") } - - msg = &inboundMessage{ - totalLen: 3, - crc: crc32.ChecksumIEEE([]byte("abcdef")), - frags: [][]byte{[]byte("abc"), []byte("def")}, - } - if got := tr.assembleMessage(msg); string(got) != "abc" { - t.Fatalf("assembleMessage() = %q, want abc", got) - } } diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go index 5bb5288..8974e47 100644 --- a/internal/transport/videochannel/transport.go +++ b/internal/transport/videochannel/transport.go @@ -3,8 +3,6 @@ package videochannel import ( "context" - "crypto/rand" - "encoding/hex" "errors" "fmt" "hash/crc32" @@ -16,6 +14,7 @@ import ( enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/common" "github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4/pkg/media" "github.com/pion/webrtc/v4/pkg/media/samplebuilder" @@ -72,11 +71,8 @@ type streamTransport struct { writerUp atomic.Bool sendMu sync.Mutex startWriter sync.Once - ackMu sync.Mutex - ackWaiters map[uint32]chan uint32 - recvMu sync.Mutex - inbound map[uint32]*inboundMessage - delivered map[uint32]uint32 + acks *common.AckRegistry + reassembler *common.Reassembler videoW int videoH int videoFPS int @@ -129,7 +125,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) // Stream/track IDs must be unique per peer: Jitsi/Jicofo keys participant // sources by msid (stream-id+track-id) and rejects a session-accept whose // msid collides with one already in the conference. - track, err := webrtc.NewTrackLocalStaticSample(codec.capability, "videochannel-"+randomID(), "olcrtc-"+randomID()) + track, err := webrtc.NewTrackLocalStaticSample(codec.capability, "videochannel-"+common.RandomID(), "olcrtc-"+common.RandomID()) if err != nil { return nil, fmt.Errorf("create local video track: %w", err) } @@ -159,9 +155,8 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) closeCh: make(chan struct{}), writerDone: make(chan struct{}), decoders: make(map[*ffmpegDecoder]struct{}), - ackWaiters: make(map[uint32]chan uint32), - inbound: make(map[uint32]*inboundMessage), - delivered: make(map[uint32]uint32), + acks: common.NewAckRegistry(), + reassembler: common.NewReassembler(256), videoW: opts.Width, videoH: opts.Height, videoFPS: opts.FPS, @@ -232,17 +227,9 @@ func (p *streamTransport) Send(data []byte) error { seq := p.nextSeq.Add(1) crc := crc32.ChecksumIEEE(data) - fragments := fragmentPayload(data, p.videoQRSize) - waiter := make(chan uint32, 1) - - p.ackMu.Lock() - p.ackWaiters[seq] = waiter - p.ackMu.Unlock() - defer func() { - p.ackMu.Lock() - delete(p.ackWaiters, seq) - p.ackMu.Unlock() - }() + fragments := common.FragmentPayload(data, p.videoQRSize) + waiter := p.acks.Register(seq) + defer p.acks.Unregister(seq) for range maxSendAttempts { for idx, fragment := range fragments { @@ -576,72 +563,26 @@ func (p *streamTransport) handleFrame(frame []byte) { } } -func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) { - msg, ok := p.inbound[frame.seq] - if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { - msg = &inboundMessage{ - totalLen: frame.totalLen, - crc: frame.crc, - frags: make([][]byte, frame.fragTotal), - remain: int(frame.fragTotal), - } - p.inbound[frame.seq] = msg - } - if int(frame.fragIdx) >= len(msg.frags) { - return nil, false - } - if msg.frags[frame.fragIdx] == nil { - chunk := make([]byte, len(frame.payload)) - copy(chunk, frame.payload) - msg.frags[frame.fragIdx] = chunk - msg.remain-- - } - return msg, msg.remain == 0 -} - -func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte { - data := make([]byte, 0, msg.totalLen) - for _, frag := range msg.frags { - data = append(data, frag...) - } - if uint32(len(data)) > msg.totalLen { //nolint:gosec // G115: bounded conversion verified by surrounding logic - data = data[:msg.totalLen] - } - return data -} - func (p *streamTransport) handleInboundFrame(frame transportFrame) { - p.recvMu.Lock() - if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { - p.recvMu.Unlock() + result, data := p.reassembler.Push(common.Fragment{ + Seq: frame.seq, + CRC: frame.crc, + TotalLen: frame.totalLen, + FragIdx: frame.fragIdx, + FragTotal: frame.fragTotal, + Payload: frame.payload, + }) + switch result { + case common.ResultDuplicate: p.sendAck(frame.seq, frame.crc) - return + case common.ResultDelivered: + if p.onData != nil { + p.onData(data) + } + p.sendAck(frame.seq, frame.crc) + default: + // Partial or Ignore: do nothing. } - - msg, complete := p.upsertInbound(frame) - if msg == nil || !complete { - p.recvMu.Unlock() - return - } - - delete(p.inbound, frame.seq) - data := p.assembleMessage(msg) - - if crc32.ChecksumIEEE(data) != msg.crc { - p.recvMu.Unlock() - return - } - - if len(p.delivered) > 256 { - p.delivered = make(map[uint32]uint32) - } - p.delivered[frame.seq] = msg.crc - p.recvMu.Unlock() - - if p.onData != nil { - p.onData(data) - } - p.sendAck(frame.seq, frame.crc) } func (p *streamTransport) sendAck(seq, crc uint32) { @@ -649,29 +590,7 @@ func (p *streamTransport) sendAck(seq, crc uint32) { } func (p *streamTransport) resolveAck(seq, crc uint32) { - p.ackMu.Lock() - waiter := p.ackWaiters[seq] - p.ackMu.Unlock() - - if waiter == nil { - return - } - - select { - case waiter <- crc: - default: - } -} - -// randomID returns 8 random hex characters for use as a per-peer suffix on -// track and stream IDs. Required for Jitsi: msid collisions between -// participants cause Jicofo to reject session-accept. -func randomID() string { - var b [4]byte - if _, err := rand.Read(b[:]); err != nil { - return fmt.Sprintf("%08x", time.Now().UnixNano()) - } - return hex.EncodeToString(b[:]) + p.acks.Resolve(seq, crc) } func localFrameRole(deviceID string) byte { diff --git a/internal/transport/videochannel/transport_unit_test.go b/internal/transport/videochannel/transport_unit_test.go index e0050a8..35a60f8 100644 --- a/internal/transport/videochannel/transport_unit_test.go +++ b/internal/transport/videochannel/transport_unit_test.go @@ -10,6 +10,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/engine" enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/common" "github.com/pion/webrtc/v4" ) @@ -150,7 +151,7 @@ func TestSendAckAndClosePaths(t *testing.T) { outboundAck: make(chan []byte, 8), closeCh: make(chan struct{}), writerDone: make(chan struct{}), - ackWaiters: make(map[uint32]chan uint32), + acks: common.NewAckRegistry(), videoQRSize: 4, } diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index b0df02d..3eefff3 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -29,7 +29,6 @@ import ( "context" "crypto/rand" "encoding/binary" - "encoding/hex" "errors" "fmt" "hash/crc32" @@ -42,6 +41,7 @@ import ( enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/common" "github.com/pion/rtp" "github.com/pion/rtp/codecs" "github.com/pion/webrtc/v4" @@ -166,8 +166,8 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) MimeType: webrtc.MimeTypeVP8, ClockRate: 90000, }, - "vp8channel-"+randomID(), - "olcrtc-"+randomID(), + "vp8channel-"+common.RandomID(), + "olcrtc-"+common.RandomID(), ) if err != nil { return nil, fmt.Errorf("create local video track: %w", err) @@ -273,17 +273,6 @@ func bindingToken(clientID string) uint32 { return token } -// randomID returns 8 random hex characters for use as a per-peer suffix on -// track and stream IDs. Required for Jitsi: msid collisions between -// participants cause Jicofo to reject session-accept. -func randomID() string { - var b [4]byte - if _, err := rand.Read(b[:]); err != nil { - return fmt.Sprintf("%08x", time.Now().UnixNano()) - } - return hex.EncodeToString(b[:]) -} - func randomEpoch() uint32 { var b [4]byte if _, err := rand.Read(b[:]); err != nil {