diff --git a/.gitignore b/.gitignore index 9ce7926..2f22db0 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ build/ *.d compile_commands.json .cache/ +ref_audio.* +ref_text.* +/voices \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..16faf82 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,241 @@ +# AGENTS.md + +## Build Commands + +### Basic Build (CPU only) +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build --parallel $(nproc) +``` +For debug builds, use `-DCMAKE_BUILD_TYPE=Debug`. + +### With Vulkan GPU support +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release -DS2_VULKAN=ON +cmake --build build --parallel $(nproc) +``` + +### With CUDA GPU support +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release -DS2_CUDA=ON +cmake --build build --parallel $(nproc) +``` + +### With Metal GPU support (macOS) +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release -DS2_METAL=ON +cmake --build build --parallel $(nproc) +``` + +### Clean build +```bash +rm -rf build +``` + +### Submodule initialization +If you cloned without `--recurse-submodules`, initialize and update submodules: +```bash +git submodule update --init --recursive +``` + +### Run the executable +```bash +./build/s2 -m model.gguf -t tokenizer.json -text "Hello world" -o output.wav +``` + +## Linting and Formatting + +The project uses `clang-format` for code formatting. The ggml submodule provides a `.clang-format` configuration. To format all source files: + +```bash +find src include -name '*.cpp' -o -name '*.h' | xargs clang-format -i +``` + +Alternatively, use the ggml's formatting script if available. + +## Testing + +No unit tests are currently defined in the main s2.cpp project. The ggml submodule has its own test suite, but it's disabled by default (`GGML_BUILD_TESTS OFF`). To enable and run ggml tests: + +1. Set `GGML_BUILD_TESTS ON` in `CMakeLists.txt` (line 22) or via CMake command line: + ```bash + cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_BUILD_TESTS=ON + ``` +2. Rebuild: `cmake --build build --parallel $(nproc)` +3. Run the ggml test executable: `./build/ggml/tests/test-*` (specific test binary) or use ctest if configured. + +Alternatively, you can run the test suite via `ctest` after building: +```bash +cd build && ctest --output-on-failure +``` + +## Project Architecture Overview + +s2.cpp implements a **Dual‑Autoregressive (Dual‑AR) text‑to‑speech inference engine** for Fish Audio's S2 Pro model. It is a pure C++17 GGML‑based pipeline that runs locally with CPU, Vulkan, or CUDA backends (no Python required). + +### Core Pipeline +``` +Text → Tokenizer → Prompt Builder → Slow‑AR Transformer → Fast‑AR Decoder → Audio Codec → WAV +``` + +### Key Components + +1. **Tokenizer** (`s2_tokenizer.cpp`): + - BPE tokenizer reading HuggingFace `tokenizer.json` (Qwen3 with Byte‑Level pre‑tokenization). + - Handles special tokens (`<|im_start|>`, `<|semantic:N|>`, `<|voice|>`, etc.). + +2. **Prompt Builder** (`s2_prompt.cpp`): + - Constructs `(num_codebooks + 1) × T` integer tensor combining text tokens and optional reference‑audio codes for voice cloning. + +3. **Slow‑AR Model** (`s2_model.cpp`): + - 36‑layer Qwen3‑based transformer (4.13B params) with GQA, RoPE, QK‑norm, and KV cache. + - Processes semantic tokens; outputs hidden state and logits for next semantic token. + - **Operations**: `load()`, `init_kv_cache()`, `prefill()`, `step()`. + +4. **Fast‑AR Decoder** (`s2_model.cpp`): + - 4‑layer transformer (0.42B params) that takes Slow‑AR hidden state plus prefix codebook tokens. + - Autoregressively predicts remaining acoustic codebook tokens (10 codebooks total). + - **Operation**: `fast_decode()`. + +5. **Audio Codec** (`s2_codec.cpp`): + - Convolutional encoder/decoder with RVQ (10 codebooks × 4096 entries). + - Encodes reference audio to codes; decodes generated codes to 44.1 kHz mono waveform. + - Always runs on CPU (tiny workload). + +6. **Generation Loop** (`s2_generate.cpp`): + - Manages the autoregressive loop: prefill → while not EOS → sample semantic token → fast‑decode codebooks → store frame → step. + - Implements **Repetition‑Aware Sampling (RAS)** and semantic‑mask enforcement. + - Uses top‑k + top‑p + temperature sampling matching Fish‑Speech. + +7. **Pipeline** (`s2_pipeline.cpp`): + - Top‑level orchestrator: initializes tokenizer, model, codec; handles voice‑cloning flow; applies post‑processing (normalization, silence trimming). + - **HTTP server** (`s2_server.cpp`) exposes a `/generate` endpoint for remote synthesis. + +### Dual‑AR Design Rationale +- **Slow‑AR**: models long‑range linguistic dependencies (one semantic token per ~21.5 ms frame). +- **Fast‑AR**: models local acoustic correlations (10 codebook tokens per frame). +- This separation drastically reduces per‑step FLOPs compared to a monolithic AR model over all codebooks. + +### Memory & Execution Model +- Uses **GGML** tensors and allocators. +- Separate allocators for: KV‑cache (persistent), Slow‑AR compute buffer, Fast‑AR compute buffer, prefill temporary buffer. +- GPU backends run the transformer; codec stays on CPU. +- **posix_fadvise(DONTNEED)** on Linux to drop GGUF file from page cache after loading weights. + +### File Structure +``` +include/ # Headers (one per component) +src/ # Implementations +├── s2_model.cpp # Slow‑AR + Fast‑AR +├── s2_codec.cpp # Audio codec +├── s2_tokenizer.cpp # Tokenizer +├── s2_generate.cpp # Generation loop +├── s2_prompt.cpp # Prompt builder +├── s2_pipeline.cpp # Top‑level pipeline +├── s2_sampler.cpp # Sampling utilities +├── s2_audio.cpp # WAV I/O & audio processing +├── s2_server.cpp # HTTP server +└── main.cpp # CLI entry‑point +``` + +### Important Data Structures +- `ModelHParams`: model hyper‑parameters (context length, vocab size, codebook size, etc.). +- `PromptTensor`: `(num_codebooks+1, T)` integer matrix for model input. +- `StepResult`: hidden state + logits from a Slow‑AR step. +- `GenerateParams`: generation settings (temperature, top‑p, top‑k, max tokens, etc.). + +### Voice‑Cloning Flow +1. Load reference audio (WAV/MP3) → encode to codes via codec. +2. Build prompt: `<|im_start|> <|voice|> transcript <|im_end|> reference‑codes <|im_start|> text <|im_end|>`. +3. Model learns speaker’s voice from reference codes and transcript. +4. **Voice profile persistence** (optional): encoded codes + transcript can be saved to a `.s2voice` binary file and reused later via `--voice `. Profiles are stored in `./voices/` and checked for compatibility (codebook size, sample rate, num_codebooks). + +### When Modifying +- The **GGUF file** contains both transformer weights and codec tensors (`c.*` prefix). +- Adding new source files requires updating `CMakeLists.txt` `S2_SOURCES`. +- Follow existing patterns for error handling (`bool` returns, `std::runtime_error` for fatal errors). +- Use GPU backend guards (`#ifdef GGML_USE_VULKAN`, `GGML_USE_CUDA`). + +## Code Style Guidelines + +### Language Standard +- C++17 +- Use standard library facilities where possible. + +### Naming Conventions +- **Classes**: `PascalCase` (e.g., `SlowARModel`, `Tokenizer`) +- **Structs**: `PascalCase` (e.g., `ModelHParams`, `StepResult`) +- **Functions**: `snake_case` (e.g., `load`, `prefill`, `fast_decode`) +- **Variables**: `snake_case` (e.g., `hparams_`, `max_seq_len_`) +- **Member variables**: suffix with underscore `_` (e.g., `backend_`, `ctx_kv_`) +- **Constants**: `snake_case` with `k` prefix? (not observed; seems to use `snake_case` for static constants) +- **Namespaces**: lowercase (e.g., `s2`) + +### File Organization +- Header files in `include/` with `.h` extension. +- Source files in `src/` with `.cpp` extension. +- Each header should have `#pragma once` guard. +- Include order: + 1. Corresponding header (if in source file) + 2. Third-party library includes (e.g., `"../third_party/json.hpp"`) + 3. System includes (``, ``, etc.) + 4. Project includes (`"../include/..."`) + +### Indentation and Formatting +- Indent with 4 spaces (no tabs). +- Opening braces on the same line as function/class/struct definition. +- Use spaces around operators. +- Line length: aim for 80-100 characters, but not strictly enforced. +- Use `//` for single-line comments, `/* */` for multi-line. + +### Error Handling +- Use `bool` return values for operations that can fail (e.g., `load`, `prefill`). +- For unrecoverable errors (e.g., missing tensors), throw `std::runtime_error` with a descriptive message. +- Log errors to `stderr` using `std::fprintf(stderr, ...)` or `std::cerr`. + +### Memory Management +- Use RAII; avoid raw `new`/`delete`. +- The project uses ggml's allocators (`ggml_context`, `ggml_backend_buffer_t`). Ensure proper cleanup in destructors. +- Use `std::vector` for dynamic arrays. + +### Types +- Prefer `int32_t`, `uint32_t`, etc. from `` for fixed-width integers. +- Use `size_t` for sizes and indices. +- Use `float` for floating-point computations (ggml uses float). + +### Includes +- Minimize includes in headers; forward declare where possible. +- Use C++ versions of C headers (e.g., `` instead of ``). + +### Example Code Snippet +```cpp +#include "../include/s2_model.h" +#include +#include + +namespace s2 { + +bool SlowARModel::load(const std::string & gguf_path, int32_t gpu_device, int32_t backend_type) { + // implementation + if (!success) { + std::fprintf(stderr, "Failed to load model from %s\n", gguf_path.c_str()); + return false; + } + return true; +} + +} // namespace s2 +``` + +## Additional Notes + +- The project depends on ggml as a submodule; do not modify ggml source files. +- When adding new source files, update `CMakeLists.txt` `S2_SOURCES` list. +- The codebase is cross-platform (Linux, Windows, macOS). Use preprocessor guards for platform-specific code (`#ifdef __linux__`, `#ifdef _WIN32`). +- For GPU backends, use `#ifdef GGML_USE_VULKAN` / `GGML_USE_CUDA` guards. + +## Commit Guidelines + +- Write concise commit messages in imperative mood. +- Ensure the build passes before committing. +- Format code with clang-format before committing. \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index eb89465..e1b37d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ set(S2_SOURCES src/s2_generate.cpp src/s2_pipeline.cpp src/s2_server.cpp + src/s2_voice.cpp src/main.cpp ) diff --git a/README.md b/README.md index 899d6e4..0d2837b 100755 --- a/README.md +++ b/README.md @@ -133,6 +133,40 @@ Provide a short reference clip (5–30 seconds, WAV or MP3) and a transcript of By default, the engine uses fish-speech-aligned sampling defaults: `--min-tokens-before-end 0`, no trailing-silence trim, no peak normalization, and no dynamic loudness normalization. All of these behaviors are optional and can be enabled from the CLI. +### Voice profile persistence + +Encoded voice profiles can be saved and reused, eliminating the need to re‑encode the reference audio and transcript each time. + +**Save a voice profile** (clones voice and saves profile): +```bash +./build/s2 \ + -m s2-pro-q6_k.gguf \ + -t tokenizer.json \ + -pa reference.wav \ + -pt "Transcript of what the reference speaker says." \ + --voice alice \ + --save-voice \ + -text "Now synthesize this text in that voice." \ + -o output.wav +``` + +**Reuse a saved voice profile** (no reference audio needed): +```bash +./build/s2 \ + -m s2-pro-q6_k.gguf \ + -t tokenizer.json \ + --voice alice \ + -text "Another sentence in the same voice." \ + -o output.wav +``` + +**List saved profiles:** +```bash +./build/s2 --list-voices +``` + +Profiles are stored as `.s2voice` binary files in `./voices` (customizable with `--voice-dir`). They contain the encoded reference codes, transcript, and metadata (codebook size, sample rate, etc.). The format is **little‑endian** and portable across Windows, Linux, and macOS on x86/ARM. + ### GPU inference via Vulkan (AMD/Intel) ```bash @@ -168,6 +202,10 @@ By default, the engine uses fish-speech-aligned sampling defaults: `--min-tokens | `-text` | `"Hello world"` | Text to synthesize | | `-pa`, `--prompt-audio` | — | Reference audio file for voice cloning (WAV/MP3) | | `-pt`, `--prompt-text` | — | Transcript of the reference audio | +| `--voice` | — | Use saved voice profile (instead of -pa/-pt) | +| `--save-voice` | `false` | Save encoded voice profile after cloning (requires --voice and -pa/-pt) | +| `--voice-dir` | `"./voices"` | Directory for voice profiles | +| `--list-voices` | — | List saved voice profiles and exit | | `-o`, `--output` | `out.wav` | Output WAV file path | | `-v`, `--vulkan` | `-1` (CPU) | Vulkan device index (`-1` = CPU only) | | `-c`, `--cuda` | `-1` (CPU) | CUDA device index (`-1` = CPU only) | diff --git a/include/s2_pipeline.h b/include/s2_pipeline.h index 9e1bb31..c8b9db6 100644 --- a/include/s2_pipeline.h +++ b/include/s2_pipeline.h @@ -5,6 +5,7 @@ #include "s2_generate.h" #include "s2_model.h" #include "s2_tokenizer.h" +#include "s2_voice.h" #include #include @@ -25,6 +26,11 @@ struct PipelineParams { bool trim_silence = false; bool normalize_output = false; bool normalize_dynamic = false; + + // Voice persistence + std::string voice_id; // load saved voice profile + bool save_voice = false; // save encoded voice profile after cloning + std::string voice_storage_dir = "./voices"; // where profiles are stored }; class Pipeline { @@ -42,8 +48,16 @@ class Pipeline { Tokenizer tokenizer_; SlowARModel model_; AudioCodec codec_; + VoiceProfileManager voice_mgr_; mutable std::mutex synthesize_mutex_; bool initialized_ = false; + + // Save voice profile from encoded codes and transcript + bool save_voice_profile(const std::string & voice_id, + const std::vector & codes, + int32_t T_prompt, + const std::string & transcript, + const PipelineParams & params); }; } diff --git a/include/s2_voice.h b/include/s2_voice.h new file mode 100644 index 0000000..d0b08aa --- /dev/null +++ b/include/s2_voice.h @@ -0,0 +1,77 @@ +#pragma once +// s2_voice.h — Voice profile persistence for S2 voice cloning +// +// Stores encoded reference codes and transcript, allowing reuse of a +// cloned voice without re‑encoding the reference audio each time. +// +// Binary file format (.s2voice) — little‑endian, portable across x86/ARM: +// Offset Size Content +// 0 8 bytes Magic header: 'S','2','V','O','I','C','E','\0' +// 8 4 bytes Version (uint32_t, currently 1) +// 12 4 bytes num_codebooks (int32_t) +// 16 4 bytes T_prompt (int32_t) +// 20 4 bytes sample_rate (int32_t) +// 24 4 bytes codebook_size (int32_t) +// 28 8 bytes transcript_len (uint64_t, includes null terminator) +// 36 8 bytes codes_size (uint64_t, bytes = codes.size() * sizeof(int32_t)) +// 44 transcript_len bytes Transcript (null‑terminated C‑string) +// … codes_size bytes Code data (row‑major int32_t array) +// +// Compatibility: a profile must match the current model's num_codebooks, +// codebook_size, and sample_rate. + +#include +#include +#include + +namespace s2 { + +struct VoiceProfile { + std::string transcript; + std::vector codes; // row‑major: (num_codebooks, T_prompt) + int32_t num_codebooks = 0; + int32_t T_prompt = 0; + int32_t sample_rate = 44100; + int32_t codebook_size = 4096; + + // Metadata + std::string model_hash; // optional identifier of the source model + std::string timestamp; + + // Save to file + bool save(const std::string & path) const; + + // Load from file + static VoiceProfile load(const std::string & path); + + // Check compatibility with current codec + bool is_compatible(int32_t expected_num_codebooks, int32_t expected_codebook_size, + int32_t expected_sample_rate = 44100) const; +}; + +class VoiceProfileManager { +public: + VoiceProfileManager() = default; + + // Set storage directory (default: "./voices") + void set_storage_dir(const std::string & dir); + + // Save profile with given ID + bool save(const std::string & voice_id, const VoiceProfile & profile); + + // Load profile by ID (searches storage directory) + VoiceProfile load(const std::string & voice_id); + + // Delete profile + bool remove(const std::string & voice_id); + + // List available voice IDs + std::vector list() const; + +private: + std::string storage_dir_ = "./voices"; + + std::string get_path(const std::string & voice_id) const; +}; + +} // namespace s2 \ No newline at end of file diff --git a/openapi/README.md b/openapi/README.md new file mode 100644 index 0000000..55c2145 --- /dev/null +++ b/openapi/README.md @@ -0,0 +1,43 @@ +# OpenAPI Specification for s2.cpp HTTP Server + +This directory contains the OpenAPI 3.1 specification for the s2.cpp HTTP server mode. + +## File + +- `s2-openapi.yaml` – OpenAPI 3.1 YAML specification + +## Overview + +The s2.cpp binary includes an HTTP server mode (`--server`) that exposes a single endpoint for text-to-speech synthesis with optional voice cloning. + +### Endpoint + +- `POST /generate` – Accepts multipart/form-data with text, optional reference audio, and generation parameters; returns a WAV audio file. + +### Server Defaults + +- Host: `127.0.0.1` +- Port: `3030` + +## Usage with OpenAPI Tools + +You can use tools like [Swagger UI](https://swagger.io/tools/swagger-ui/), [Redoc](https://redoc.ly/), or [OpenAPI Generator](https://openapi-generator.tech/) to render documentation, generate client code, or validate requests. + +Example with `curl`: + +```bash +curl -X POST http://127.0.0.1:3030/generate \ + -F "text=Hello world" \ + -F "params={\"temperature\":0.9}" \ + -o output.wav +``` + +## Notes + +- The server only implements the `/generate` endpoint; no health‑check or root endpoint is provided. +- Reference audio must be accompanied by its transcript (`reference_text` field). +- The `params` field is a JSON string that follows the `GenerateParams` schema defined in the OpenAPI document. + +## License + +Same as the s2.cpp project (MIT). \ No newline at end of file diff --git a/openapi/s2-openapi.yaml b/openapi/s2-openapi.yaml new file mode 100644 index 0000000..c067045 --- /dev/null +++ b/openapi/s2-openapi.yaml @@ -0,0 +1,149 @@ +openapi: 3.1.0 +info: + title: s2.cpp HTTP Server API + description: | + HTTP server for Fish Audio S2 Pro text-to-speech inference engine. + Provides a single endpoint for generating speech from text with optional voice cloning. + version: 1.0.0 + contact: + name: s2.cpp project + url: https://github.com/fishaudio/s2.cpp + license: + name: MIT + url: https://github.com/fishaudio/s2.cpp/blob/main/LICENSE.md + +servers: + - url: http://127.0.0.1:3030 + description: Default local server + +paths: + /generate: + post: + summary: Generate speech from text + description: | + Synthesize speech from given text, optionally using reference audio for voice cloning. + Request must be multipart/form-data. + operationId: generateSpeech + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - text + properties: + text: + type: string + description: Text to synthesize (required) + example: "Hello, world!" + reference_text: + type: string + description: | + Transcript of the reference audio (required if reference audio is provided). + Aliases: ref_text, prompt_text. + example: "This is a reference audio transcript." + params: + type: string + description: | + JSON string containing generation parameters (optional). + The JSON object must match the GenerateParams schema. + See the GenerateParams component for field details. + example: '{"temperature": 0.9, "top_p": 0.9}' + reference: + type: string + format: binary + description: | + Reference audio file for voice cloning (optional). + Supported aliases: reference_audio, prompt_audio, ref_audio. + Must be a WAV or MP3 file. + additionalProperties: false + responses: + '200': + description: Generated audio as WAV file + content: + audio/wav: + schema: + type: string + format: binary + headers: + Content-Disposition: + schema: + type: string + example: 'attachment; filename="generated_audio.wav"' + '400': + description: Bad request (missing required field, invalid parameters, synthesis failure) + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + text/plain: + schema: + type: string + '500': + description: Internal server error (server initialization, pipeline failure) + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + text/plain: + schema: + type: string + +components: + schemas: + Error: + type: object + required: + - error + properties: + error: + type: string + description: Human-readable error message + example: + error: "No text field in multipart form" + + GenerateParams: + type: object + description: Generation parameters (JSON object) + properties: + max_new_tokens: + type: integer + minimum: 0 + default: 1024 + description: Maximum number of semantic tokens to generate + temperature: + type: number + format: float + minimum: 0.0 + default: 0.8 + description: Sampling temperature (higher = more random) + top_p: + type: number + format: float + minimum: 0.0 + default: 0.8 + description: Nucleus sampling probability + top_k: + type: integer + minimum: 0 + default: 30 + description: Top-k sampling (0 = disabled) + min_tokens_before_end: + type: integer + minimum: 0 + default: 0 + description: Minimum tokens before EOS token is allowed + n_threads: + type: integer + minimum: 1 + default: 4 + description: Number of CPU threads for generation + verbose: + type: boolean + default: true + description: Enable verbose logging + +tags: + - name: Speech + description: Text-to-speech generation endpoints \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 15f01d7..62622a1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,12 +16,12 @@ namespace fs = ghc::filesystem; #include #endif -static void safe_print(const char* msg) { - fputs(msg, stdout); +static void safe_print(const std::string& msg) { + fputs(msg.c_str(), stdout); } -static void safe_print_error(const char* msg) { - fputs(msg, stderr); +static void safe_print_error(const std::string& msg) { + fputs(msg.c_str(), stderr); } void print_uso() { @@ -32,6 +32,10 @@ void print_uso() { safe_print(" -text Text to synthesize\n"); safe_print(" -pa, --prompt-audio Path to reference audio for cloning\n"); safe_print(" -pt, --prompt-text Text of the reference audio\n"); + safe_print(" --voice Use saved voice profile (instead of -pa/-pt)\n"); + safe_print(" --save-voice Save encoded voice profile after cloning (requires --voice)\n"); + safe_print(" --voice-dir Directory for voice profiles (default: ./voices)\n"); + safe_print(" --list-voices List saved voice profiles\n"); safe_print(" -o, --output Output WAV path\n"); safe_print(" -v, -c, --vulkan, --cuda Vulkan/Cuda device index (-1 = CPU)\n"); safe_print(" -threads Number of threads\n"); @@ -101,6 +105,7 @@ int main(int argc, char** argv) { params.backend_type = -1; bool use_server = false; + bool list_voices = false; s2::ServerParams serverParams; for (int i = 1; i < argc; ++i) { @@ -110,6 +115,10 @@ int main(int argc, char** argv) { else if (arg == "-text") { if (i+1 < argc) params.text = argv[++i]; } else if (arg == "-pa" || arg == "--prompt-audio") { if (i+1 < argc) params.prompt_audio_path = argv[++i]; } else if (arg == "-pt" || arg == "--prompt-text") { if (i+1 < argc) params.prompt_text = argv[++i]; } + else if (arg == "--voice") { if (i+1 < argc) params.voice_id = argv[++i]; } + else if (arg == "--save-voice") { params.save_voice = true; } + else if (arg == "--voice-dir") { if (i+1 < argc) params.voice_storage_dir = argv[++i]; } + else if (arg == "--list-voices") { list_voices = true; } else if (arg == "-o" || arg == "--output") { if (i+1 < argc) params.output_path = argv[++i]; } else if (arg == "-v" || arg == "--vulkan") { if (i+1 < argc) { try { params.gpu_device = std::stoi(argv[++i]); } catch(...) {} params.backend_type = 0; } } else if (arg == "-c" || arg == "--cuda") { if (i+1 < argc) { try { params.gpu_device = std::stoi(argv[++i]); } catch(...) {} params.backend_type = 1; } } @@ -138,6 +147,41 @@ int main(int argc, char** argv) { else if (arg == "-h" || arg == "--help") { print_uso(); return 0; } } + if (list_voices) { + fs::path dir(params.voice_storage_dir); + if (fs::exists(dir)) { + safe_print("Saved voice profiles:\n"); + for (const auto & entry : fs::directory_iterator(dir)) { + if (entry.path().extension() == ".s2voice") { + safe_print(" " + entry.path().stem().string() + "\n"); + } + } + } else { + safe_print("No voice profiles directory found.\n"); + } + return 0; + } + + // Validate voice profile options + if (!params.voice_id.empty()) { + if (!params.prompt_audio_path.empty()) { + safe_print_error("Warning: --voice overrides -pa/--prompt-audio, reference audio will be ignored.\n"); + } + if (!params.prompt_text.empty()) { + safe_print_error("Warning: --voice overrides -pt/--prompt-text, prompt text will be ignored.\n"); + } + } + if (params.save_voice) { + if (params.voice_id.empty()) { + safe_print_error("Error: --save-voice requires --voice .\n"); + return 1; + } + if (params.prompt_audio_path.empty() || params.prompt_text.empty()) { + safe_print_error("Error: --save-voice requires -pa/--prompt-audio and -pt/--prompt-text.\n"); + return 1; + } + } + if (params.tokenizer_path == "tokenizer.json") { std::string model_path = params.model_path; size_t slash = model_path.find_last_of("/\\"); diff --git a/src/s2_pipeline.cpp b/src/s2_pipeline.cpp index 70e015d..c217168 100644 --- a/src/s2_pipeline.cpp +++ b/src/s2_pipeline.cpp @@ -149,18 +149,50 @@ bool Pipeline::synthesize_raw(const PipelineParams & params, AudioData & ref_aud std::vector ref_codes; int32_t T_prompt = 0; - + std::string effective_prompt_text = params.prompt_text; + + // Set voice storage directory + voice_mgr_.set_storage_dir(params.voice_storage_dir); + + // 1. Encode reference audio if provided (takes precedence over voice_id) if (!ref_audio.samples.empty()) { + safe_print_ln("Encoding reference audio..."); if (!codec_.encode(ref_audio.samples.data(), (int32_t)ref_audio.samples.size(), params.gen.n_threads, ref_codes, T_prompt)) { safe_print_error_ln("Pipeline warning: encode failed, running without reference audio."); ref_codes.clear(); T_prompt = 0; + } else { + safe_print_ln("Encoded reference audio: " + std::to_string(T_prompt) + " frames"); + // Save voice profile if requested + if (params.save_voice && !params.voice_id.empty()) { + if (!save_voice_profile(params.voice_id, ref_codes, T_prompt, effective_prompt_text, params)) { + safe_print_error_ln("Warning: failed to save voice profile."); + } + } + } + } + // 2. Otherwise load existing voice profile if voice_id is provided + else if (!params.voice_id.empty()) { + safe_print_ln("Loading voice profile: " + params.voice_id); + try { + VoiceProfile profile = voice_mgr_.load(params.voice_id); + if (!profile.is_compatible(num_codebooks, model_.hparams().codebook_size, codec_.sample_rate())) { + safe_print_error_ln("Voice profile incompatible with current model/codec."); + return false; + } + ref_codes = profile.codes; + T_prompt = profile.T_prompt; + effective_prompt_text = profile.transcript; + safe_print_ln("Loaded voice profile: " + params.voice_id + " (" + std::to_string(T_prompt) + " frames)"); + } catch (const std::exception & e) { + safe_print_error_ln("Failed to load voice profile " + params.voice_id + ": " + e.what()); + return false; } } PromptTensor prompt = build_prompt( - tokenizer_, params.text, params.prompt_text, + tokenizer_, params.text, effective_prompt_text, ref_codes.empty() ? nullptr : ref_codes.data(), num_codebooks, T_prompt); @@ -186,4 +218,36 @@ bool Pipeline::synthesize_raw(const PipelineParams & params, AudioData & ref_aud return true; } +bool Pipeline::save_voice_profile(const std::string & voice_id, + const std::vector & codes, + int32_t T_prompt, + const std::string & transcript, + const PipelineParams & params) { + if (voice_id.empty()) return false; + + voice_mgr_.set_storage_dir(params.voice_storage_dir); + + VoiceProfile profile; + profile.transcript = transcript; + profile.codes = codes; + profile.num_codebooks = model_.hparams().num_codebooks; + profile.T_prompt = T_prompt; + profile.sample_rate = codec_.sample_rate(); + profile.codebook_size = model_.hparams().codebook_size; + + // Simple timestamp + std::time_t now = std::time(nullptr); + char buf[64]; + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); + profile.timestamp = buf; + + if (voice_mgr_.save(voice_id, profile)) { + safe_print_ln("Saved voice profile: " + voice_id); + return true; + } else { + safe_print_error_ln("Failed to save voice profile: " + voice_id); + return false; + } +} + } diff --git a/src/s2_voice.cpp b/src/s2_voice.cpp new file mode 100644 index 0000000..419ee6d --- /dev/null +++ b/src/s2_voice.cpp @@ -0,0 +1,144 @@ +#include "../include/s2_voice.h" +#include "../third_party/filesystem.hpp" +#include +#include +#include +#include +#include +#include + +namespace fs = ghc::filesystem; + +namespace s2 { + +static const char MAGIC[8] = {'S','2','V','O','I','C','E','\0'}; +static const uint32_t VERSION = 1; + +bool VoiceProfile::save(const std::string & path) const { + std::ofstream out(path, std::ios::binary); + if (!out) return false; + + // Write header + out.write(MAGIC, sizeof(MAGIC)); + uint32_t version = VERSION; + out.write(reinterpret_cast(&version), sizeof(version)); + out.write(reinterpret_cast(&num_codebooks), sizeof(num_codebooks)); + out.write(reinterpret_cast(&T_prompt), sizeof(T_prompt)); + out.write(reinterpret_cast(&sample_rate), sizeof(sample_rate)); + out.write(reinterpret_cast(&codebook_size), sizeof(codebook_size)); + + uint64_t transcript_len = transcript.size() + 1; // include null terminator + out.write(reinterpret_cast(&transcript_len), sizeof(transcript_len)); + + uint64_t codes_size = static_cast(codes.size()) * sizeof(int32_t); + out.write(reinterpret_cast(&codes_size), sizeof(codes_size)); + + // Write transcript (null‑terminated) + out.write(transcript.c_str(), transcript_len); + + // Write codes + out.write(reinterpret_cast(codes.data()), static_cast(codes_size)); + + return out.good(); +} + +VoiceProfile VoiceProfile::load(const std::string & path) { + std::ifstream in(path, std::ios::binary); + if (!in) throw std::runtime_error("cannot open voice profile: " + path); + + char magic[8]; + in.read(magic, sizeof(magic)); + if (std::memcmp(magic, MAGIC, sizeof(MAGIC)) != 0) { + throw std::runtime_error("invalid voice profile magic"); + } + + uint32_t version; + in.read(reinterpret_cast(&version), sizeof(version)); + if (version != VERSION) { + throw std::runtime_error("unsupported voice profile version"); + } + + VoiceProfile profile; + in.read(reinterpret_cast(&profile.num_codebooks), sizeof(profile.num_codebooks)); + in.read(reinterpret_cast(&profile.T_prompt), sizeof(profile.T_prompt)); + in.read(reinterpret_cast(&profile.sample_rate), sizeof(profile.sample_rate)); + in.read(reinterpret_cast(&profile.codebook_size), sizeof(profile.codebook_size)); + + uint64_t transcript_len; + in.read(reinterpret_cast(&transcript_len), sizeof(transcript_len)); + + uint64_t codes_size; + in.read(reinterpret_cast(&codes_size), sizeof(codes_size)); + + // Read transcript + std::vector transcript_buf(transcript_len); + in.read(transcript_buf.data(), static_cast(transcript_len)); + if (transcript_buf.back() != '\0') { + throw std::runtime_error("transcript not null‑terminated"); + } + profile.transcript = transcript_buf.data(); + + // Read codes + size_t n_codes = codes_size / sizeof(int32_t); + profile.codes.resize(n_codes); + in.read(reinterpret_cast(profile.codes.data()), static_cast(codes_size)); + + if (!in) throw std::runtime_error("truncated voice profile"); + + return profile; +} + +bool VoiceProfile::is_compatible(int32_t expected_num_codebooks, int32_t expected_codebook_size, + int32_t expected_sample_rate) const { + return (num_codebooks == expected_num_codebooks) && + (codebook_size == expected_codebook_size) && + (sample_rate == expected_sample_rate); +} + +// --------------------------------------------------------------------------- +// VoiceProfileManager +// --------------------------------------------------------------------------- + +void VoiceProfileManager::set_storage_dir(const std::string & dir) { + storage_dir_ = dir; +} + +std::string VoiceProfileManager::get_path(const std::string & voice_id) const { + fs::path dir(storage_dir_); + if (!fs::exists(dir)) { + fs::create_directories(dir); + } + return (dir / (voice_id + ".s2voice")).string(); +} + +bool VoiceProfileManager::save(const std::string & voice_id, const VoiceProfile & profile) { + std::string path = get_path(voice_id); + return profile.save(path); +} + +VoiceProfile VoiceProfileManager::load(const std::string & voice_id) { + std::string path = get_path(voice_id); + return VoiceProfile::load(path); +} + +bool VoiceProfileManager::remove(const std::string & voice_id) { + std::string path = get_path(voice_id); + if (fs::exists(path)) { + return fs::remove(path); + } + return false; +} + +std::vector VoiceProfileManager::list() const { + std::vector result; + if (!fs::exists(storage_dir_)) return result; + + for (const auto & entry : fs::directory_iterator(storage_dir_)) { + if (entry.path().extension() == ".s2voice") { + result.push_back(entry.path().stem().string()); + } + } + return result; +} + +} // namespace s2 \ No newline at end of file