A PyTorch platform for training and experimenting with neural networks.
- Bigram model — a simple neural network that learns character-to-character transition probabilities
- LSTM model — a recurrent neural network capable of learning longer-range character sequences using memory and context
- GRU model — a gated recurrent unit network for efficient sequence modeling with fewer parameters than LSTM
- Transformer model — a trainable transformer supporting both training and generation
- DistilGPT2 model — inference-only; uses a pre-trained Hugging Face transformer for high-quality text generation
The codebase is modular, config-driven, and supports training, checkpointing, early stopping, hyperparameter tuning, and generation from any model via CLI. Comprehensive unit tests are included for all major modules, including training, library, utilities, visualization, tuning, CLI, model behavior, and profiling (current coverage: 100%).
- Features
- Model Architectures
- Datasets
- Configuration
- Configuration Editor
- Pipeline GUI
- Model Export
- Hyperparameter Tuning
- Optuna Dashboard
- Profiling
- Usage
- Loss Visualization
- GPU Acceleration
- Dependencies
- Docker Usage
- Testing
- Future Improvements
- License
- Character-level or word-level tokenization across multiple input files (configurable via
token_levelinmodel_options) - Dynamic vocabulary and index mapping
- Modular model registry for Bigram, LSTM, GRU, Transformer, and DistilGPT2 (inference-only)
- Configurable training and hyperparameter tuning via
config.json - Automatic hyperparameter tuning with Optuna
- Optuna Dashboard for visualizing hyperparameter optimization studies
- Adam optimizer with early stopping
- Automatic checkpoint rotation and resumption with metadata tracking
- Text generation strategies: Random generation and prompt-based generation (configurable via
generatoringenerator_options) - Multinomial and argmax sampling for text generation (configurable via
sampleringenerator_options) - Temperature scaling for controllable randomness in generation (configurable via
temperatureingenerator_options) - Comprehensive CLI interface with model selection, runtime, hyperparameter, model options, tuning, and visualization configuration
- Full unit test coverage (100%) for all modules
- Tests include generation and training for all models, tuning, visualization, CLI behavior, argument parsing helpers, and profiling
- Loss visualization with matplotlib, including smoothing and saving plots
- GPU-accelerated training by default
- Integrated dataset library with pre-configured datasets
- Support for local files, Hugging Face datasets, and built-in library datasets
- Built-in profiling for performance analysis
- Interactive GUI for editing
config.json - Pipeline GUI: End-to-end graphical interface for model selection, training, evaluation, and export
- Model Export System: Export models in multiple formats (Architecture, Package, Framework, Application)
- Statistics: 347 unit tests, 100% coverage, 1875 stmts / 0 miss
A lightweight model that uses an embedding table to predict the next character from the current character only. Fast and simple, but limited in predictive capability.
A recurrent neural network using embedding, multi-layer LSTM, and projection back to vocab size. Learns long-range dependencies across sequences for improved generation.
A gated recurrent unit network that efficiently models sequences with fewer parameters than LSTM, providing a balance between speed and performance.
A trainable transformer model using self-attention mechanisms for sophisticated text generation. Architecture includes token and position embeddings, multi-head attention, feedforward layers, and stacking of encoder layers.
Integration with a pre-trained Hugging Face DistilGPT2 model for high-quality text generation. Inference-only (cannot be trained or fine-tuned).
Three types of datasets are supported:
The dataset/ directory (included in this repo) contains 100 filtered sample texts for training, preprocessed from Project Gutenberg. These texts have been cleaned and filtered using length, markup, and English word ratio heuristics.
For more training data, you can download the full cleaned dataset (4,437 books) on Hugging Face.
The project includes a pre-configured library of datasets:
- News (0.03 GB) - AG News dataset
- Science (0.41 GB) - PubMed QA dataset
- Movies (0.49 GB) - IMDB dataset
- Yelp (0.51 GB) - Yelp Review Full dataset
- SQuAD (0.08 GB) - Stanford Question Answering Dataset
- Tiny Stories (1.89 GB) - Tiny Stories dataset
- Stack Overflow (5.75 GB) - Stack Overflow Questions dataset
- Wikipedia (18.81 GB) - English Wikipedia dataset
You can use any dataset from the Hugging Face Hub by specifying the dataset name and configuration in config.json. This allows for flexible experimentation with different text sources.
All behavior is driven by a single config.json file. You can edit this file manually or with the user-friendly Configuration Editor.
Example config.json (click to expand)
{
"datasets": {
"source": "library",
"locations": {
"local": {
"directory": "dataset",
"extension": "txt"
},
"library": {
"data_name": "science"
},
"huggingface": {
"data_name": "pubmed_qa",
"config_name": "pqa_artificial",
"split": "train",
"field": "question"
}
}
},
"generator_options": {
"generator": "random",
"context_length": 128,
"sampler": "multinomial",
"temperature": 1.0
},
"model_options": {
"save_model": true,
"token_level": "char",
"patience": 10,
"max_checkpoints": 10
},
"models": {
"bigram": {
"runtime": {
"training": true,
"steps": 10000,
"interval": 100,
"max_new_tokens": 128
},
"hparams": {
"batch_size": 16,
"block_size": 32,
"lr": 0.001
}
},
"lstm": {
"runtime": {
"training": true,
"steps": 50000,
"interval": 500,
"max_new_tokens": 256
},
"hparams": {
"batch_size": 32,
"block_size": 64,
"lr": 0.0015,
"embedding_dim": 64,
"hidden_size": 128,
"num_layers": 2
}
},
"gru": {
"runtime": {
"training": true,
"steps": 50000,
"interval": 500,
"max_new_tokens": 256
},
"hparams": {
"batch_size": 32,
"block_size": 64,
"lr": 0.0015,
"embedding_dim": 64,
"hidden_size": 128,
"num_layers": 2
}
},
"transformer": {
"runtime": {
"training": true,
"steps": 100000,
"interval": 1000,
"max_new_tokens": 256
},
"hparams": {
"batch_size": 32,
"block_size": 64,
"lr": 0.001,
"embedding_dim": 32,
"max_seq_len": 128,
"num_heads": 2,
"ff_dim": 128,
"num_layers": 2
}
},
"distilgpt2": {
"runtime": {
"max_new_tokens": 256
},
"hparams": {
"block_size": 32
}
}
},
"pruners": {
"median": {
"n_startup_trials": 5,
"n_warmup_steps": 1000
},
"halving": {
"min_resource": 5,
"reduction_factor": 2,
"min_early_stopping_rate": 1
},
"hyperband": {
"min_resource": 5,
"reduction_factor": 2
}
},
"tuning_options": {
"auto_tuning": true,
"save_tuning": true,
"save_study": true,
"n_trials": 100,
"pruner": "hyperband",
"step_divisor": 10
},
"tuning_ranges": {
"batch_size": {
"type": "int",
"min": 16,
"max": 128,
"step": 16
},
"block_size": {
"type": "int",
"min": 32,
"max": 256,
"step": 32
},
"lr": {
"type": "float",
"min": 0.0001,
"max": 0.01,
"log": true
},
"embedding_dim": {
"type": "int",
"min": 16,
"max": 128,
"step": 16
},
"hidden_size": {
"type": "int",
"min": 32,
"max": 256,
"step": 32
},
"num_layers": {
"type": "categorical",
"values": [1, 2, 3, 4]
},
"max_seq_len": {
"type": "int",
"min": 32,
"max": 256,
"step": 32
},
"num_heads": {
"type": "int",
"min": 2,
"max": 8,
"step": 2
},
"ff_dim": {
"type": "int",
"min": 32,
"max": 256,
"step": 32
}
},
"visualization": {
"save_plot": true,
"show_plot": true,
"smooth_loss": true,
"smooth_val_loss": true,
"weight": 0.9
}
}You can configure:
- Datasets (
datasets): Source and location to pull from - Generator Options (
generator_options): Text generation and sampling strategies, along with relevant options. - Model Options (
model_options): Model saving, tokenization level, early stopping patience, and checkpoint management - Runtime (
runtime): Training and generation settings for each model - Hyperparameters (
hparams): Model-specific architecture and optimization parameters - Pruners (
pruners): Configuration for Optuna pruners - Tuning Options (
tuning_options): Optuna tuning configuration, pruner, and trial settings - Tuning Ranges (
tuning_ranges): Hyperparameter search spaces for automatic tuning - Visualization (
visualization): Loss plotting, smoothing, and saving options
An interactive GUI for editing config.json is included to simplify configuration management.
- Launch with:
python -m run config
- This will open a window where you can view and modify all settings from
config.json. - The config editor supports saving changes back to the file.
- DearPyGui was used to build the editor.
A comprehensive end-to-end graphical interface for managing the entire model lifecycle, from selection to training, evaluation, and export.
- Launch with:
python -m run pipeline
- Features:
- Model Selection: Choose from Bigram, LSTM, GRU, or Transformer models
- Training Control: Start and stop model training with real-time progress monitoring
- Evaluation Mode: Run trained models for text generation
- Integrated Terminal: Built-in terminal window for logging and debugging output
- Configuration Editor: Access the config editor directly from the pipeline interface
- Model Export: Export models in various formats (see Model Export)
- Real-time Updates: Live updates of training progress and terminal output
The pipeline GUI provides a unified interface for all model operations, making it easy to manage training workflows without switching between command-line tools.
The project includes a comprehensive model export system that allows you to export trained models in multiple formats for different use cases.
-
Architecture Export
- Exports the model architecture as a standalone Python module
- Includes the model's hyperparameters from
config.json - Generates a complete, runnable Python file with the model class and hyperparameter dataclass
- Useful for sharing model architectures or using them in other projects
-
Package Export
- Bundles the trained model with all required dependencies using PyTorch's PackageExporter
- Creates a
.pkgfile that can be loaded and used independently - Includes model weights, vocabulary mappings, and necessary code dependencies
- Ideal for deployment or sharing trained models
-
Framework Export
- Exports the model as a
.zipfile containing the package structure - Allows for custom framework development around the exported model
- Provides more flexibility for building custom implementations
- Exports the model as a
-
Application Export
- Compiles models as standalone executable applications using PyInstaller
- Creates executables for the platform where PyInstaller is run (Windows
.exe, Linux, or macOS executables) - Includes model weights, vocabulary mappings, and all dependencies bundled into a single source
- Applications are named based on the model and generation strategy (e.g.,
transformer_random_generation)
Model export is available through both the Pipeline GUI and the CLI:
Pipeline GUI:
- Launch the pipeline:
python -m run pipeline - Select a model to export
- Choose an export format from the export menu
- The exported files will be saved to the
exports/directory
CLI:
Export models directly from the command line using the --export argument:
python main.py --model transformer --export architecture
python main.py --model lstm --export package
python main.py --model gru --export framework
python main.py --model transformer --export applicationExports are organized by format followed by model name.
Automatic hyperparameter tuning is supported via Optuna.
- Enable tuning by setting
"auto_tuning": trueinconfig.json. - Define search spaces in the
"tuning_ranges"section (supportsint,float, andcategoricaltypes). - Tuning is integrated into the training workflow and can be controlled via the CLI or config.
- Results are saved and can be used to update model hyperparameters automatically if
"save_tuning": true. - Tunable fields include:
batch_size,block_size,lr,embedding_dim,hidden_size,num_layers,max_seq_len,num_heads,ff_dim. - Pruners and tuning options are configurable in
config.json.
Visualize your hyperparameter optimization studies interactively with Optuna Dashboard.
- Launch with:
python -m run dashboard
- The dashboard will open at localhost:8080.
- Requires
optuna-dashboard(install withpip install optuna-dashboard). - Shows study history, parameter importance, and more.
A built-in profiling tool is included to help you analyze performance bottlenecks in your code.
- The profiler is located at
run/profiler.pyand is accessible via the modular run package. - It runs the main application under
cProfileand saves filtered, timestamped reports toprofiles/. - Reports are filtered to exclude virtual environment and bootstrap code, and show the top functions by calls, time, and cumulative time.
- You can run the profiler with:
python -m run profiler
- You can also import and use the profiler programmatically in your own scripts or tests.
- The profiling tool is unit tested with 100% coverage.
The easiest way to interact with the project is through the Pipeline GUI, which provides a unified interface for all operations:
python -m run pipelineThis launches a graphical interface where you can:
- Select models
- Start/stop training
- Run evaluations
- Export models
- View real-time logs and progress
For programmatic control, the project provides a flexible CLI for controlling model options, runtime and hyperparameters, tuning options, visualization, and exports:
- Generator Options (
--generator,--prompt,--context-length, ...) - Model Options (
--save-model,--token-level,patience...) - Runtime (
--training,--steps,--interval, ...) - Hyperparameters (
--batch-size,--block-size,--lr, ...) - Tuning Options (
--auto-tuning,--save-tuning,--save-study, ...) - Visualization (
--save-plot,--show-plot,--smooth-loss, ...) - Export (
--export)
For a full list of arguments, run:
python main.py --helpAvailable CLI Arguments: (click to expand)
--model: Select model type (default: transformer, options: [bigram | lstm | gru | transformer | distilgpt2])--generator: Override text generation strategy (generator_options)--prompt: Override prompt for prompt-based generation (generator_options)--context-length: Override context length for prompt generation (generator_options)--sampler: Override sampling strategy for generation (generator_options)--temperature: Override temperature for generation (generator_options)--save-model: Override model saving (model_options)--token-level: Override tokenization level (model_options)--patience: Early stopping patience (model_options)--max-checkpoints: Maximum checkpoints to keep (model_options)--training: Toggle training mode (runtime)--steps: Number of training steps (runtime)--interval: Validation interval during training (runtime)--max-new-tokens: Maximum tokens to generate (runtime)--batch-size: Override batch size for training (hparams)--block-size: Override context window size (hparams)--lr: Override learning rate (hparams)--embedding-dim: Override embedding dimension size (hparams)--hidden-size: Override hidden layer size (hparams)--num-layers: Override number of model layers (hparams)--max-seq-len: Override maximum sequence length (hparams)--num-heads: Override number of attention heads (hparams)--ff-dim: Override feedforward dimension (hparams)--auto-tuning: Enable/disable hyperparameter tuning (tuning_options)--save-tuning: Enable/disable saving tuned hyperparameters (tuning_options)--save-study: Enable/disable saving Optuna study (tuning_options)--n-trials: Number of Optuna trials (tuning_options)--pruner: Pruner type for Optuna (tuning_options)--step-divisor: Step divisor for tuning (tuning_options)--save-plot: Enable/disable saving loss plots (visualization)--show-plot: Enable/disable showing loss plots (visualization)--smooth-loss: Enable/disable smoothing of loss curves (visualization)--smooth-val-loss: Enable/disable smoothing of validation loss (visualization)--weight: Smoothing weight (visualization)--export: Override export format (export) (options: [architecture | package | framework | application])
- The model argument will default to
--model transformerif omitted. - Any arguments omitted will default to the respective value defined in
config.json. - Boolean flags support flexible input:
true,false,on,off,yes,no,1,0. - distilgpt2 uses a pre-trained Hugging Face model and is inference-only (cannot be trained).
# Train the Transformer model
python main.py --model transformer --training trueAfter training, switch to generation mode by setting "training": false inside the appropriate section of config.json:
"training": falseThen run the same command to generate text:
python main.py --model transformerThe project supports two text generation strategies:
- Random Generation (
"generator": "random"): Starts with a randomly selected seed character/word and generates text from there - Prompt-Based Generation (
"generator": "prompt"): Uses a user-provided prompt and generates text continuing from that prompt
You can control the randomness of generation using the temperature argument in config.json. Lower values make output more deterministic; higher values make it more random.
For prompt-based generation, you can also configure the context_length to control how much of the prompt history is used for generating the next token.
- Plotting both training and validation loss curves
- Optional exponential smoothing for clearer trends
- Saving plots with timestamped filenames to a model-specific directory
- Configurable via
config.json
- GPU acceleration is used by default if a CUDA device is available.
- Python 3.10+
- PyTorch
- matplotlib
- datasets (Hugging Face)
- transformers (DistilGPT2)
- optuna (for hyperparameter tuning)
- dearpygui (for configuration editor and pipeline GUI)
- PyInstaller (for application export)
Install all production dependencies with:
pip install -r requirements.txtFor development, testing, and linting, additional dependencies are available:
- pytest (for unit testing)
- pytest-cov (for coverage reporting)
- coverage (for code coverage analysis)
- flake8 (for code linting)
- flake8-docstrings (for docstring style checking)
- isort (for import sorting)
Install development dependencies with:
pip install -r requirements-dev.txtOr install both production and development dependencies:
pip install -r requirements.txt -r requirements-dev.txtNote:
The torch package in requirements.txt is the CPU version of PyTorch.
If you want to use a GPU, it is recommended to install the appropriate CUDA-enabled version of PyTorch for your system before running pip install -r requirements.txt for a quicker install.
You can find the correct install command for your system and CUDA version at the official PyTorch installation page.
For example, to install PyTorch with CUDA 12.8 support:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128You can run this project in a Docker container for a fully isolated and reproducible environment.
docker build -t <your-image-name> .docker run --rm -it <your-image-name>:latest- The provided
Dockerfileuses thepython:3.11-slimbase image for a smaller footprint. - It explicitly installs the CPU-only version of PyTorch to avoid large image sizes. If you need GPU support, modify the Dockerfile and requirements accordingly.
- The
.dockerignorefile is configured to exclude unnecessary files (such as datasets, checkpoints, and virtual environments) from the image. If you add new large files or folders, update.dockerignoreto keep your image size small. - If you encounter issues with image size, check that you are not copying large files or using the GPU version of torch by accident.
You can modify the CMD in the Dockerfile to run other scripts or pass arguments as needed.
- The project includes comprehensive unit tests for all major modules: training, datasets, utility functions, loss visualization, tuning, model/CLI behavior, profiling, pipeline, UI components, services, and model export.
- Tests are written using
pytestwithcoveragefor reporting. Both are included inrequirements-dev.txt - All unit tests are located in the
tests/directory. - Statistics: 347 unit tests, 100% coverage, 1875 stmts / 0 miss
- To run all tests:
pytest
- To check coverage:
coverage run -m pytest coverage report -m
- You can also run a specific test file, for example:
pytest tests/test_ui.py
- Test output will show which tests passed or failed, and coverage will report which lines are tested.
- Coverage includes data processing, plotting, model logic, CLI argument parsing, tuning, profiling, pipeline GUI, UI components, services (architecture, exporter, terminal, callbacks), and more.
- Logger/debug lines are excluded from coverage analysis (see
.coveragerc/pyproject.toml).
- Logger/debug lines are excluded from coverage analysis (see
- Add learning rate scheduling during training
- Add visualization for transformer attention
- Add beam search to model generation
- Expand model sampling and generation options
This project is licensed under the Apache 2.0 License. See LICENSE for details.