diff --git a/adaptive_range_finder.py b/adaptive_range_finder.py new file mode 100644 index 0000000..eba56c1 --- /dev/null +++ b/adaptive_range_finder.py @@ -0,0 +1,197 @@ +import numpy as np +import torch +from typing import Union, List + +def adaptive_randomized_range_finder( + A: Union[np.ndarray, torch.Tensor], + epsilon: float, + r: int = 10 +) -> Union[np.ndarray, torch.Tensor]: + """ + 实现算法 4.2: 自适应随机化 Range Finder (PyTorch/NumPy 通用版)。 + + 该函数计算矩阵 A 的正交基 Q,使得近似误差在概率上小于 epsilon。 + 自动适配 CPU(NumPy) 或 GPU(PyTorch)。 + """ + + # --- 1. 环境检测与适配 --- + is_torch = False + device = None + dtype = None + + if isinstance(A, torch.Tensor): + is_torch = True + device = A.device + dtype = A.dtype + # 获取维度 + m, n = A.shape + else: + # NumPy 模式 + m, n = A.shape + dtype = A.dtype + + # --- 2. 辅助函数 (屏蔽框架差异) --- + def make_random(shape): + if is_torch: + return torch.randn(shape, device=device, dtype=dtype) + else: + return np.random.normal(size=shape).astype(dtype) + + def calc_norm(vec): + if is_torch: + return torch.norm(vec) + else: + return np.linalg.norm(vec) + + def calc_dot(v1, v2): + if is_torch: + return torch.dot(v1, v2) + else: + return np.dot(v1, v2) + + def mat_mul_vec(mat, vec): + # 矩阵乘向量 + return mat @ vec + + # --- 步骤 1: 初始化 --- + # Draw standard Gaussian vectors omega^(1)...omega^(r) + Omega = make_random((n, r)) + + # --- 步骤 2: 初始采样 --- + # Compute Y = A * Omega + # 注意:为了保持动态特性,我们用列表存储向量 + Y = [] + for i in range(r): + # 取出第 i 列 + omega_col = Omega[:, i] + y_col = mat_mul_vec(A, omega_col) + Y.append(y_col) + + # --- 步骤 3 & 4: 初始化循环变量 --- + j = 0 + Q = [] # 存放正交基向量 + + # 计算阈值 limit + # np.sqrt(2 / np.pi) 约等于 0.798 + const_factor = 0.79788456 + limit = epsilon / (10 * const_factor) + + # --- 步骤 5: While 循环 --- + # 只要前瞻窗口内的向量能量还很大,就继续寻找 + while True: + # 检查是否越界 (防止极其罕见的无限循环) + if j >= n: + break + + # 获取当前窗口内的向量 Y[j : j+r] + # 如果窗口超出了 Y 的当前长度,说明需要生成新的 (虽然后面的逻辑会生成,但这里做个防守) + current_window = Y[j : j+r] + if not current_window: + break + + # 计算窗口内每个向量的范数 + norms = [calc_norm(y).item() for y in current_window] # .item() 转为 python float 比较 + max_norm = max(norms) + + # 停止条件 + if max_norm <= limit: + break + + # --- 步骤 7: 投影 (Gram-Schmidt) --- + # 这里的 Y[j] 实际上已经被之前的 Q 正交化过了(在步骤13), + # 但为了数值稳定性,或者如果是第一轮,我们需要确保它正交。 + y_current = Y[j] + + # Double Orthogonalization (数值稳定性关键) + for _ in range(2): # 做两次以防万一,通常一次也够 + for q_prev in Q: + projection = calc_dot(q_prev, y_current) + y_current = y_current - q_prev * projection + + # --- 步骤 8: 归一化 --- + norm_y = calc_norm(y_current) + + if norm_y < 1e-15: + # 线性相关,跳过 + j += 1 + continue + + q_new = y_current / norm_y + Q.append(q_new) + + # --- 步骤 10: 生成新的高斯向量 --- + omega_new = make_random((n,)) + + # --- 步骤 11: 计算新样本 --- + # y_new = (I - Q Q*) A omega_new + # 先算 A * omega + y_new = mat_mul_vec(A, omega_new) + + # 立即对现有的 Q 进行正交化 + for q in Q: + y_new = y_new - q * calc_dot(q, y_new) + + Y.append(y_new) + + # --- 步骤 12 & 13: 更新前瞻窗口内的向量 --- + # Y[i] = Y[i] - q_new * + # 范围: j+1 到 j+r (注意 Python切片是左闭右开,但这里不仅是切片,是由于 append 导致 len 增加) + # 我们只需要更新目前列表中位于 j 之后的所有向量 + for i in range(j + 1, len(Y)): + proj = calc_dot(q_new, Y[i]) + Y[i] = Y[i] - q_new * proj + + j += 1 + + # --- 步骤 16: 构建最终矩阵 --- + if not Q: + # 返回空矩阵 + if is_torch: + return torch.zeros((m, 0), device=device, dtype=dtype) + else: + return np.zeros((m, 0), dtype=dtype) + + # 堆叠结果 + if is_torch: + Q_matrix = torch.stack(Q, dim=1) + else: + Q_matrix = np.column_stack(Q) + + return Q_matrix +# --- 单元测试/用法示例 --- +if __name__ == "__main__": + # 1. 创建一个具有特定秩的合成矩阵来测试 + # 假设 m=1000, n=100, 真实秩=10 + np.random.seed(42) # 固定随机种子以复现结果 + m, n = 1000, 100 + true_rank = 10 + + # 构造低秩矩阵 A = U * S * V.T + U_true, _ = np.linalg.qr(np.random.normal(size=(m, true_rank))) + V_true, _ = np.linalg.qr(np.random.normal(size=(n, true_rank))) + S_true = np.diag(np.linspace(10, 1, true_rank)) # 奇异值从 10 降到 1 + A = U_true @ S_true @ V_true.T + + print(f"原始矩阵形状: {A.shape}, 真实秩: {true_rank}") + + # 2. 运行算法 + target_epsilon = 1e-2 + Q_approx = adaptive_randomized_range_finder(A, epsilon=target_epsilon) + + # 3. 验证结果 + found_rank = Q_approx.shape[1] + print(f"算法计算出的秩 (Q的列数): {found_rank}") + + # 4. 验证近似误差 || (I - QQ*)A || + # I - QQ* 是投影到 Q 正交补空间的算子 + # 也就是 A 减去它在 Q 上的投影: A - Q(Q*A) + diff = A - Q_approx @ (Q_approx.T @ A) + error_norm = np.linalg.norm(diff, ord=2) # 谱范数 + + print(f"近似误差 (Spectral Norm): {error_norm:.6f}") + print(f"目标误差: {target_epsilon}") + + if error_norm < target_epsilon * 10: # 允许一定的随机浮动 + print(">> 测试通过:误差在可接受范围内。") + else: + print(">> 测试警告:误差偏大,请检查参数。") \ No newline at end of file diff --git a/colab.ipynb b/colab.ipynb new file mode 100644 index 0000000..1e0027e --- /dev/null +++ b/colab.ipynb @@ -0,0 +1,438 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "l-fJd1vOXhyE", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "edd98885-16b4-4f3b-fc23-811e81f61f1a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'DiffStateGrad'...\n", + "remote: Enumerating objects: 502, done.\u001b[K\n", + "remote: Counting objects: 100% (502/502), done.\u001b[K\n", + "remote: Compressing objects: 100% (382/382), done.\u001b[K\n", + "remote: Total 502 (delta 205), reused 329 (delta 102), pack-reused 0 (from 0)\u001b[K\n", + "Receiving objects: 100% (502/502), 18.86 MiB | 21.56 MiB/s, done.\n", + "Resolving deltas: 100% (205/205), done.\n" + ] + } + ], + "source": [ + "!git clone https://github.com/rzirvi1665/DiffStateGrad.git" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "0YkGfhsIX7wz", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "43b1789a-61f5-45b6-857f-77c77a2bfdb4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/DiffStateGrad\n" + ] + } + ], + "source": [ + "cd /content/DiffStateGrad" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "gcyW21fxX-_J" + }, + "outputs": [], + "source": [ + "!mkdir -p models/ldm" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "OS1g6QlIYHBl", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "17f5123c-3995-4e81-d466-9dde048c06a9" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DEBUG output created by Wget 1.21.2 on linux-gnu.\n", + "\n", + "Reading HSTS entries from /root/.wget-hsts\n", + "URI encoding = ‘UTF-8’\n", + "Converted file name 'ffhq.zip' (UTF-8) -> 'ffhq.zip' (UTF-8)\n", + "--2026-02-01 09:53:52-- https://ommer-lab.com/files/latent-diffusion/ffhq.zip\n", + "Resolving ommer-lab.com (ommer-lab.com)... 141.84.41.65\n", + "Caching ommer-lab.com => 141.84.41.65\n", + "Connecting to ommer-lab.com (ommer-lab.com)|141.84.41.65|:443... connected.\n", + "Created socket 3.\n", + "Releasing 0x00005c9b3e555010 (new refcount 1).\n", + "Initiating SSL handshake.\n", + "Handshake successful; connected socket 3 to SSL handle 0x00005c9b3e556b10\n", + "certificate:\n", + " subject: CN=ommer-lab.com\n", + " issuer: CN=R12,O=Let's Encrypt,C=US\n", + "X509 certificate successfully verified and matches host ommer-lab.com\n", + "\n", + "---request begin---\n", + "GET /files/latent-diffusion/ffhq.zip HTTP/1.1\n", + "Host: ommer-lab.com\n", + "User-Agent: Wget/1.21.2\n", + "Accept: */*\n", + "Accept-Encoding: identity\n", + "Connection: Keep-Alive\n", + "\n", + "---request end---\n", + "HTTP request sent, awaiting response... \n", + "---response begin---\n", + "HTTP/1.1 200 OK\n", + "Date: Sun, 01 Feb 2026 09:53:52 GMT\n", + "Server: Apache/2.4.52 (Ubuntu)\n", + "Last-Modified: Mon, 21 Feb 2022 11:25:33 GMT\n", + "ETag: \"85777df6-5d8857da75fd3\"\n", + "Accept-Ranges: bytes\n", + "Content-Length: 2239200758\n", + "Keep-Alive: timeout=5, max=100\n", + "Connection: Keep-Alive\n", + "Content-Type: application/zip\n", + "\n", + "---response end---\n", + "200 OK\n", + "Registered socket 3 for persistent reuse.\n", + "Length: 2239200758 (2.1G) [application/zip]\n", + "Saving to: ‘./models/ldm/ffhq.zip’\n", + "\n", + "ffhq.zip 100%[===================>] 2.08G 22.1MB/s in 1m 58s \n", + "\n", + "2026-02-01 09:55:50 (18.1 MB/s) - ‘./models/ldm/ffhq.zip’ saved [2239200758/2239200758]\n", + "\n", + "URI encoding = ‘UTF-8’\n", + "Converted file name 'index.html' (UTF-8) -> 'index.html' (UTF-8)\n", + "--2026-02-01 09:55:50-- http://unzip/\n", + "Resolving unzip (unzip)... failed: Name or service not known.\n", + "wget: unable to resolve host address ‘unzip’\n", + "URI encoding = ‘UTF-8’\n", + "Converted file name 'ffhq.zip' (UTF-8) -> 'ffhq.zip' (UTF-8)\n", + "--2026-02-01 09:55:50-- http://models/ldm/ffhq.zip\n", + "Resolving models (models)... failed: Name or service not known.\n", + "wget: unable to resolve host address ‘models’\n", + "URI encoding = ‘UTF-8’\n", + "Converted file name 'ldm' (UTF-8) -> 'ldm' (UTF-8)\n", + "--2026-02-01 09:55:50-- http://./models/ldm\n", + "Resolving . (.)... failed: No address associated with hostname.\n", + "wget: unable to resolve host address ‘.’\n", + "FINISHED --2026-02-01 09:55:50--\n", + "Total wall clock time: 1m 59s\n", + "Downloaded: 1 files, 2.1G in 1m 58s (18.1 MB/s)\n" + ] + } + ], + "source": [ + "!wget https://ommer-lab.com/files/latent-diffusion/ffhq.zip -P ./models/ldm unzip models/ldm/ffhq.zip -d ./models/ldm" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "OP3DAnw_arPQ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "9b6772ce-4d15-4375-978e-2943cb7ed562" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Archive: models/ldm/ffhq.zip\n", + " inflating: ./models/ldm/model.ckpt \n" + ] + } + ], + "source": [ + "!unzip models/ldm/ffhq.zip -d ./models/ldm" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "6lfhOedYY4Cw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "496a8a63-9e4b-4230-8d69-9fbb49b9a566" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2026-02-01 09:56:21-- https://ommer-lab.com/files/latent-diffusion/vq-f4.zip\n", + "Resolving ommer-lab.com (ommer-lab.com)... 141.84.41.65\n", + "Connecting to ommer-lab.com (ommer-lab.com)|141.84.41.65|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 696655056 (664M) [application/zip]\n", + "Saving to: ‘./models/first_stage_models/vq-f4/vq-f4.zip’\n", + "\n", + "vq-f4.zip 100%[===================>] 664.38M 21.9MB/s in 38s \n", + "\n", + "2026-02-01 09:57:00 (17.4 MB/s) - ‘./models/first_stage_models/vq-f4/vq-f4.zip’ saved [696655056/696655056]\n", + "\n", + "Archive: models/first_stage_models/vq-f4/vq-f4.zip\n", + " inflating: ./models/first_stage_models/vq-f4/model.ckpt \n" + ] + } + ], + "source": [ + "!mkdir -p models/first_stage_models/vq-f4\n", + "!wget https://ommer-lab.com/files/latent-diffusion/vq-f4.zip -P ./models/first_stage_models/vq-f4\n", + "!unzip models/first_stage_models/vq-f4/vq-f4.zip -d ./models/first_stage_models/vq-f4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2IuKi9__amlG" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "mCkKUFupZugA", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6836aa22-f75e-4ab1-e4bb-39a511b5558a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'bkse'...\n", + "remote: Enumerating objects: 577, done.\u001b[K\n", + "remote: Counting objects: 100% (577/577), done.\u001b[K\n", + "remote: Compressing objects: 100% (328/328), done.\u001b[K\n", + "remote: Total 577 (delta 334), reused 461 (delta 232), pack-reused 0 (from 0)\u001b[K\n", + "Receiving objects: 100% (577/577), 1.05 MiB | 4.42 MiB/s, done.\n", + "Resolving deltas: 100% (334/334), done.\n", + "Cloning into 'motionblur'...\n", + "remote: Enumerating objects: 36, done.\u001b[K\n", + "remote: Total 36 (delta 0), reused 0 (delta 0), pack-reused 36 (from 1)\u001b[K\n", + "Receiving objects: 100% (36/36), 511.08 KiB | 2.19 MiB/s, done.\n", + "Resolving deltas: 100% (12/12), done.\n" + ] + } + ], + "source": [ + "!git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse\n", + "\n", + "!git clone https://github.com/LeviBorodenko/motionblur motionblur" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uXCv0ulNZysA" + }, + "outputs": [], + "source": [ + "Install dependencies via" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "axQF9EtCZz_N" + }, + "outputs": [], + "source": [ + "!conda env create -f environment.yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E4Zj_xeBaOc6" + }, + "outputs": [], + "source": [ + "!pip install lpips" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Xci7NlggaVg_" + }, + "outputs": [], + "source": [ + "!pip install pytorch_lightning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EjyXS3fLb5Q0" + }, + "outputs": [], + "source": [ + "!pip install \"pip<24.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0P2Dvwk4cR2-" + }, + "outputs": [], + "source": [ + "!pip install pytorch-lightning==1.7.7" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Eh62WLWlcvAY" + }, + "outputs": [], + "source": [ + "!pip install taming-transformers==0.0.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ewDIMS8de7j" + }, + "outputs": [], + "source": [ + "!pip install taming-transformers==0.0.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s03oh4z0cjRs" + }, + "outputs": [], + "source": [ + "!pip install torchmetrics==0.9.3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h53FwWZ_aHrb", + "outputId": "c338b23b-3f83-4762-a314-7879618c00ec" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2026-02-01 10:04:01.171669: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1769940241.191691 10657 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1769940241.197727 10657 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "W0000 00:00:1769940241.213039 10657 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1769940241.213063 10657 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1769940241.213067 10657 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1769940241.213072 10657 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "2026-02-01 10:04:01.217894: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "Global seed set to 42\n", + "Device set to cuda:0.\n", + "Loading model from models/ldm/model.ckpt\n", + "LatentDiffusion: Running in eps-prediction mode\n", + "DiffusionWrapper has 274.06 M params.\n", + "Keeping EMAs of 370.\n", + "making attention of type 'vanilla' with 512 in_channels\n", + "Working with z of shape (1, 3, 64, 64) = 12288 dimensions.\n", + "making attention of type 'vanilla' with 512 in_channels\n", + "Restored from models/first_stage_models/vq-f4/model.ckpt with 0 missing and 55 unexpected keys\n", + "Training LatentDiffusion as an unconditional model.\n", + "Operation: gaussian_blur / Noise: gaussian\n", + "Conditioning sampler : resample\n", + "Inference for image 60004\n", + "Data shape for DDIM sampling is (1, 3, 64, 64), eta 0.0\n", + "DDIM Sampler: 38% 189/500 [02:00<03:44, 1.39it/s]" + ] + } + ], + "source": [ + "!python3 diffstategrad_sample_condition.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "authorship_tag": "ABX9TyOcfrP0A7K8uCc6Us91iHvx", + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/diffstategrad_sample_condition.py b/diffstategrad_sample_condition.py index a45ed6f..9c2954f 100644 --- a/diffstategrad_sample_condition.py +++ b/diffstategrad_sample_condition.py @@ -144,7 +144,7 @@ def make_folder(sample_path, opt): parser.add_argument('--ddim_eta', default=0.0, type=float) parser.add_argument('--n_samples_per_class', default=1, type=int) parser.add_argument('--ddim_scale', default=1.0, type=float) -parser.add_argument('--image_id', default=60000, type=int) +parser.add_argument('--image_id', default=60004, type=int) parser.add_argument('--var_cutoff', default=0.99, type=float) parser.add_argument('--pixel_lr', default=1e-2, type=float) parser.add_argument('--latent_lr', default=5e-3, type=float) diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 6a9c4f4..028ebb7 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from contextlib import contextmanager -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from taming.modules.vqvae.quantize import VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution @@ -37,8 +37,7 @@ def __init__(self, self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) + ) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if colorize_nlabels is not None: @@ -76,7 +75,7 @@ def ema_scope(self, context=None): print(f"{context}: Restored training weights") def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] + sd = torch.load(path, map_location="cpu",weights_only=False)["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 18383ec..946c9ff 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -23,7 +23,7 @@ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.diffstategrad_ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', diff --git a/ldm/models/diffusion/diffstategrad_ddim.py b/ldm/models/diffusion/diffstategrad_ddim.py index 3170ba5..0750910 100644 --- a/ldm/models/diffusion/diffstategrad_ddim.py +++ b/ldm/models/diffusion/diffstategrad_ddim.py @@ -31,7 +31,7 @@ def compute_rank_for_explained_variance(singular_values, explained_variance_cuto rank = np.searchsorted(cumulative_variance, explained_variance_cutoff) + 1 total_rank += rank return int(total_rank / 3) - +import time def compute_svd_and_adaptive_rank(z_t, var_cutoff): """ Compute SVD and adaptive rank for the input tensor. @@ -44,14 +44,33 @@ def compute_svd_and_adaptive_rank(z_t, var_cutoff): tuple: (U, s, Vh, adaptive_rank) where U, s, Vh are SVD components and adaptive_rank is the computed rank """ - # Compute SVD of current image representation + + # 1. begin + start_time = time.perf_counter() + + # 2. run U, s, Vh = torch.linalg.svd(z_t[0], full_matrices=False) + # 3. end + end_time = time.perf_counter() + + # 4. minus + time1 = end_time - start_time + + + # Compute SVD of current image representation + # implement of rSVD + start_time = time.perf_counter() + U, sb, Vh = randomized_svd(z_t[0],epsilon = 0.1) + end_time = time.perf_counter() + time2 = end_time - start_time + excutive_time = time1 - time2 + print(f"relative time(svd-rsvd): {execution_time:.6f} 秒") # Compute adaptive rank s_numpy = s.detach().cpu().numpy() adaptive_rank = compute_rank_for_explained_variance([s_numpy], var_cutoff) - + print("rank:",adaptive_rank) return U, s, Vh, adaptive_rank def apply_diffstategrad(norm_grad, iteration_count, period, U=None, s=None, Vh=None, adaptive_rank=None): diff --git a/randomized_svd.py b/randomized_svd.py new file mode 100644 index 0000000..e44243e --- /dev/null +++ b/randomized_svd.py @@ -0,0 +1,162 @@ +import numpy as np +import torch +from typing import Tuple, Union + +# --- 模块化导入 --- +from adaptive_range_finder import adaptive_randomized_range_finder + +def _randomized_svd_2d_padded( + A: Union[np.ndarray, torch.Tensor], + epsilon: float +) -> Tuple[Union[np.ndarray, torch.Tensor], ...]: + """ + 内部函数:执行自适应 SVD,并支持 PyTorch/NumPy 自动切换和零填充。 + """ + # 1. 检测环境 + is_torch = isinstance(A, torch.Tensor) + + # 获取形状 + m, n = A.shape + min_dim = min(m, n) + + # 2. 自适应计算 Range (这一步会调用我们刚修好的 adaptive_range_finder) + # Q 的类型会和 A 保持一致 (GPU Tensor 或 NumPy) + Q = adaptive_randomized_range_finder(A, epsilon=epsilon) + + # 计算 B = Q.T * A + # PyTorch 和 NumPy 都支持 @ 运算符 + B = Q.T @ A + + # 3. 对小矩阵 B 进行标准 SVD (区分框架) + if is_torch: + # PyTorch 路径 + # S_hat: (k, k), Sigma: (k,), Vt: (k, n) + # 注意:torch.linalg.svd 返回的 U 对应这里的 S_hat + S_hat, Sigma_small, Vt_small = torch.linalg.svd(B, full_matrices=False) + + # 还原 U_small = Q @ S_hat + U_small = Q @ S_hat + + # 获取当前秩 k + k = Sigma_small.shape[0] + + # --- Padding (PyTorch) --- + if k < min_dim: + # 补全 S + Sigma_final = torch.zeros(min_dim, dtype=A.dtype, device=A.device) + Sigma_final[:k] = Sigma_small + + # 补全 U + U_final = torch.zeros((m, min_dim), dtype=A.dtype, device=A.device) + U_final[:, :k] = U_small + + # 补全 Vt + Vt_final = torch.zeros((min_dim, n), dtype=A.dtype, device=A.device) + Vt_final[:k, :] = Vt_small + + return U_final, Sigma_final, Vt_final + else: + # 截断(防止 k > min_dim 的浮点误差情况) + return U_small[:, :min_dim], Sigma_small[:min_dim], Vt_small[:min_dim, :] + + else: + # NumPy 路径 (保持原有逻辑) + S_hat, Sigma_small, Vt_small = np.linalg.svd(B, full_matrices=False) + U_small = Q @ S_hat + k = Sigma_small.shape[0] + + if k < min_dim: + Sigma_final = np.zeros(min_dim, dtype=A.dtype) + Sigma_final[:k] = Sigma_small + + U_final = np.zeros((m, min_dim), dtype=A.dtype) + U_final[:, :k] = U_small + + Vt_final = np.zeros((min_dim, n), dtype=A.dtype) + Vt_final[:k, :] = Vt_small + return U_final, Sigma_final, Vt_final + else: + return U_small[:, :min_dim], Sigma_small[:min_dim], Vt_small[:min_dim, :] + +def randomized_svd( + data: Union[np.ndarray, torch.Tensor], + epsilon: float = 1e-2 +) -> Tuple[Union[np.ndarray, torch.Tensor], ...]: + """ + 实现算法 5.1: 逐通道随机化 SVD (支持 Batch/Channel-wise)。 + 完全兼容 PyTorch GPU Tensor 流水线,无需 CPU 转换。 + + 输出维度 (假设输入 3, 64, 64): + U: (3, 64, 64) + S: (3, 64) (零填充对齐) + Vh: (3, 64, 64) + """ + + # 1. 基础信息获取 + is_torch = isinstance(data, torch.Tensor) + input_shape = data.shape + + # 2. 逐通道处理逻辑 + if len(input_shape) == 3: + # (C, H, W) 模式 + C, H, W = input_shape + min_dim = min(H, W) + + # 准备容器 + if is_torch: + # 直接在 GPU 上分配内存 + U_batch = torch.zeros((C, H, min_dim), dtype=data.dtype, device=data.device) + S_batch = torch.zeros((C, min_dim), dtype=data.dtype, device=data.device) + Vt_batch = torch.zeros((C, min_dim, W), dtype=data.dtype, device=data.device) + else: + U_batch = np.zeros((C, H, min_dim), dtype=data.dtype) + S_batch = np.zeros((C, min_dim), dtype=data.dtype) + Vt_batch = np.zeros((C, min_dim, W), dtype=data.dtype) + + for i in range(C): + # 取出单个通道 (保持 Tensor 属性) + # data[i] 依然是 GPU tensor + u, s, vt = _randomized_svd_2d_padded(data[i], epsilon) + + U_batch[i] = u + S_batch[i] = s + Vt_batch[i] = vt + + return U_batch, S_batch, Vt_batch + + elif len(input_shape) == 2: + # 2D 模式直接调用 + return _randomized_svd_2d_padded(data, epsilon) + + else: + raise ValueError(f"仅支持 2D 或 3D 输入,当前形状: {input_shape}") + +# --- 验证代码 (确保 GPU 流程通畅) --- +if __name__ == "__main__": + if torch.cuda.is_available(): + print("正在测试 CUDA GPU 模式...") + device = "cuda:0" + + # 1. 创建 GPU 数据 (3, 64, 64) + # 模拟真实秩 rank=10 + rank = 10 + U = torch.randn(3, 64, rank, device=device) + S = torch.randn(3, rank, device=device) + V = torch.randn(3, rank, 64, device=device) + z_t = U @ torch.diag_embed(S) @ V + + print(f"输入数据位于: {z_t.device}") + + # 2. 运行算法 + # 期望:没有任何报错,且输出依然在 GPU 上 + U_out, S_out, Vh_out = randomized_svd(z_t, epsilon=1e-2) + + print(f"输出 U 位于: {U_out.device}") + print(f"输出形状: {U_out.shape}, {S_out.shape}, {Vh_out.shape}") + + if U_out.is_cuda: + print("✅ 测试通过:全链路 GPU 计算成功!") + else: + print("❌ 测试失败:数据回落到了 CPU。") + else: + print("未检测到 GPU,跳过 GPU 测试。") \ No newline at end of file