File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import numpy as np
0003 import matplotlib.pyplot as plt
0004
0005 def visualize_model(model, device, x_min=0.001, x_max=0.5, num_points=500):
0006 model.eval()
0007
0008 x_test = np.linspace(x_min, x_max, num_points)
0009
0010 proxy_trans_test = 1 / x_test
0011
0012
0013 x_tensor = torch.tensor(proxy_trans_test, dtype=torch.float32).unsqueeze(1).to(device)
0014
0015 with torch.no_grad():
0016 y_pred = model(x_tensor).cpu().numpy().flatten()
0017
0018 plt.figure(figsize=(8,6))
0019 plt.plot(x_test, y_pred, color='red', linewidth=2)
0020 plt.xlabel("x")
0021 plt.ylabel("Model Output f(x)")
0022 plt.title("Visualization of Model Output f(x)")
0023 plt.grid(True)
0024 plt.savefig("outputFile/model_func.png")
0025