From 591780d79e1ab812f93419fbb996e40c01b37e2e Mon Sep 17 00:00:00 2001
From: Hanzhang Ma <hanzhang@plunder.dbs.ifi.lmu.de>
Date: Thu, 9 May 2024 13:42:22 +0200
Subject: [PATCH] init diffusion

---
 main.ipynb | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 171 insertions(+)
 create mode 100644 main.ipynb

diff --git a/main.ipynb b/main.ipynb
new file mode 100644
index 0000000..8b7714e
--- /dev/null
+++ b/main.ipynb
@@ -0,0 +1,171 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Network Helper\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch.nn as nn\n",
+    "import inspect\n",
+    "import torch\n",
+    "import math"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def exists(x):\n",
+    "    return x is not None\n",
+    "\n",
+    "def default(val, d):\n",
+    "    if exists(val):\n",
+    "        return val\n",
+    "    return d() if inspect.isfunction(d) else d\n",
+    "\n",
+    "class Residual(nn.Module):\n",
+    "    def __init__(self, fn):\n",
+    "        super().__init__()\n",
+    "        self.fn = fn\n",
+    "\n",
+    "    def forward(self, x, *args, **kwargs):\n",
+    "        return self.fn(x, *args, **kwargs) + x\n",
+    "\n",
+    "# 上采样(反卷积)\n",
+    "def Upsample(dim):\n",
+    "    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n",
+    "\n",
+    "def Downsample(dim):\n",
+    "    return nn.Conv2d(dim, dim, 4, 2 ,1)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Positional embedding\n",
+    "\n",
+    "目的是让网络知道\n",
+    "当前是哪一个step. \n",
+    "ddpm采用正弦位置编码\n",
+    "\n",
+    "输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n",
+    "这个tensor会被加到每一个残差模块中.\n",
+    "\n",
+    "总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class SinusolidalPositionEmbedding(nn.Module):\n",
+    "    def __init__(self, dim):\n",
+    "        super().__init__()\n",
+    "        self.dim = dim\n",
+    "\n",
+    "    def forward(self, time):\n",
+    "        device = time.device\n",
+    "        half_dim = self.dim // 2\n",
+    "        embeddings = math.log(10000) / (half_dim - 1)\n",
+    "        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n",
+    "        embeddings = time[:, :, None] * embeddings[None, None, :]\n",
+    "        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n",
+    "        return embeddings\n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## ResNet/ConvNeXT block"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Block(nn.Module):\n",
+    "    def __init__(self, dim, dim_out, groups = 8):\n",
+    "        super().__init__()\n",
+    "        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n",
+    "        self.norm = nn.GroupNorm(groups, dim_out)\n",
+    "        self.act = nn.SiLU()\n",
+    "    \n",
+    "    def forward(self, x, scale_shift = None):\n",
+    "        x = self.proj(x)\n",
+    "        x = self.norm(x)\n",
+    "\n",
+    "        if exists(scale_shift):\n",
+    "            scale, shift = scale_shift\n",
+    "            x = x * (scale + 1) + shift\n",
+    "\n",
+    "        x = self.act(x)\n",
+    "        return x\n",
+    "\n",
+    "class ResnetBlock(nn.Module):\n",
+    "    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n",
+    "        super().__init__()\n",
+    "        self.mlp = (\n",
+    "            nn.Sequential(\n",
+    "                nn.SiLU(), \n",
+    "                nn.Linear(time_emb_dim, dim_out)\n",
+    "            )\n",
+    "            if exists(time_emb_dim) else None\n",
+    "        )\n",
+    "        self.block1 = Block(dim, dim_out, groups=groups)\n",
+    "        self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n",
+    "        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
+    "    \n",
+    "    def forward(self, x, time_emb = None):\n",
+    "        h = self.block1(x)\n",
+    "\n",
+    "        if exists(self.mlp) and exists(time_emb):\n",
+    "            time_emb = self.mlp(time_emb)\n",
+    "            h = rearrange(time_emb, 'b n -> b () n') + h\n",
+    "\n",
+    "        h = self.block2(h)\n",
+    "        return h + self.res_conv(x)\n",
+    "    \n",
+    "    "
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "arch2vec39",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}