HN 표시: WaveletLM – O(n log n) 스케일링을 사용하는 웨이블릿 기반의 주의가 필요 없는 모델

hackernews | | 📰 뉴스
#오픈소스
원문 출처: hackernews · Genesis Park에서 요약 및 분석

요약

WaveletLM은 어텐션 메커니즘을 제거하고 웨이블릿 변환을 활용해 토큰을 혼합하는 언어 모델로, O(n log n)의 계산 복잡도를 가집니다. 파라미터가 8억 830만 개인 이 모델은 RTX 5090에서 추론 시 약 4.9GB의 메모리를 사용하며 초당 28.8 토큰을 생성하는 성능을 보여줍니다. 또한 향후 양자화 기술을 적용하면 처리 속도가 베이스라인 대비 최대 2.2배까지 향상될 것으로 기대됩니다.

본문

WaveletLM is a wavelet-based, attention-free language model that mixes tokens through learned lifting wavelet decomposition, a Fast Walsh-Hadamard Transform, per-scale gated spectral mixing with SwiGLU activation, an inverse FWHT, and wavelet reconstruction. Combined with expanded MLPs and sparse product-key memory, this yields an architecture with no attention and O(n log n) scaling in sequence length. Installation Training Inference Sample Generations Architecture Results Future Plans License References Requires Python 3.10+, PyTorch 2.8+, and CUDA. git clone https://github.com/ramongougis/WaveletLM.git cd WaveletLM pip install torch "datasets_/ . Results from all runs are tracked in runs.md . The full default run takes ~14h on an RTX 5090; drop epochs to 1 for a quick smoke test. Obtain weights from HuggingFace, then replace best_model.pt in the commands below with the path to the file. The current best 883M parameter model requires 4,918 MiB for inference and generates at 28.8 tokens/s on a 5090. Recommended generation command: python generate.py --checkpoint best_model.pt --strategies / --prompt "Your prompt here" Default generation: python generate.py --checkpoint best_model.pt Additional options: python generate.py --checkpoint best_model.pt / --prompt "Your prompt goes here." --num_tokens 1024 --seed 1337 / --n 1 --temperature 1.0 --strategies --ptq8 --num_tokens 9000 Can run all strategies together with --strategies or individual ones. Use --help for a complete list. Use all inference strategies: python generate.py --checkpoint best_model.pt --strategies Some strategies options: python generate.py --checkpoint best_model.pt --entropy_adaptive / --lookahead_k 3 --lookahead_depth 5 --best_of_n 5 --clean_spacing / --wavelet_coherence Near-lossless uniform 8-bit PTQ: python generate.py --checkpoint best_model.pt --ptq8 PTQ effects: - +0.0001 BPB hit (negligible performance impact) - 10% less inference VRAM - 50% less checkpoint file size - 12% less tok/s currently. However, PTQ is expected to be 1.4-2.2x faster than the baseline 28.8 tok/s with bit-packed kernels. See Future Plans → Bit-Packed PTQ Kernels and runs.md ). Selections from the WikiText-103 best PPL run with strategies enabled. Post-processing spacing fixes were added on 2026-04-26. Generations in the linked file from before that date contain raw WikiText artifacts (@-@ for hyphens, spaces around commas, etc.); the samples below have been formatted manually for readability. Sample A with prompt "The history of": The history of the city is reflected in its architecture, which includes the historic Old Town and New Castle County Courthouse Square Historic District. The building was designed by John H. Stevens, who also designed the Albany-Fulton Celebration in 1906 and built a steel-hulled shipyard on the lake shore. In 1909, Fort Wayne became home to the first naval base for military aviation in the United States. The Navy Yard at the former Naval Air Station Pensacola (now known as the Georgia Tech Research Institute), located on the campus' eastern boundary, was constructed in 1925–1926 near the site of the original schoolhouse. Sample B with prompt "The album was released": The album was released on August 25, 2007 by Sony Music Entertainment and features several songs from the record including "Never Say Die", "The Show", "Don't Cry for Me Argentina" and a cover of "I Can Only Imagine (But You Are Not Alone)". In 2009, Tyler performed at the Rock in Rio festival held in Lisbon, Portugal; a performance which included performances of his own compositions by Queen Latifah and Ali Zardari. On July 1, 2010 he sang "O Holy Night" during the closing ceremony of the 2012 Summer Olympics opening ceremony at London Olympic Stadium as part of the celebrations of the 75th anniversary of the United States Declaration of Independence. Sample C with prompt "The species was first described": The species was first described by Swedish zoologist Carl Linnaeus in 1758 as Agaricus adustus. The genus name is derived from the Latin words perma "to tie", and pous ("like") means "with a large head". In 1821, French mycologists Jean-Baptiste de Lacaille placed it in section Cricetae of the order Carnivora. He later renamed it Spongiforma punctata after the Greek kribensis. P. nigriceps was originally classified under its current binomial name Sirmuellera speciosa; however, this has been rejected on the grounds that Lactarius deterrimus may not be closely related to Boletus edulis or other similar fungi in general. - Learnable lifting wavelet decomposition: Haar-initialized predict/update networks decompose each sequence into multi-scale coefficients per block, trained end-to-end with causality preserved via zero-padded dilation. Decompose/reconstruct weights are shared. Untying them had negligible performance impact while saving parameters. - Fast Walsh-Hadamard Transform (FHT): a fixed orthogonal O(C log C) cross-channel rotation replacing attention's channel-mixing role. Cost is independent of sequence length. - Per-scale gated spectral mixer (SwiGLU): mixes each wavelet scale independently in Walsh-Hadamard space via a gated linear layer. Runs in fixed O(S²) per layer for S scales (S = levels + 1), versus attention's O(N²) in sequence length. - Expanded MLP (expansion ≥ 20): Hidden layer width multiplier for the MLP layers. Logarithmic relationship with BPB. - Decompose bypass: a causal cumulative mean of pre-decompose hidden states, projected per-scale and added as bias to the post-decompose coefficients. Additional key components (always-on architectural pieces) - LayerNorms near both ends of each block, and one before the LM head - Two residual connections per block with learned scalar gating ( learned_residual in config.json) - Per-scale weights applied after the inverse FHT, one trainable scalar per wavelet scale - Feature padding to the next power of 2, required for the Walsh-Hadamard transform ( C →Cp = next_pow2(C) ) - Causal zero-padded dilation in the lifting predict/update steps, preserving autoregressive causality at every level - Per-Layer Embedding: a learned per-channel residual of the token embedding added at each block, letting deeper blocks reach back to the input representation. - Product Key Memory / Fast-Weight Product Key Memory: sparse key-value memory modules complementing the dense MLP, with optional inference-time fast-weight updates. - Low-Rank Factorization: a rank-r U·V^T perturbation added to the spectral mixer; rank=4 yields a measurable BPB improvement at trivial parameter cost. - Exponential Parametrization: reparameterizes mixer weights through exp() , stabilizing training under high learning rates that would otherwise NaN. - Cross-scale gating (routing mode): a learned identity-initialized (S, S) routing matrix that mixes per-scale inputs before each gate, enabling conditional cross-scale interactions. - Per-scale mixer widths: asymmetric per-scale mixer capacity (coarse scales full width, fine scales reduced). At [1, 1, 1, 0.5, 0.5, 0.5] : small BPB improvement + ~23% per-epoch speedup. - Wavelet crawl: softmax-weighted mixture of K candidate dilations per level around the base 2^l , letting the model discover off-power-of-2 receptive fields. K=3 (±1) is the stable sweet spot. - Shared lifting weights: one lifting wavelet module shared across all blocks. Essentially free on BPB; cuts training VRAM by ~5–10% at L=2. - Looped blocks (Universal Transformer-style): one shared block applied K times in place of L stacked blocks. Reduces BPB at fixed parameter count; compute is usually better spent on more epochs of the stacked model. Additional optional features (all configurable in config.json ) - Data-dependent EMA decompose-bypass ( decompose_bypass_ema ): σ-gated adaptive IIR replacement for the cumulative running mean. Promising at 1 epoch (-0.30 nats val loss), regressed at 5 epochs (BPB 1.0226 vs 1.0201 baseline). Rejected for release; investigation plan in plans/ema_post_release.md. - Cross-layer decompose bypass state carry ( decompose_bypass_cross_window ) - Stable-parametrization master flag ( stable_parametrization ) - Spectral-norm constraint on mixer weights ( stab_spectral_norm ) - MLP final-layer variance scaling ( stab_ff_scaling ) - √C embedding output scaling ( stab_embed_scaling ) - Projection-out residual-stream scaling ( stab_proj_out_scaling ) - Mixer init-epsilon scaling ( stab_mixer_eps_scaling ) - Per-level lifting init damping ( stab_lifting_level_scaling ) - Multi-basis (K parallel) lifting wavelets ( multi_basis_lifting ,multi_basis_inits ) - Untied reconstruction weights ( untied_reconstruction ) - Linear-only lifting networks - no GELU ( lifting_linear_only ) - Stacked spectral mixer depth ( mixer_depth ,mixer_depth_stabilizers ,mixer_depth_residuals ) - LoopLM mode - full-stack iterated inference ( loop_iterations ) - Weight tying between embedding and LM head ( tie_embedding_to_lm_head ) - Output-projection skip when C equals Cp ( skip_proj_out ) - Gradient checkpointing ( gradient_checkpointing ) - Stochastic depth ( stochastic_depth_rate ) - Per-component dropouts ( dropout_embedding ,dropout_projection ,dropout_mixer ,dropout_mlp ,dropout_lm_head ) - Lifting-network hidden-dim multiplier ( lifting_hidden_mult ) - Lifting initialization choice - Haar / zero / random ( lifting_init ) - Lifting dropout ( lifting_dropout ) - Spectral mixer gate toggle and activation ( use_mixer_gate ,mixer_gate_activation ) - Non-learned fixed-Haar fallback for the wavelet ( wavelet_mode="haar" ) - Multinodal feature bagging mode and its sub-flags ( multinodal_enabled ,multinodal_num_cells ,multinodal_cell_dim ,multinodal_seeds ,multinodal_combination ,multinodal_cross_cell_gating ,multinodal_features_per_cell ,multinodal_bagged_eps ) It is important to note that WaveletLM has not been fully optimized: - it is underregularized with a 0.8 train/val loss gap, - the 5 dropout parameters have not been swept, - weight decay needs further tuning, - longer training time is needed, and - parameter compression has not yet been applied. My current run budget is limited. Other researchers are encouraged to train the model with these changes to more accurately gauge its potential performance. See Areas for Improvement below for more info on optimization, and Future Plans for ways to push WaveletLM further post-release. | Model | Type | Trained on | Params | PPL | |---|---|---|---|---| | GPT-2 XL | Transformer | WebText (40GB) | 1.5B | 17.51 | | Transformer-XL Large* | Transformer + recurrence* | WikiText-103 (0.5GB)* | 257M* | 18.32* | | GPT-2 Large | Transformer | WebText (40GB) | 774M | 19.31 | | S4* | SSM* | WikiText-103 (0.5GB)* | 130M* | 20.93* | | GPT-2 Medium | Transformer | WebText (40GB) | 355M | 22.11 | | WaveletLM | Wavelet mixer | WikiText-103 (0.5GB) | 883M | 23.8† | | Transformer-XL Standard* | Transformer + recurrence* | WikiText-103 (0.5GB)* | 151M* | 24.02* | | GPT-2 | Transformer | WebText (40GB) | 124M | 29.41 | * Both trained and evaluated on WikiText-103 only (direct comparison to WaveletLM). GPT-2 BPE was used by WaveletLM for tokenization. † Best of 3 seeds PPL of 23.749 with mean PPL of 23.818. See runs.md for a record of all training runs, logs, configs, and benchmark results with fully-reproducible point-in-time code snapshots. | Model | Type | Params | PPL | |---|---|---|---| | Perceiver AR | Cross-attn + latents | 974M | 28.94 | | Block-Recurrent Transformer | Transformer + recurrence | ~200M | 29.05 | | Compressive Transformer | Transformer + compressive memory | 257M | 33.66 | | Transformer-XL | Transformer + recurrence | 257M | 36.36 | | WaveletLM (1 epoch) | Wavelet mixer | ~808M | TBD (1 epoch) (pending pre-release run) | All models in this table were trained and evaluated on PG-19 with its standard SentencePiece tokenization. Unlike the others, WaveletLM was trained on one epoch only. Comparison numbers for both datasets are sourced from their respective papers. See References below. Longer training time, more regularization, and parameter compression are the surest ways to immediately improve the model's performance. More training time: More research and more resources are needed to uncover the effects of longer training. Regularization: WaveletLM is vastly underregularized, with a 0.8 train/val loss gap at 5+ epochs. Dropout and weight decay parameter sweeps are limited by budget and involve tuning weight_decay dropout_embedding , dropout_projection , dropout_mixer , dropout_mlp , and dropout_lm_head in tandem. Parameter compression: Of WaveletLM's 883M parameter total, around 55% (488M) live in two highly compressible components: dense MLPs (335.6M) and product-key memory modules (PKM: 76M + FwPKM: 76M). Further work is needed to determine the degree of compressivity of each during training, which makes it complementary to PTQ. The PG-19 run above was trained for a single epoch using the WikiText-optimized config. Published baselines for other models on the same dataset were likely trained for many more epochs or with much more effective compute. Once it is possible, the first post-release goal will be to train on PG-19 for 2 epochs, and loss permitting, 5 epochs, in order to better gauge language modeling on a large dataset at the current parameter size. The best WaveletLM config trained on Pile-ArXiv, BookCorpusOpen, OpenWebText, and other datasets to gauge their performance. Side-by-side benchmarks against Transformer, Mamba, RWKV, and other modern architectures on WikiText-103 at matched compute and fully optimized. The 883M RTX 5090 headline run scales up naturally to a B200: C : 2048 → 4096layers : 2 → 4–8mlp_expansion : 20 → 50–200pkm_num_keys &fwpkm_num_keys : 16384 → 65536 each- fp16 → FP8 via Blackwell tensor cores (NYI) The goal is a 10–15B parameter configuration, trained individually on WikiText-103 and PG-19, and also on a multi-dataset mix of WikiText-103, PG-19, Pile-ArXiv, BookCorpusOpen, TinyStories, & OpenWebText. Inference would fit on a single RTX 4090 at fp16 and roughly half the VRAM with uniform 8-bit PTQ. See runs.md for the pending run entry. Adagrad (lr=0.01) is the validated optimizer for the released model but has not been directly compared against properly-tuned alternatives. WaveletLM is matrix-parameter-heavy (MLP at expansion=20 produces Linear(2048, 40960) weights, plus per-scale mixers and lifting matrices), so Muon (Jordan et al., 2025) - which orthogonalizes matrix gradient updates via Newton-Schulz iteration and reports 1.5–2× wall-clock speedups vs AdamW on small transformers - is a strong candidate. Plan: a 2-phase sweep (1-epoch LR screening + 5-epoch finalist validation) across Adagrad, AdamW, and Muon. Even a 30% wall-clock speedup compounds across every subsequent ablation and the B200 scale-up. See plans/other_post_release_plans.md §6. The current PTQ path dequantizes int8 weights to fp16 inside forward() and runs a standard fp16 matmul, which pays t

Genesis Park 편집팀이 AI를 활용하여 작성한 분석입니다. 원문은 출처 링크를 통해 확인할 수 있습니다.

공유

관련 저널 읽기

전체 보기 →