Files
supertonic/go/helper.go
T

953 lines
24 KiB
Go

package main
import (
"encoding/json"
"fmt"
"math"
"math/rand"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/go-audio/audio"
"github.com/go-audio/wav"
ort "github.com/yalue/onnxruntime_go"
)
// Config structures
type SpecProcessorConfig struct {
NFFT int `json:"n_fft"`
WinLength int `json:"win_length"`
HopLength int `json:"hop_length"`
NMels int `json:"n_mels"`
Eps float64 `json:"eps"`
NormMean float64 `json:"norm_mean"`
NormStd float64 `json:"norm_std"`
}
type EncoderConfig struct {
SpecProcessor SpecProcessorConfig `json:"spec_processor"`
}
type AEConfig struct {
SampleRate int `json:"sample_rate"`
BaseChunkSize int `json:"base_chunk_size"`
Encoder EncoderConfig `json:"encoder"`
}
type StyleTokenLayerConfig struct {
NStyle int `json:"n_style"`
StyleValueDim int `json:"style_value_dim"`
}
type StyleEncoderConfig struct {
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
}
type ProjOutConfig struct {
Idim int `json:"idim"`
Odim int `json:"odim"`
}
type TextEncoderConfig struct {
ProjOut ProjOutConfig `json:"proj_out"`
}
type TTLConfig struct {
ChunkCompressFactor int `json:"chunk_compress_factor"`
LatentDim int `json:"latent_dim"`
StyleEncoder StyleEncoderConfig `json:"style_encoder"`
TextEncoder TextEncoderConfig `json:"text_encoder"`
}
type DPStyleEncoderConfig struct {
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
}
type DPConfig struct {
LatentDim int `json:"latent_dim"`
ChunkCompressFactor int `json:"chunk_compress_factor"`
StyleEncoder DPStyleEncoderConfig `json:"style_encoder"`
}
type Config struct {
AE AEConfig `json:"ae"`
TTL TTLConfig `json:"ttl"`
DP DPConfig `json:"dp"`
}
// VoiceStyleData holds voice style JSON structure
type VoiceStyleData struct {
StyleTTL struct {
Data [][][]float64 `json:"data"`
Dims []int64 `json:"dims"`
Type string `json:"type"`
} `json:"style_ttl"`
StyleDP struct {
Data [][][]float64 `json:"data"`
Dims []int64 `json:"dims"`
Type string `json:"type"`
} `json:"style_dp"`
}
// UnicodeProcessor for text processing
type UnicodeProcessor struct {
indexer []int64
}
// NewUnicodeProcessor creates a new UnicodeProcessor
func NewUnicodeProcessor(unicodeIndexerPath string) (*UnicodeProcessor, error) {
indexer, err := loadJSONInt64(unicodeIndexerPath)
if err != nil {
return nil, fmt.Errorf("failed to load unicode indexer: %w", err)
}
return &UnicodeProcessor{indexer: indexer}, nil
}
// Call processes text list to text IDs and mask
func (up *UnicodeProcessor) Call(textList []string) ([][]int64, [][][]float64) {
// Preprocess texts
processedTexts := make([]string, len(textList))
for i, text := range textList {
processedTexts[i] = preprocessText(text)
}
// Get text lengths
textLengths := make([]int64, len(processedTexts))
maxLen := 0
for i, text := range processedTexts {
textLengths[i] = int64(len([]rune(text)))
if int(textLengths[i]) > maxLen {
maxLen = int(textLengths[i])
}
}
// Create text IDs
textIDs := make([][]int64, len(processedTexts))
for i, text := range processedTexts {
row := make([]int64, maxLen)
runes := []rune(text)
for j, r := range runes {
unicodeVal := int(r)
if unicodeVal < len(up.indexer) {
row[j] = up.indexer[unicodeVal]
} else {
row[j] = -1
}
}
textIDs[i] = row
}
// Create text mask
textMask := lengthToMask(textLengths, maxLen)
return textIDs, textMask
}
// Text chunking utilities
const maxChunkLength = 300
var abbreviations = []string{
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
}
func chunkText(text string, maxLen int) []string {
if maxLen == 0 {
maxLen = maxChunkLength
}
text = strings.TrimSpace(text)
if text == "" {
return []string{""}
}
// Split by paragraphs
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
var chunks []string
for _, para := range paragraphs {
para = strings.TrimSpace(para)
if para == "" {
continue
}
if len(para) <= maxLen {
chunks = append(chunks, para)
continue
}
// Split by sentences
sentences := splitSentences(para)
var current strings.Builder
currentLen := 0
for _, sentence := range sentences {
sentence = strings.TrimSpace(sentence)
if sentence == "" {
continue
}
sentenceLen := len(sentence)
if sentenceLen > maxLen {
// If sentence is longer than maxLen, split by comma or space
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
// Try splitting by comma
parts := strings.Split(sentence, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
partLen := len(part)
if partLen > maxLen {
// Split by space as last resort
words := strings.Fields(part)
var wordChunk strings.Builder
wordChunkLen := 0
for _, word := range words {
wordLen := len(word)
if wordChunkLen+wordLen+1 > maxLen && wordChunk.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
wordChunk.Reset()
wordChunkLen = 0
}
if wordChunk.Len() > 0 {
wordChunk.WriteString(" ")
wordChunkLen++
}
wordChunk.WriteString(word)
wordChunkLen += wordLen
}
if wordChunk.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
}
} else {
if currentLen+partLen+1 > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
if current.Len() > 0 {
current.WriteString(", ")
currentLen += 2
}
current.WriteString(part)
currentLen += partLen
}
}
continue
}
if currentLen+sentenceLen+1 > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
if current.Len() > 0 {
current.WriteString(" ")
currentLen++
}
current.WriteString(sentence)
currentLen += sentenceLen
}
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
}
}
if len(chunks) == 0 {
return []string{""}
}
return chunks
}
func splitSentences(text string) []string {
// Go's regexp doesn't support lookbehind, so we use a simpler approach
// Split on sentence boundaries and then check if they're abbreviations
re := regexp.MustCompile(`([.!?])\s+`)
// Find all matches
matches := re.FindAllStringIndex(text, -1)
if len(matches) == 0 {
return []string{text}
}
var sentences []string
lastEnd := 0
for _, match := range matches {
// Get the text before the punctuation
beforePunc := text[lastEnd:match[0]]
// Check if this ends with an abbreviation
isAbbrev := false
for _, abbrev := range abbreviations {
if strings.HasSuffix(strings.TrimSpace(beforePunc+text[match[0]:match[0]+1]), abbrev) {
isAbbrev = true
break
}
}
if !isAbbrev {
// This is a real sentence boundary
sentences = append(sentences, text[lastEnd:match[1]])
lastEnd = match[1]
}
}
// Add the remaining text
if lastEnd < len(text) {
sentences = append(sentences, text[lastEnd:])
}
if len(sentences) == 0 {
return []string{text}
}
return sentences
}
// Utility functions
func preprocessText(text string) string {
// Simple normalization (Go doesn't have built-in NFKD normalization)
// For full Unicode normalization, use golang.org/x/text/unicode/norm
return text
}
func lengthToMask(lengths []int64, maxLen int) [][][]float64 {
bsz := len(lengths)
mask := make([][][]float64, bsz)
for i := 0; i < bsz; i++ {
row := make([]float64, maxLen)
for j := 0; j < maxLen; j++ {
if int64(j) < lengths[i] {
row[j] = 1.0
} else {
row[j] = 0.0
}
}
mask[i] = [][]float64{row}
}
return mask
}
func getTextMask(textLengths []int64, maxLen int) [][][]float64 {
return lengthToMask(textLengths, maxLen)
}
func getLatentMask(wavLengths []int64, cfg Config) [][][]float64 {
baseChunkSize := int64(cfg.AE.BaseChunkSize)
chunkCompressFactor := int64(cfg.TTL.ChunkCompressFactor)
latentSize := baseChunkSize * chunkCompressFactor
latentLengths := make([]int64, len(wavLengths))
maxLen := int64(0)
for i, wavLen := range wavLengths {
latentLengths[i] = (wavLen + latentSize - 1) / latentSize
if latentLengths[i] > maxLen {
maxLen = latentLengths[i]
}
}
return lengthToMask(latentLengths, int(maxLen))
}
func writeWavFile(filename string, audioData []float64, sampleRate int) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
// Convert float64 to int
intData := make([]int, len(audioData))
for i, sample := range audioData {
// Clamp to [-1, 1] and convert to 16-bit int
clamped := math.Max(-1.0, math.Min(1.0, sample))
intData[i] = int(clamped * 32767)
}
encoder := wav.NewEncoder(file, sampleRate, 16, 1, 1)
buf := &audio.IntBuffer{
Data: intData,
Format: &audio.Format{SampleRate: sampleRate, NumChannels: 1},
SourceBitDepth: 16,
}
if err := encoder.Write(buf); err != nil {
return err
}
return encoder.Close()
}
// Style holds style tensors
type Style struct {
TtlTensor *ort.Tensor[float32]
DpTensor *ort.Tensor[float32]
}
func (s *Style) Destroy() {
if s.TtlTensor != nil {
s.TtlTensor.Destroy()
}
if s.DpTensor != nil {
s.DpTensor.Destroy()
}
}
// LoadVoiceStyle loads voice style from JSON files
func LoadVoiceStyle(voiceStylePaths []string, verbose bool) (*Style, error) {
bsz := len(voiceStylePaths)
// Read first file to get dimensions
firstData, err := os.ReadFile(voiceStylePaths[0])
if err != nil {
return nil, fmt.Errorf("failed to read voice style file: %w", err)
}
var firstStyle VoiceStyleData
if err := json.Unmarshal(firstData, &firstStyle); err != nil {
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
}
ttlDims := firstStyle.StyleTTL.Dims
dpDims := firstStyle.StyleDP.Dims
ttlDim1 := ttlDims[1]
ttlDim2 := ttlDims[2]
dpDim1 := dpDims[1]
dpDim2 := dpDims[2]
// Pre-allocate arrays with full batch size
ttlSize := int(int64(bsz) * ttlDim1 * ttlDim2)
dpSize := int(int64(bsz) * dpDim1 * dpDim2)
ttlFlat := make([]float32, ttlSize)
dpFlat := make([]float32, dpSize)
// Fill in the data
for i := 0; i < bsz; i++ {
data, err := os.ReadFile(voiceStylePaths[i])
if err != nil {
return nil, fmt.Errorf("failed to read voice style file: %w", err)
}
var voiceStyle VoiceStyleData
if err := json.Unmarshal(data, &voiceStyle); err != nil {
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
}
// Flatten TTL data
ttlOffset := int(int64(i) * ttlDim1 * ttlDim2)
idx := 0
for _, batch := range voiceStyle.StyleTTL.Data {
for _, row := range batch {
for _, val := range row {
ttlFlat[ttlOffset+idx] = float32(val)
idx++
}
}
}
// Flatten DP data
dpOffset := int(int64(i) * dpDim1 * dpDim2)
idx = 0
for _, batch := range voiceStyle.StyleDP.Data {
for _, row := range batch {
for _, val := range row {
dpFlat[dpOffset+idx] = float32(val)
idx++
}
}
}
}
ttlShape := []int64{int64(bsz), ttlDim1, ttlDim2}
dpShape := []int64{int64(bsz), dpDim1, dpDim2}
ttlTensor, err := ort.NewTensor(ttlShape, ttlFlat)
if err != nil {
return nil, fmt.Errorf("failed to create TTL tensor: %w", err)
}
dpTensor, err := ort.NewTensor(dpShape, dpFlat)
if err != nil {
ttlTensor.Destroy()
return nil, fmt.Errorf("failed to create DP tensor: %w", err)
}
if verbose {
fmt.Printf("Loaded %d voice styles\n\n", bsz)
}
return &Style{
TtlTensor: ttlTensor,
DpTensor: dpTensor,
}, nil
}
// TextToSpeech generates speech from text
type TextToSpeech struct {
cfg Config
textProcessor *UnicodeProcessor
dpOrt *ort.DynamicAdvancedSession
textEncOrt *ort.DynamicAdvancedSession
vectorEstOrt *ort.DynamicAdvancedSession
vocoderOrt *ort.DynamicAdvancedSession
SampleRate int
baseChunkSize int
chunkCompress int
ldim int
}
func (tts *TextToSpeech) sampleNoisyLatent(durOnnx []float32) ([][][]float64, [][][]float64) {
bsz := len(durOnnx)
maxDur := float64(0)
for _, d := range durOnnx {
if float64(d) > maxDur {
maxDur = float64(d)
}
}
wavLenMax := maxDur * float64(tts.SampleRate)
wavLengths := make([]int64, bsz)
for i, d := range durOnnx {
wavLengths[i] = int64(float64(d) * float64(tts.SampleRate))
}
chunkSize := tts.baseChunkSize * tts.chunkCompress
latentLen := int((wavLenMax + float64(chunkSize) - 1) / float64(chunkSize))
latentDim := tts.ldim * tts.chunkCompress
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
noisyLatent := make([][][]float64, bsz)
for b := 0; b < bsz; b++ {
batch := make([][]float64, latentDim)
for d := 0; d < latentDim; d++ {
row := make([]float64, latentLen)
for t := 0; t < latentLen; t++ {
// Box-Muller transform for normal distribution
// Add epsilon to avoid log(0)
const eps = 1e-10
u1 := math.Max(eps, rng.Float64())
u2 := rng.Float64()
row[t] = math.Sqrt(-2.0*math.Log(u1)) * math.Cos(2.0*math.Pi*u2)
}
batch[d] = row
}
noisyLatent[b] = batch
}
latentMask := getLatentMask(wavLengths, tts.cfg)
// Apply mask
for b := 0; b < bsz; b++ {
for d := 0; d < latentDim; d++ {
for t := 0; t < latentLen; t++ {
noisyLatent[b][d][t] *= latentMask[b][0][t]
}
}
}
return noisyLatent, latentMask
}
func (tts *TextToSpeech) _infer(textList []string, style *Style, totalStep int) ([]float32, []float32, error) {
bsz := len(textList)
// Process text
textIDs, textMask := tts.textProcessor.Call(textList)
textIDsShape := []int64{int64(bsz), int64(len(textIDs[0]))}
textMaskShape := []int64{int64(bsz), 1, int64(len(textMask[0][0]))}
textIDsTensor := IntArrayToTensor(textIDs, textIDsShape)
defer textIDsTensor.Destroy()
textMaskTensor := ArrayToTensor(textMask, textMaskShape)
defer textMaskTensor.Destroy()
// Predict duration
dpOutputs := []ort.Value{nil}
err := tts.dpOrt.Run(
[]ort.Value{textIDsTensor, style.DpTensor, textMaskTensor},
dpOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run duration predictor: %w", err)
}
durTensor := dpOutputs[0].(*ort.Tensor[float32])
defer durTensor.Destroy()
durOnnx := durTensor.GetData()
// Encode text
textIDsTensor2 := IntArrayToTensor(textIDs, textIDsShape)
defer textIDsTensor2.Destroy()
textEncOutputs := []ort.Value{nil}
err = tts.textEncOrt.Run(
[]ort.Value{textIDsTensor2, style.TtlTensor, textMaskTensor},
textEncOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run text encoder: %w", err)
}
textEmbTensor := textEncOutputs[0].(*ort.Tensor[float32])
defer textEmbTensor.Destroy()
// Sample noisy latent
xt, latentMask := tts.sampleNoisyLatent(durOnnx)
latentShape := []int64{int64(bsz), int64(len(xt[0])), int64(len(xt[0][0]))}
latentMaskShape := []int64{int64(bsz), 1, int64(len(latentMask[0][0]))}
// Prepare constant arrays
totalStepArray := make([]float32, bsz)
for b := 0; b < bsz; b++ {
totalStepArray[b] = float32(totalStep)
}
scalarShape := []int64{int64(bsz)}
totalStepTensor, _ := ort.NewTensor(scalarShape, totalStepArray)
defer totalStepTensor.Destroy()
// Denoising loop
for step := 0; step < totalStep; step++ {
currentStepArray := make([]float32, bsz)
for b := 0; b < bsz; b++ {
currentStepArray[b] = float32(step)
}
currentStepTensor, _ := ort.NewTensor(scalarShape, currentStepArray)
noisyLatentTensor := ArrayToTensor(xt, latentShape)
latentMaskTensor := ArrayToTensor(latentMask, latentMaskShape)
textMaskTensor2 := ArrayToTensor(textMask, textMaskShape)
vectorEstOutputs := []ort.Value{nil}
err = tts.vectorEstOrt.Run(
[]ort.Value{noisyLatentTensor, textEmbTensor, style.TtlTensor, latentMaskTensor, textMaskTensor2,
currentStepTensor, totalStepTensor},
vectorEstOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run vector estimator: %w", err)
}
denoisedTensor := vectorEstOutputs[0].(*ort.Tensor[float32])
denoisedData := denoisedTensor.GetData()
// Update latent
idx := 0
for b := 0; b < bsz; b++ {
for d := 0; d < len(xt[b]); d++ {
for t := 0; t < len(xt[b][d]); t++ {
xt[b][d][t] = float64(denoisedData[idx])
idx++
}
}
}
noisyLatentTensor.Destroy()
latentMaskTensor.Destroy()
textMaskTensor2.Destroy()
currentStepTensor.Destroy()
denoisedTensor.Destroy()
}
// Generate waveform
finalLatentTensor := ArrayToTensor(xt, latentShape)
defer finalLatentTensor.Destroy()
vocoderOutputs := []ort.Value{nil}
err = tts.vocoderOrt.Run(
[]ort.Value{finalLatentTensor},
vocoderOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run vocoder: %w", err)
}
wavBatchTensor := vocoderOutputs[0].(*ort.Tensor[float32])
defer wavBatchTensor.Destroy()
wav := wavBatchTensor.GetData()
return wav, durOnnx, nil
}
// Call synthesizes speech from a single text with automatic chunking
func (tts *TextToSpeech) Call(text string, style *Style, totalStep int, silenceDuration float32) ([]float32, float32, error) {
chunks := chunkText(text, 0)
var wavCat []float32
var durCat float32
for i, chunk := range chunks {
wav, duration, err := tts._infer([]string{chunk}, style, totalStep)
if err != nil {
return nil, 0, err
}
dur := duration[0]
wavLen := int(float32(tts.SampleRate) * dur)
wavChunk := wav[:wavLen]
if i == 0 {
wavCat = wavChunk
durCat = dur
} else {
silenceLen := int(silenceDuration * float32(tts.SampleRate))
silence := make([]float32, silenceLen)
wavCat = append(wavCat, silence...)
wavCat = append(wavCat, wavChunk...)
durCat += silenceDuration + dur
}
}
return wavCat, durCat, nil
}
// Batch synthesizes speech from multiple texts
func (tts *TextToSpeech) Batch(textList []string, style *Style, totalStep int) ([]float32, []float32, error) {
return tts._infer(textList, style, totalStep)
}
func (tts *TextToSpeech) Destroy() {
if tts.dpOrt != nil {
tts.dpOrt.Destroy()
}
if tts.textEncOrt != nil {
tts.textEncOrt.Destroy()
}
if tts.vectorEstOrt != nil {
tts.vectorEstOrt.Destroy()
}
if tts.vocoderOrt != nil {
tts.vocoderOrt.Destroy()
}
}
// LoadTextToSpeech loads TTS components
func LoadTextToSpeech(onnxDir string, useGPU bool, cfg Config) (*TextToSpeech, error) {
if useGPU {
return nil, fmt.Errorf("GPU mode is not supported yet")
}
fmt.Println("Using CPU for inference\n")
// Load models
dpPath := filepath.Join(onnxDir, "duration_predictor.onnx")
textEncPath := filepath.Join(onnxDir, "text_encoder.onnx")
vectorEstPath := filepath.Join(onnxDir, "vector_estimator.onnx")
vocoderPath := filepath.Join(onnxDir, "vocoder.onnx")
dpOrt, err := ort.NewDynamicAdvancedSession(dpPath, []string{"text_ids", "style_dp", "text_mask"},
[]string{"duration"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load duration predictor: %w", err)
}
textEncOrt, err := ort.NewDynamicAdvancedSession(textEncPath, []string{"text_ids", "style_ttl", "text_mask"},
[]string{"text_emb"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load text encoder: %w", err)
}
vectorEstOrt, err := ort.NewDynamicAdvancedSession(vectorEstPath,
[]string{"noisy_latent", "text_emb", "style_ttl", "latent_mask", "text_mask", "current_step", "total_step"},
[]string{"denoised_latent"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load vector estimator: %w", err)
}
vocoderOrt, err := ort.NewDynamicAdvancedSession(vocoderPath, []string{"latent"},
[]string{"wav_tts"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load vocoder: %w", err)
}
// Load text processor
unicodeIndexerPath := filepath.Join(onnxDir, "unicode_indexer.json")
textProcessor, err := NewUnicodeProcessor(unicodeIndexerPath)
if err != nil {
return nil, err
}
textToSpeech := &TextToSpeech{
cfg: cfg,
textProcessor: textProcessor,
dpOrt: dpOrt,
textEncOrt: textEncOrt,
vectorEstOrt: vectorEstOrt,
vocoderOrt: vocoderOrt,
SampleRate: cfg.AE.SampleRate,
baseChunkSize: cfg.AE.BaseChunkSize,
chunkCompress: cfg.TTL.ChunkCompressFactor,
ldim: cfg.TTL.LatentDim,
}
return textToSpeech, nil
}
// InitializeONNXRuntime initializes ONNX Runtime environment
func InitializeONNXRuntime() error {
libPath := os.Getenv("ONNXRUNTIME_LIB_PATH")
if libPath == "" {
libPath = "/usr/local/lib/libonnxruntime.so"
if _, err := os.Stat("/usr/local/lib/libonnxruntime.dylib"); err == nil {
libPath = "/usr/local/lib/libonnxruntime.dylib"
} else if _, err := os.Stat("/usr/lib/libonnxruntime.so"); err == nil {
libPath = "/usr/lib/libonnxruntime.so"
}
}
ort.SetSharedLibraryPath(libPath)
if err := ort.InitializeEnvironment(); err != nil {
return fmt.Errorf("failed to initialize ONNX Runtime: %w\nHint: Set ONNXRUNTIME_LIB_PATH environment variable", err)
}
return nil
}
// sanitizeFilename creates a safe filename from text
func sanitizeFilename(text string, maxLen int) string {
if len(text) > maxLen {
text = text[:maxLen]
}
result := make([]rune, 0, len(text))
for _, r := range text {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
result = append(result, r)
} else {
result = append(result, '_')
}
}
return string(result)
}
// extractWavSegment extracts a single audio segment from batch output
func extractWavSegment(wav []float32, duration float32, sampleRate int, index int, batchSize int) []float64 {
wavLen := int(float64(sampleRate) * float64(duration))
wavPerBatch := len(wav) / batchSize
wavStart := index * wavPerBatch
wavEnd := wavStart + wavLen
if wavEnd > len(wav) {
wavEnd = len(wav)
}
wavOut := make([]float64, wavLen)
for j := 0; j < wavLen && wavStart+j < len(wav); j++ {
wavOut[j] = float64(wav[wavStart+j])
}
return wavOut
}
// Timer measures execution time
func Timer(name string, fn func() interface{}) interface{} {
start := time.Now()
fmt.Printf("%s...\n", name)
result := fn()
elapsed := time.Since(start).Seconds()
fmt.Printf(" -> %s completed in %.2f sec\n", name, elapsed)
return result
}
// LoadCfgs loads configuration from JSON file
func LoadCfgs(onnxDir string) (Config, error) {
cfgPath := filepath.Join(onnxDir, "tts.json")
data, err := os.ReadFile(cfgPath)
if err != nil {
return Config{}, err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
// JSON loading helpers
func loadJSONInt64(filePath string) ([]int64, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var result []int64
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// Tensor conversion utilities
func ArrayToTensor(array [][][]float64, shape []int64) *ort.Tensor[float32] {
// Flatten array
totalSize := int64(1)
for _, dim := range shape {
totalSize *= dim
}
flat := make([]float32, totalSize)
idx := 0
for b := 0; b < len(array); b++ {
for d := 0; d < len(array[b]); d++ {
for t := 0; t < len(array[b][d]); t++ {
flat[idx] = float32(array[b][d][t])
idx++
}
}
}
tensor, err := ort.NewTensor(shape, flat)
if err != nil {
panic(err)
}
return tensor
}
func IntArrayToTensor(array [][]int64, shape []int64) *ort.Tensor[int64] {
// Flatten array
totalSize := int64(1)
for _, dim := range shape {
totalSize *= dim
}
flat := make([]int64, totalSize)
idx := 0
for b := 0; b < len(array); b++ {
for t := 0; t < len(array[b]); t++ {
flat[idx] = array[b][t]
idx++
}
}
tensor, err := ort.NewTensor(shape, flat)
if err != nil {
panic(err)
}
return tensor
}