Wander's Whisper

--'Just do something,give destiny a reason to stir.'

Flow-matching

Wander's avatar

toy demo

模仿这篇文章做的一个小demo,目标是在二维平面上将一堆高斯噪声中的100个点“挪”到sinxsinx上。 感觉最后效果很不错啊。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# 超参数
dim = 2         # 数据维度(2D点)
num_samples = 1000
num_steps = 50  # ODE求解步数
lr = 1e-3
epochs = 100000

# 目标分布:正弦曲线上的点(x1坐标)
x1_samples = torch.rand(num_samples, 1) * 4 * torch.pi  # 0到4π
y1_samples = torch.sin(x1_samples)                      # y=sin(x)
target_data = torch.cat([x1_samples, y1_samples], dim=1)

# 噪声分布:高斯噪声(x0坐标)
noise_data = torch.randn(num_samples, dim) * 2

class VectorField(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, 64),  # 输入维度: x (2) + t (1) = 3
            nn.ReLU(),
            nn.Linear(64, dim)
        )
  
    def forward(self, x, t):
        # 直接拼接x和t(t的形状需为(batch_size, 1))
        return self.net(torch.cat([x, t], dim=1))
        
model = VectorField()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    # 随机采样噪声点和目标点
    idx = torch.randperm(num_samples)
    x0 = noise_data[idx]  # 起点:噪声
    x1 = target_data[idx] # 终点:正弦曲线

    # 时间t的形状为 (batch_size, 1)
    t = torch.rand(x0.size(0), 1)  # 例如:shape (1000, 1)
  
    # 线性插值生成中间点
    xt = (1 - t) * x0 + t * x1
  
    # 模型预测向量场(直接传入t,无需squeeze)
    vt_pred = model(xt, t)  # t的维度保持不变
  
    # 目标向量场:x1 - x0
    vt_target = x1 - x0
  
    # 损失函数
    loss = torch.mean((vt_pred - vt_target)**2)
  
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch%1000 ==0:
        print(f"epoch:{epoch}/{epochs},loss:{loss}")

num_of_point=100
x = noise_data[0:num_of_point,:].reshape(-1,2)  # 初始噪声点
#trajectory = torch.empty(num_steps+1,2)
#trajectory[0,:]=x

tag = torch.from_numpy(np.array([1]))
# 数值求解ODE(欧拉法)
t = 0
delta_t = 1 / num_steps
with torch.no_grad():
    for i in range(num_steps):
        vt = model(x, torch.tensor([[t]], dtype=torch.float32).repeat(num_of_point,1))
        t += delta_t
        x = x + vt * delta_t  # x(t+Δt) = x(t) + v(t)Δt
        #trajectory[1+i]=x.reshape(-1,)

#print(trajectory[-1] / (torch.pi / 10 * 4))

# 绘制向量场和生成轨迹
plt.figure(figsize=(10, 5))
plt.scatter(target_data[:,0], target_data[:,1], c='blue', label='Target (sin(x))')
plt.scatter(noise_data[:,0], noise_data[:,1], c='green', alpha=0.3, label='Noise')
plt.scatter(x[:,0], x[:,1], c='red', label='final distribution')
plt.legend()
plt.title("Flow Matching: From Noise to Target Distribution")

plt.show()