mirror of
https://github.com/supertone-inc/supertonic.git
synced 2026-06-02 01:38:48 +02:00
830 lines
26 KiB
Rust
830 lines
26 KiB
Rust
// ============================================================================
|
||
// TTS Helper Module - All utility functions and structures
|
||
// ============================================================================
|
||
|
||
use ndarray::{Array, Array3};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json;
|
||
use std::fs::File;
|
||
use std::io::BufReader;
|
||
use std::path::Path;
|
||
use anyhow::{Result, Context};
|
||
use unicode_normalization::UnicodeNormalization;
|
||
use hound::{WavWriter, WavSpec, SampleFormat};
|
||
use rand_distr::{Distribution, Normal};
|
||
use regex::Regex;
|
||
|
||
// ============================================================================
|
||
// Configuration Structures
|
||
// ============================================================================
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct Config {
|
||
pub ae: AEConfig,
|
||
pub ttl: TTLConfig,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct AEConfig {
|
||
pub sample_rate: i32,
|
||
pub base_chunk_size: i32,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct TTLConfig {
|
||
pub chunk_compress_factor: i32,
|
||
pub latent_dim: i32,
|
||
}
|
||
|
||
/// Load configuration from JSON file
|
||
pub fn load_cfgs<P: AsRef<Path>>(onnx_dir: P) -> Result<Config> {
|
||
let cfg_path = onnx_dir.as_ref().join("tts.json");
|
||
let file = File::open(cfg_path)?;
|
||
let reader = BufReader::new(file);
|
||
let cfgs: Config = serde_json::from_reader(reader)?;
|
||
Ok(cfgs)
|
||
}
|
||
|
||
// ============================================================================
|
||
// Voice Style Data Structure
|
||
// ============================================================================
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct VoiceStyleData {
|
||
pub style_ttl: StyleComponent,
|
||
pub style_dp: StyleComponent,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct StyleComponent {
|
||
pub data: Vec<Vec<Vec<f32>>>,
|
||
pub dims: Vec<usize>,
|
||
#[serde(rename = "type")]
|
||
pub dtype: String,
|
||
}
|
||
|
||
// ============================================================================
|
||
// Unicode Text Processor
|
||
// ============================================================================
|
||
|
||
pub struct UnicodeProcessor {
|
||
indexer: Vec<i64>,
|
||
}
|
||
|
||
impl UnicodeProcessor {
|
||
pub fn new<P: AsRef<Path>>(unicode_indexer_json_path: P) -> Result<Self> {
|
||
let file = File::open(unicode_indexer_json_path)?;
|
||
let reader = BufReader::new(file);
|
||
let indexer: Vec<i64> = serde_json::from_reader(reader)?;
|
||
Ok(UnicodeProcessor { indexer })
|
||
}
|
||
|
||
pub fn call(&self, text_list: &[String]) -> (Vec<Vec<i64>>, Array3<f32>) {
|
||
let processed_texts: Vec<String> = text_list
|
||
.iter()
|
||
.map(|t| preprocess_text(t))
|
||
.collect();
|
||
|
||
let text_ids_lengths: Vec<usize> = processed_texts
|
||
.iter()
|
||
.map(|t| t.chars().count())
|
||
.collect();
|
||
|
||
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
|
||
|
||
let mut text_ids = Vec::new();
|
||
for text in &processed_texts {
|
||
let mut row = vec![0i64; max_len];
|
||
let unicode_vals = text_to_unicode_values(text);
|
||
for (j, &val) in unicode_vals.iter().enumerate() {
|
||
if val < self.indexer.len() {
|
||
row[j] = self.indexer[val];
|
||
} else {
|
||
row[j] = -1;
|
||
}
|
||
}
|
||
text_ids.push(row);
|
||
}
|
||
|
||
let text_mask = get_text_mask(&text_ids_lengths);
|
||
|
||
(text_ids, text_mask)
|
||
}
|
||
}
|
||
|
||
pub fn preprocess_text(text: &str) -> String {
|
||
// TODO: Need advanced normalizer for better performance
|
||
let mut text: String = text.nfkd().collect();
|
||
|
||
// FIXME: this should be fixed for non-English languages
|
||
|
||
// Remove emojis (wide Unicode range)
|
||
let emoji_pattern = Regex::new(r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+").unwrap();
|
||
text = emoji_pattern.replace_all(&text, "").to_string();
|
||
|
||
// Replace various dashes and symbols
|
||
let replacements = [
|
||
("–", "-"), // en dash
|
||
("‑", "-"), // non-breaking hyphen
|
||
("—", "-"), // em dash
|
||
("¯", " "), // macron
|
||
("_", " "), // underscore
|
||
("\u{201C}", "\""), // left double quote
|
||
("\u{201D}", "\""), // right double quote
|
||
("\u{2018}", "'"), // left single quote
|
||
("\u{2019}", "'"), // right single quote
|
||
("´", "'"), // acute accent
|
||
("`", "'"), // grave accent
|
||
("[", " "), // left bracket
|
||
("]", " "), // right bracket
|
||
("|", " "), // vertical bar
|
||
("/", " "), // slash
|
||
("#", " "), // hash
|
||
("→", " "), // right arrow
|
||
("←", " "), // left arrow
|
||
];
|
||
|
||
for (from, to) in &replacements {
|
||
text = text.replace(from, to);
|
||
}
|
||
|
||
// Remove combining diacritics // FIXME: this should be fixed for non-English languages
|
||
let diacritics_pattern = Regex::new(r"[\u{0302}\u{0303}\u{0304}\u{0305}\u{0306}\u{0307}\u{0308}\u{030A}\u{030B}\u{030C}\u{0327}\u{0328}\u{0329}\u{032A}\u{032B}\u{032C}\u{032D}\u{032E}\u{032F}]").unwrap();
|
||
text = diacritics_pattern.replace_all(&text, "").to_string();
|
||
|
||
// Remove special symbols
|
||
let special_symbols = ["♥", "☆", "♡", "©", "\\"];
|
||
for symbol in &special_symbols {
|
||
text = text.replace(symbol, "");
|
||
}
|
||
|
||
// Replace known expressions
|
||
let expr_replacements = [
|
||
("@", " at "),
|
||
("e.g.,", "for example, "),
|
||
("i.e.,", "that is, "),
|
||
];
|
||
|
||
for (from, to) in &expr_replacements {
|
||
text = text.replace(from, to);
|
||
}
|
||
|
||
// Fix spacing around punctuation
|
||
text = Regex::new(r" ,").unwrap().replace_all(&text, ",").to_string();
|
||
text = Regex::new(r" \.").unwrap().replace_all(&text, ".").to_string();
|
||
text = Regex::new(r" !").unwrap().replace_all(&text, "!").to_string();
|
||
text = Regex::new(r" \?").unwrap().replace_all(&text, "?").to_string();
|
||
text = Regex::new(r" ;").unwrap().replace_all(&text, ";").to_string();
|
||
text = Regex::new(r" :").unwrap().replace_all(&text, ":").to_string();
|
||
text = Regex::new(r" '").unwrap().replace_all(&text, "'").to_string();
|
||
|
||
// Remove duplicate quotes
|
||
while text.contains("\"\"") {
|
||
text = text.replace("\"\"", "\"");
|
||
}
|
||
while text.contains("''") {
|
||
text = text.replace("''", "'");
|
||
}
|
||
while text.contains("``") {
|
||
text = text.replace("``", "`");
|
||
}
|
||
|
||
// Remove extra spaces
|
||
text = Regex::new(r"\s+").unwrap().replace_all(&text, " ").to_string();
|
||
text = text.trim().to_string();
|
||
|
||
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
|
||
if !text.is_empty() {
|
||
let ends_with_punct = Regex::new(r#"[.!?;:,'"\u{201C}\u{201D}\u{2018}\u{2019})\]}…。」』】〉》›»]$"#).unwrap();
|
||
if !ends_with_punct.is_match(&text) {
|
||
text.push('.');
|
||
}
|
||
}
|
||
|
||
text
|
||
}
|
||
|
||
pub fn text_to_unicode_values(text: &str) -> Vec<usize> {
|
||
text.chars().map(|c| c as usize).collect()
|
||
}
|
||
|
||
pub fn length_to_mask(lengths: &[usize], max_len: Option<usize>) -> Array3<f32> {
|
||
let bsz = lengths.len();
|
||
let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0));
|
||
|
||
let mut mask = Array3::<f32>::zeros((bsz, 1, max_len));
|
||
for (i, &len) in lengths.iter().enumerate() {
|
||
for j in 0..len.min(max_len) {
|
||
mask[[i, 0, j]] = 1.0;
|
||
}
|
||
}
|
||
mask
|
||
}
|
||
|
||
pub fn get_text_mask(text_ids_lengths: &[usize]) -> Array3<f32> {
|
||
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
|
||
length_to_mask(text_ids_lengths, Some(max_len))
|
||
}
|
||
|
||
/// Sample noisy latent from normal distribution and apply mask
|
||
pub fn sample_noisy_latent(
|
||
duration: &[f32],
|
||
sample_rate: i32,
|
||
base_chunk_size: i32,
|
||
chunk_compress: i32,
|
||
latent_dim: i32,
|
||
) -> (Array3<f32>, Array3<f32>) {
|
||
let bsz = duration.len();
|
||
let max_dur = duration.iter().fold(0.0f32, |a, &b| a.max(b));
|
||
|
||
let wav_len_max = (max_dur * sample_rate as f32) as usize;
|
||
let wav_lengths: Vec<usize> = duration
|
||
.iter()
|
||
.map(|&d| (d * sample_rate as f32) as usize)
|
||
.collect();
|
||
|
||
let chunk_size = (base_chunk_size * chunk_compress) as usize;
|
||
let latent_len = (wav_len_max + chunk_size - 1) / chunk_size;
|
||
let latent_dim_val = (latent_dim * chunk_compress) as usize;
|
||
|
||
let mut noisy_latent = Array3::<f32>::zeros((bsz, latent_dim_val, latent_len));
|
||
|
||
let normal = Normal::new(0.0, 1.0).unwrap();
|
||
let mut rng = rand::thread_rng();
|
||
|
||
for b in 0..bsz {
|
||
for d in 0..latent_dim_val {
|
||
for t in 0..latent_len {
|
||
noisy_latent[[b, d, t]] = normal.sample(&mut rng);
|
||
}
|
||
}
|
||
}
|
||
|
||
let latent_lengths: Vec<usize> = wav_lengths
|
||
.iter()
|
||
.map(|&len| (len + chunk_size - 1) / chunk_size)
|
||
.collect();
|
||
|
||
let latent_mask = length_to_mask(&latent_lengths, Some(latent_len));
|
||
|
||
// Apply mask
|
||
for b in 0..bsz {
|
||
for d in 0..latent_dim_val {
|
||
for t in 0..latent_len {
|
||
noisy_latent[[b, d, t]] *= latent_mask[[b, 0, t]];
|
||
}
|
||
}
|
||
}
|
||
|
||
(noisy_latent, latent_mask)
|
||
}
|
||
|
||
// ============================================================================
|
||
// WAV File I/O
|
||
// ============================================================================
|
||
|
||
pub fn write_wav_file<P: AsRef<Path>>(
|
||
filename: P,
|
||
audio_data: &[f32],
|
||
sample_rate: i32,
|
||
) -> Result<()> {
|
||
let spec = WavSpec {
|
||
channels: 1,
|
||
sample_rate: sample_rate as u32,
|
||
bits_per_sample: 16,
|
||
sample_format: SampleFormat::Int,
|
||
};
|
||
|
||
let mut writer = WavWriter::create(filename, spec)?;
|
||
|
||
for &sample in audio_data {
|
||
let clamped = sample.max(-1.0).min(1.0);
|
||
let val = (clamped * 32767.0) as i16;
|
||
writer.write_sample(val)?;
|
||
}
|
||
|
||
writer.finalize()?;
|
||
Ok(())
|
||
}
|
||
|
||
// ============================================================================
|
||
// Text Chunking
|
||
// ============================================================================
|
||
|
||
const MAX_CHUNK_LENGTH: usize = 300;
|
||
|
||
const ABBREVIATIONS: &[&str] = &[
|
||
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
|
||
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
|
||
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
|
||
];
|
||
|
||
pub fn chunk_text(text: &str, max_len: Option<usize>) -> Vec<String> {
|
||
let max_len = max_len.unwrap_or(MAX_CHUNK_LENGTH);
|
||
let text = text.trim();
|
||
|
||
if text.is_empty() {
|
||
return vec![String::new()];
|
||
}
|
||
|
||
// Split by paragraphs
|
||
let para_re = Regex::new(r"\n\s*\n").unwrap();
|
||
let paragraphs: Vec<&str> = para_re.split(text).collect();
|
||
let mut chunks = Vec::new();
|
||
|
||
for para in paragraphs {
|
||
let para = para.trim();
|
||
if para.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
if para.len() <= max_len {
|
||
chunks.push(para.to_string());
|
||
continue;
|
||
}
|
||
|
||
// Split by sentences
|
||
let sentences = split_sentences(para);
|
||
let mut current = String::new();
|
||
let mut current_len = 0;
|
||
|
||
for sentence in sentences {
|
||
let sentence = sentence.trim();
|
||
if sentence.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let sentence_len = sentence.len();
|
||
if sentence_len > max_len {
|
||
// If sentence is longer than max_len, split by comma or space
|
||
if !current.is_empty() {
|
||
chunks.push(current.trim().to_string());
|
||
current.clear();
|
||
current_len = 0;
|
||
}
|
||
|
||
// Try splitting by comma
|
||
let parts: Vec<&str> = sentence.split(',').collect();
|
||
for part in parts {
|
||
let part = part.trim();
|
||
if part.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let part_len = part.len();
|
||
if part_len > max_len {
|
||
// Split by space as last resort
|
||
let words: Vec<&str> = part.split_whitespace().collect();
|
||
let mut word_chunk = String::new();
|
||
let mut word_chunk_len = 0;
|
||
|
||
for word in words {
|
||
let word_len = word.len();
|
||
if word_chunk_len + word_len + 1 > max_len && !word_chunk.is_empty() {
|
||
chunks.push(word_chunk.trim().to_string());
|
||
word_chunk.clear();
|
||
word_chunk_len = 0;
|
||
}
|
||
|
||
if !word_chunk.is_empty() {
|
||
word_chunk.push(' ');
|
||
word_chunk_len += 1;
|
||
}
|
||
word_chunk.push_str(word);
|
||
word_chunk_len += word_len;
|
||
}
|
||
|
||
if !word_chunk.is_empty() {
|
||
chunks.push(word_chunk.trim().to_string());
|
||
}
|
||
} else {
|
||
if current_len + part_len + 1 > max_len && !current.is_empty() {
|
||
chunks.push(current.trim().to_string());
|
||
current.clear();
|
||
current_len = 0;
|
||
}
|
||
|
||
if !current.is_empty() {
|
||
current.push_str(", ");
|
||
current_len += 2;
|
||
}
|
||
current.push_str(part);
|
||
current_len += part_len;
|
||
}
|
||
}
|
||
continue;
|
||
}
|
||
|
||
if current_len + sentence_len + 1 > max_len && !current.is_empty() {
|
||
chunks.push(current.trim().to_string());
|
||
current.clear();
|
||
current_len = 0;
|
||
}
|
||
|
||
if !current.is_empty() {
|
||
current.push(' ');
|
||
current_len += 1;
|
||
}
|
||
current.push_str(sentence);
|
||
current_len += sentence_len;
|
||
}
|
||
|
||
if !current.is_empty() {
|
||
chunks.push(current.trim().to_string());
|
||
}
|
||
}
|
||
|
||
if chunks.is_empty() {
|
||
vec![String::new()]
|
||
} else {
|
||
chunks
|
||
}
|
||
}
|
||
|
||
fn split_sentences(text: &str) -> Vec<String> {
|
||
// Rust's regex doesn't support lookbehind, so we use a simpler approach
|
||
// Split on sentence boundaries and then check if they're abbreviations
|
||
let re = Regex::new(r"([.!?])\s+").unwrap();
|
||
|
||
// Find all matches
|
||
let matches: Vec<_> = re.find_iter(text).collect();
|
||
if matches.is_empty() {
|
||
return vec![text.to_string()];
|
||
}
|
||
|
||
let mut sentences = Vec::new();
|
||
let mut last_end = 0;
|
||
|
||
for m in matches {
|
||
// Get the text before the punctuation
|
||
let before_punc = &text[last_end..m.start()];
|
||
|
||
// Check if this ends with an abbreviation
|
||
let mut is_abbrev = false;
|
||
for abbrev in ABBREVIATIONS {
|
||
let combined = format!("{}{}", before_punc.trim(), &text[m.start()..m.start()+1]);
|
||
if combined.ends_with(abbrev) {
|
||
is_abbrev = true;
|
||
break;
|
||
}
|
||
}
|
||
|
||
if !is_abbrev {
|
||
// This is a real sentence boundary
|
||
sentences.push(text[last_end..m.end()].to_string());
|
||
last_end = m.end();
|
||
}
|
||
}
|
||
|
||
// Add the remaining text
|
||
if last_end < text.len() {
|
||
sentences.push(text[last_end..].to_string());
|
||
}
|
||
|
||
if sentences.is_empty() {
|
||
vec![text.to_string()]
|
||
} else {
|
||
sentences
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Utility Functions
|
||
// ============================================================================
|
||
|
||
pub fn timer<F, T>(name: &str, f: F) -> Result<T>
|
||
where
|
||
F: FnOnce() -> Result<T>,
|
||
{
|
||
let start = std::time::Instant::now();
|
||
println!("{}...", name);
|
||
let result = f()?;
|
||
let elapsed = start.elapsed().as_secs_f64();
|
||
println!(" -> {} completed in {:.2} sec", name, elapsed);
|
||
Ok(result)
|
||
}
|
||
|
||
pub fn sanitize_filename(text: &str, max_len: usize) -> String {
|
||
let text = if text.len() > max_len {
|
||
&text[..max_len]
|
||
} else {
|
||
text
|
||
};
|
||
|
||
text.chars()
|
||
.map(|c| {
|
||
if c.is_ascii_alphanumeric() {
|
||
c
|
||
} else {
|
||
'_'
|
||
}
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
// ============================================================================
|
||
// ONNX Runtime Integration
|
||
// ============================================================================
|
||
|
||
use ort::{
|
||
session::Session,
|
||
value::Value,
|
||
};
|
||
|
||
pub struct Style {
|
||
pub ttl: Array3<f32>,
|
||
pub dp: Array3<f32>,
|
||
}
|
||
|
||
pub struct TextToSpeech {
|
||
cfgs: Config,
|
||
text_processor: UnicodeProcessor,
|
||
dp_ort: Session,
|
||
text_enc_ort: Session,
|
||
vector_est_ort: Session,
|
||
vocoder_ort: Session,
|
||
pub sample_rate: i32,
|
||
}
|
||
|
||
impl TextToSpeech {
|
||
pub fn new(
|
||
cfgs: Config,
|
||
text_processor: UnicodeProcessor,
|
||
dp_ort: Session,
|
||
text_enc_ort: Session,
|
||
vector_est_ort: Session,
|
||
vocoder_ort: Session,
|
||
) -> Self {
|
||
let sample_rate = cfgs.ae.sample_rate;
|
||
TextToSpeech {
|
||
cfgs,
|
||
text_processor,
|
||
dp_ort,
|
||
text_enc_ort,
|
||
vector_est_ort,
|
||
vocoder_ort,
|
||
sample_rate,
|
||
}
|
||
}
|
||
|
||
fn _infer(
|
||
&mut self,
|
||
text_list: &[String],
|
||
style: &Style,
|
||
total_step: usize,
|
||
speed: f32,
|
||
) -> Result<(Vec<f32>, Vec<f32>)> {
|
||
let bsz = text_list.len();
|
||
|
||
// Process text
|
||
let (text_ids, text_mask) = self.text_processor.call(text_list);
|
||
|
||
let text_ids_array = {
|
||
let text_ids_shape = (bsz, text_ids[0].len());
|
||
let mut flat = Vec::new();
|
||
for row in &text_ids {
|
||
flat.extend_from_slice(row);
|
||
}
|
||
Array::from_shape_vec(text_ids_shape, flat)?
|
||
};
|
||
|
||
let text_ids_value = Value::from_array(text_ids_array)?;
|
||
let text_mask_value = Value::from_array(text_mask.clone())?;
|
||
let style_dp_value = Value::from_array(style.dp.clone())?;
|
||
|
||
// Predict duration
|
||
let dp_outputs = self.dp_ort.run(ort::inputs!{
|
||
"text_ids" => &text_ids_value,
|
||
"style_dp" => &style_dp_value,
|
||
"text_mask" => &text_mask_value
|
||
})?;
|
||
|
||
let (_, duration_data) = dp_outputs["duration"].try_extract_tensor::<f32>()?;
|
||
let mut duration: Vec<f32> = duration_data.to_vec();
|
||
|
||
// Apply speed factor to duration
|
||
for dur in duration.iter_mut() {
|
||
*dur /= speed;
|
||
}
|
||
|
||
// Encode text
|
||
let style_ttl_value = Value::from_array(style.ttl.clone())?;
|
||
let text_enc_outputs = self.text_enc_ort.run(ort::inputs!{
|
||
"text_ids" => &text_ids_value,
|
||
"style_ttl" => &style_ttl_value,
|
||
"text_mask" => &text_mask_value
|
||
})?;
|
||
|
||
let (text_emb_shape, text_emb_data) = text_enc_outputs["text_emb"].try_extract_tensor::<f32>()?;
|
||
let text_emb = Array3::from_shape_vec(
|
||
(text_emb_shape[0] as usize, text_emb_shape[1] as usize, text_emb_shape[2] as usize),
|
||
text_emb_data.to_vec()
|
||
)?;
|
||
|
||
// Sample noisy latent
|
||
let (mut xt, latent_mask) = sample_noisy_latent(
|
||
&duration,
|
||
self.sample_rate,
|
||
self.cfgs.ae.base_chunk_size,
|
||
self.cfgs.ttl.chunk_compress_factor,
|
||
self.cfgs.ttl.latent_dim,
|
||
);
|
||
|
||
// Prepare constant arrays
|
||
let total_step_array = Array::from_elem(bsz, total_step as f32);
|
||
|
||
// Denoising loop
|
||
for step in 0..total_step {
|
||
let current_step_array = Array::from_elem(bsz, step as f32);
|
||
|
||
let xt_value = Value::from_array(xt.clone())?;
|
||
let text_emb_value = Value::from_array(text_emb.clone())?;
|
||
let latent_mask_value = Value::from_array(latent_mask.clone())?;
|
||
let text_mask_value2 = Value::from_array(text_mask.clone())?;
|
||
let current_step_value = Value::from_array(current_step_array)?;
|
||
let total_step_value = Value::from_array(total_step_array.clone())?;
|
||
|
||
let vector_est_outputs = self.vector_est_ort.run(ort::inputs!{
|
||
"noisy_latent" => &xt_value,
|
||
"text_emb" => &text_emb_value,
|
||
"style_ttl" => &style_ttl_value,
|
||
"latent_mask" => &latent_mask_value,
|
||
"text_mask" => &text_mask_value2,
|
||
"current_step" => ¤t_step_value,
|
||
"total_step" => &total_step_value
|
||
})?;
|
||
|
||
let (denoised_shape, denoised_data) = vector_est_outputs["denoised_latent"].try_extract_tensor::<f32>()?;
|
||
xt = Array3::from_shape_vec(
|
||
(denoised_shape[0] as usize, denoised_shape[1] as usize, denoised_shape[2] as usize),
|
||
denoised_data.to_vec()
|
||
)?;
|
||
}
|
||
|
||
// Generate waveform
|
||
let final_latent_value = Value::from_array(xt)?;
|
||
let vocoder_outputs = self.vocoder_ort.run(ort::inputs!{
|
||
"latent" => &final_latent_value
|
||
})?;
|
||
|
||
let (_, wav_data) = vocoder_outputs["wav_tts"].try_extract_tensor::<f32>()?;
|
||
let wav: Vec<f32> = wav_data.to_vec();
|
||
|
||
Ok((wav, duration))
|
||
}
|
||
|
||
pub fn call(
|
||
&mut self,
|
||
text: &str,
|
||
style: &Style,
|
||
total_step: usize,
|
||
speed: f32,
|
||
silence_duration: f32,
|
||
) -> Result<(Vec<f32>, f32)> {
|
||
let chunks = chunk_text(text, None);
|
||
|
||
let mut wav_cat: Vec<f32> = Vec::new();
|
||
let mut dur_cat: f32 = 0.0;
|
||
|
||
for (i, chunk) in chunks.iter().enumerate() {
|
||
let (wav, duration) = self._infer(&[chunk.clone()], style, total_step, speed)?;
|
||
|
||
let dur = duration[0];
|
||
let wav_len = (self.sample_rate as f32 * dur) as usize;
|
||
let wav_chunk = &wav[..wav_len.min(wav.len())];
|
||
|
||
if i == 0 {
|
||
wav_cat.extend_from_slice(wav_chunk);
|
||
dur_cat = dur;
|
||
} else {
|
||
let silence_len = (silence_duration * self.sample_rate as f32) as usize;
|
||
let silence = vec![0.0f32; silence_len];
|
||
|
||
wav_cat.extend_from_slice(&silence);
|
||
wav_cat.extend_from_slice(wav_chunk);
|
||
dur_cat += silence_duration + dur;
|
||
}
|
||
}
|
||
|
||
Ok((wav_cat, dur_cat))
|
||
}
|
||
|
||
pub fn batch(
|
||
&mut self,
|
||
text_list: &[String],
|
||
style: &Style,
|
||
total_step: usize,
|
||
speed: f32,
|
||
) -> Result<(Vec<f32>, Vec<f32>)> {
|
||
self._infer(text_list, style, total_step, speed)
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Component Loading Functions
|
||
// ============================================================================
|
||
|
||
/// Load voice style from JSON files
|
||
pub fn load_voice_style(voice_style_paths: &[String], verbose: bool) -> Result<Style> {
|
||
let bsz = voice_style_paths.len();
|
||
|
||
// Read first file to get dimensions
|
||
let first_file = File::open(&voice_style_paths[0])
|
||
.context("Failed to open voice style file")?;
|
||
let first_reader = BufReader::new(first_file);
|
||
let first_data: VoiceStyleData = serde_json::from_reader(first_reader)?;
|
||
|
||
let ttl_dims = &first_data.style_ttl.dims;
|
||
let dp_dims = &first_data.style_dp.dims;
|
||
|
||
let ttl_dim1 = ttl_dims[1];
|
||
let ttl_dim2 = ttl_dims[2];
|
||
let dp_dim1 = dp_dims[1];
|
||
let dp_dim2 = dp_dims[2];
|
||
|
||
// Pre-allocate arrays with full batch size
|
||
let ttl_size = bsz * ttl_dim1 * ttl_dim2;
|
||
let dp_size = bsz * dp_dim1 * dp_dim2;
|
||
let mut ttl_flat = vec![0.0f32; ttl_size];
|
||
let mut dp_flat = vec![0.0f32; dp_size];
|
||
|
||
// Fill in the data
|
||
for (i, path) in voice_style_paths.iter().enumerate() {
|
||
let file = File::open(path).context("Failed to open voice style file")?;
|
||
let reader = BufReader::new(file);
|
||
let data: VoiceStyleData = serde_json::from_reader(reader)?;
|
||
|
||
// Flatten TTL data
|
||
let ttl_offset = i * ttl_dim1 * ttl_dim2;
|
||
let mut idx = 0;
|
||
for batch in &data.style_ttl.data {
|
||
for row in batch {
|
||
for &val in row {
|
||
ttl_flat[ttl_offset + idx] = val;
|
||
idx += 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Flatten DP data
|
||
let dp_offset = i * dp_dim1 * dp_dim2;
|
||
idx = 0;
|
||
for batch in &data.style_dp.data {
|
||
for row in batch {
|
||
for &val in row {
|
||
dp_flat[dp_offset + idx] = val;
|
||
idx += 1;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
let ttl_style = Array3::from_shape_vec((bsz, ttl_dim1, ttl_dim2), ttl_flat)?;
|
||
let dp_style = Array3::from_shape_vec((bsz, dp_dim1, dp_dim2), dp_flat)?;
|
||
|
||
if verbose {
|
||
println!("Loaded {} voice styles\n", bsz);
|
||
}
|
||
|
||
Ok(Style {
|
||
ttl: ttl_style,
|
||
dp: dp_style,
|
||
})
|
||
}
|
||
|
||
/// Load TTS components
|
||
pub fn load_text_to_speech(onnx_dir: &str, use_gpu: bool) -> Result<TextToSpeech> {
|
||
if use_gpu {
|
||
anyhow::bail!("GPU mode is not supported yet");
|
||
}
|
||
println!("Using CPU for inference\n");
|
||
|
||
let cfgs = load_cfgs(onnx_dir)?;
|
||
|
||
let dp_path = format!("{}/duration_predictor.onnx", onnx_dir);
|
||
let text_enc_path = format!("{}/text_encoder.onnx", onnx_dir);
|
||
let vector_est_path = format!("{}/vector_estimator.onnx", onnx_dir);
|
||
let vocoder_path = format!("{}/vocoder.onnx", onnx_dir);
|
||
|
||
let dp_ort = Session::builder()?
|
||
.commit_from_file(&dp_path)?;
|
||
let text_enc_ort = Session::builder()?
|
||
.commit_from_file(&text_enc_path)?;
|
||
let vector_est_ort = Session::builder()?
|
||
.commit_from_file(&vector_est_path)?;
|
||
let vocoder_ort = Session::builder()?
|
||
.commit_from_file(&vocoder_path)?;
|
||
|
||
let unicode_indexer_path = format!("{}/unicode_indexer.json", onnx_dir);
|
||
let text_processor = UnicodeProcessor::new(&unicode_indexer_path)?;
|
||
|
||
Ok(TextToSpeech::new(
|
||
cfgs,
|
||
text_processor,
|
||
dp_ort,
|
||
text_enc_ort,
|
||
vector_est_ort,
|
||
vocoder_ort,
|
||
))
|
||
}
|