怎么提高两幅高光谱数据融合拼接后的光谱角的精度,让光谱角越来越小?这个代码的问题在哪里啊?
import cv2
import numpy as np
import spectral.io.envi as envi
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.exposure import match_histograms
from sklearn.linear_model import RANSACRegressor
import warnings
warnings.filterwarnings('ignore')
# ------------------------------
# 1. 读取高光谱数据
# ------------------------------
hdr_file1 = '01.hdr'
hdr_file2 = '02.hdr'
img1 = envi.open(hdr_file1)
img2 = envi.open(hdr_file2)
data1_raw = np.array(img1.load()) # (329, 492, 272)
data2_raw = np.array(img2.load()) # (329, 769, 272)
print("图像1形状:", data1_raw.shape)
print("图像2形状:", data2_raw.shape)
# ------------------------------
# 2. 选择最佳配准波段(角点最多)
# ------------------------------
def best_band_corner(data, max_corners=2000, quality=0.01, min_dist=10):
n_bands = data.shape[2]
corner_counts = []
for i in range(n_bands):
band = data[:, :, i]
band_norm = cv2.normalize(band, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
corners = cv2.goodFeaturesToTrack(band_norm, max_corners, quality, min_dist)
count = len(corners) if corners is not None else 0
corner_counts.append(count)
if i % 50 == 0:
print(f"波段 {i}/{n_bands} 角点数: {count}")
best_idx = np.argmax(corner_counts)
print(f"最佳配准波段索引: {best_idx}, 角点数: {corner_counts[best_idx]}")
return best_idx
best_idx1 = best_band_corner(data1_raw)
best_idx2 = best_band_corner(data2_raw)
band1 = data1_raw[:, :, best_idx1]
band2 = data2_raw[:, :, best_idx2]
# ------------------------------
# 3. 预处理:归一化+直方图均衡化(用于特征提取)
# ------------------------------
def preprocess_band(band):
band_norm = cv2.normalize(band, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
return cv2.equalizeHist(band_norm)
gray1_feat = preprocess_band(band1)
gray2_feat = preprocess_band(band2)
# ------------------------------
# 4. SIFT 特征提取
# ------------------------------
sift = cv2.SIFT_create(nfeatures=5000, contrastThreshold=0.04, edgeThreshold=10)
kp1, des1 = sift.detectAndCompute(gray1_feat, None)
kp2, des2 = sift.detectAndCompute(gray2_feat, None)
print(f"图像1关键点数: {len(kp1)}")
print(f"图像2关键点数: {len(kp2)}")
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
raise ValueError("特征点不足")
# ------------------------------
# 5. FLANN 匹配 + Lowe's ratio test
# ------------------------------
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(des1, des2, k=2)
ratio = 0.7
good_matches = []
for m, n in matches:
if m.distance < ratio * n.distance:
good_matches.append(m)
print(f"优良匹配对: {len(good_matches)}")
if len(good_matches) < 4:
ratio = 0.8
good_matches = [m for m, n in matches if m.distance < ratio * n.distance]
print(f"降低ratio后优良匹配数: {len(good_matches)}")
if len(good_matches) < 4:
raise ValueError("优良匹配不足")
# ------------------------------
# 6. RANSAC 计算单应性矩阵
# ------------------------------
src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
inliers = [good_matches[i] for i in range(len(mask)) if mask[i]]
matchesMask = mask.ravel().tolist()
print(f"RANSAC内点数量: {len(inliers)}")
# 可选:仿射变换后备
if len(inliers) < 10:
H_affine, mask_affine = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC, ransacThreshold=5.0)
if H_affine is not None:
H = np.vstack([H_affine, [0, 0, 1]])
mask = mask_affine.ravel().astype(bool)
inliers = [good_matches[i] for i in range(len(mask)) if mask[i]]
matchesMask = mask.tolist()
print(f"改用仿射变换,内点数: {len(inliers)}")
if len(inliers) < 4:
raise ValueError("内点不足")
# ------------------------------
# 7. 确定画布大小和变换参数
# ------------------------------
h1, w1 = gray1_feat.shape
h2, w2 = gray2_feat.shape
corners1 = np.float32([[0, 0], [0, h1 - 1], [w1 - 1, h1 - 1], [w1 - 1, 0]]).reshape(-1, 1, 2)
pts2 = np.float32([[0, 0], [0, h2 - 1], [w2 - 1, h2 - 1], [w2 - 1, 0]]).reshape(-1, 1, 2)
dst_pts = cv2.perspectiveTransform(pts2, H)
all_pts = np.vstack((corners1.reshape(-1, 2), dst_pts.reshape(-1, 2)))
xmin, ymin = np.int32(all_pts.min(axis=0))
xmax, ymax = np.int32(all_pts.max(axis=0))
trans = np.float32([[1, 0, -xmin], [0, 1, -ymin], [0, 0, 1]])
H_trans = trans @ H
x1_offset, y1_offset = -xmin, -ymin
# ------------------------------
# 8. 增强辐射归一化(直方图匹配 + RANSAC回归)
# ------------------------------
print("\n===== 增强辐射归一化 =====")
# 8.1 计算重叠区掩码(使用原始配准波段,避免均衡化带来的零值问题)
band1_vis = cv2.normalize(band1, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
band2_vis = cv2.normalize(band2, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
warped_temp = cv2.warpPerspective(band2_vis, H_trans, (xmax - xmin, ymax - ymin))
canvas_temp = np.zeros((ymax - ymin, xmax - xmin), dtype=np.uint8)
canvas_temp[y1_offset:y1_offset + h1, x1_offset:x1_offset + w1] = band1_vis
overlap_mask_temp = (canvas_temp > 10) & (warped_temp > 10)
print(f"重叠像素总数(阈值>10): {np.sum(overlap_mask_temp)}")
if np.sum(overlap_mask_temp) == 0:
raise ValueError("无有效重叠区域,请检查配准结果")
# 8.2 直方图匹配(整体辐射对齐)
print("正在进行直方图匹配...")
data2_histmatched = match_histograms(data2_raw, data1_raw, channel_axis=-1)
print("直方图匹配完成。")
# 8.3 获取重叠区对应像素(改进坐标映射,直接使用原始单应性矩阵H)
overlap_coords = np.argwhere(overlap_mask_temp) # (y, x) 画布坐标
max_samples = min(5000, len(overlap_coords))
if len(overlap_coords) > max_samples:
idx = np.random.choice(len(overlap_coords), max_samples, replace=False)
overlap_coords_sample = overlap_coords[idx]
else:
overlap_coords_sample = overlap_coords
sample_pts1 = [] # 图像1中的 (y1, x1)
sample_pts2 = [] # 图像2中的 (x2, y2) 浮点坐标
for yc, xc in overlap_coords_sample:
# 图像1中的坐标(整数)
x1 = xc - x1_offset
y1 = yc - y1_offset
if not (0 <= x1 < w1 and 0 <= y1 < h1):
continue
# 将图像1的点映射到图像2的原始坐标系(使用原始单应性矩阵H)
pt1 = np.array([[[x1, y1]]], dtype=np.float32)
pt2_homo = cv2.perspectiveTransform(pt1, H).flatten() # (x2, y2)
x2, y2 = pt2_homo
if not (0 <= x2 < w2 - 1 and 0 <= y2 < h2 - 1):
continue
sample_pts1.append((y1, x1))
sample_pts2.append((x2, y2))
print(f"有效采样点数量: {len(sample_pts1)}")
if len(sample_pts1) < 100:
print("采样点不足,将跳过逐波段RANSAC校正,仅使用直方图匹配结果")
data2_corrected_raw = data2_histmatched.astype(np.float32)
else:
n_bands = data1_raw.shape[2]
slopes = np.zeros(n_bands)
intercepts = np.zeros(n_bands)
for b in range(n_bands):
vals1 = np.array([data1_raw[y1, x1, b] for (y1, x1) in sample_pts1], dtype=np.float32)
vals2 = []
for (x2, y2) in sample_pts2:
x0, y0 = int(np.floor(x2)), int(np.floor(y2))
dx, dy = x2 - x0, y2 - y0
w00, w01, w10, w11 = (1 - dx) * (1 - dy), dx * (1 - dy), (1 - dx) * dy, dx * dy
# 注意:使用直方图匹配后的数据作为校正基础
val = (w00 * data2_histmatched[y0, x0, b] + w01 * data2_histmatched[y0, x0 + 1, b] +
w10 * data2_histmatched[y0 + 1, x0, b] + w11 * data2_histmatched[y0 + 1, x0 + 1, b])
vals2.append(val)
vals2 = np.array(vals2, dtype=np.float32)
if len(vals2) > 10:
try:
ransac = RANSACRegressor(residual_threshold=np.std(vals1) * 0.5, random_state=0)
ransac.fit(vals2.reshape(-1, 1), vals1)
slope = ransac.estimator_.coef_[0]
intercept = ransac.estimator_.intercept_
if b % 50 == 0:
print(f"波段 {b}: slope={slope:.4f}, intercept={intercept:.2f}")
except:
slope, intercept = 1.0, 0.0
else:
slope, intercept = 1.0, 0.0
slopes[b] = slope
intercepts[b] = intercept
# 应用校正(保留原始辐射值,未归一化)
data2_corrected_raw = np.empty_like(data2_histmatched, dtype=np.float32)
for b in range(n_bands):
data2_corrected_raw[:, :, b] = data2_histmatched[:, :, b] * slopes[b] + intercepts[b]
data2_corrected_raw = np.clip(data2_corrected_raw, 0, None)
print("增强辐射校正完成(原始辐射值)")
# 8.4 归一化到 [0,1](用于灰度图拼接和可视化)
def normalize_cube(cube):
normed = np.zeros_like(cube, dtype=np.float32)
for b in range(cube.shape[2]):
band = cube[:, :, b]
mn, mx = np.min(band), np.max(band)
if mx - mn > 1e-6:
normed[:, :, b] = (band - mn) / (mx - mn)
return normed
data1_norm = normalize_cube(data1_raw)
data2_norm = normalize_cube(data2_corrected_raw)
print("两幅数据归一化完成")
# ------------------------------
# 9. 生成灰度图并加权融合
# ------------------------------
band1_norm = data1_norm[:, :, best_idx1]
band2_norm = data2_norm[:, :, best_idx2]
gray1 = (band1_norm * 255).clip(0, 255).astype(np.uint8)
gray2 = (band2_norm * 255).clip(0, 255).astype(np.uint8)
warped = cv2.warpPerspective(gray2, H_trans, (xmax - xmin, ymax - ymin))
canvas = np.zeros((ymax - ymin, xmax - xmin), dtype=np.uint8)
canvas[y1_offset:y1_offset + h1, x1_offset:x1_offset + w1] = gray1
mask1 = (canvas > 0)
mask2 = (warped > 0)
overlap = mask1 & mask2
if np.sum(overlap) == 0:
result = np.maximum(canvas, warped)
else:
yc, xc = np.where(overlap)
x0, x1 = xc.min(), xc.max()
width = max(1, x1 - x0)
weight1 = np.zeros_like(canvas, dtype=np.float32)
weight2 = np.zeros_like(warped, dtype=np.float32)
for x in range(x0, x1 + 1):
t = (x - x0) / width
wl, wr = 1.0 - t, t
rows = yc[xc == x]
weight1[rows, x] = wl
weight2[rows, x] = wr
weight1[mask1 & ~overlap] = 1.0
weight2[mask2 & ~overlap] = 1.0
blended = canvas.astype(np.float32) * weight1 + warped.astype(np.float32) * weight2
result = np.clip(blended, 0, 255).astype(np.uint8)
# ------------------------------
# 10. 评估指标(仅重叠区域)
# ------------------------------
print("\n===== 评估指标(仅重叠区域)=====")
def in_overlap(match, kp, Ht, ov_mask):
pt = np.array(kp[match.queryIdx].pt).reshape(1, 1, 2)
pt_c = cv2.perspectiveTransform(pt, Ht).flatten()
x, y = int(round(pt_c[0])), int(round(pt_c[1]))
if 0 <= x < ov_mask.shape[1] and 0 <= y < ov_mask.shape[0]:
return ov_mask[y, x]
return False
good_in = [m for m in good_matches if in_overlap(m, kp1, H_trans, overlap_mask_temp)]
inl_in = [m for m in inliers if in_overlap(m, kp1, H_trans, overlap_mask_temp)]
print(f"重叠区内优良匹配: {len(good_in)}, 内点: {len(inl_in)}")
if good_in:
print(f"匹配正确率: {len(inl_in) / len(good_in):.4f}")
errs = []
for m in inl_in:
pt1 = np.array(kp1[m.queryIdx].pt).reshape(1, 1, 2)
pt1t = cv2.perspectiveTransform(pt1, H).flatten()
pt2a = np.array(kp2[m.trainIdx].pt)
errs.append(np.linalg.norm(pt1t - pt2a))
if errs:
print(f"CE90: {np.percentile(errs, 90):.3f} px")
print(f"CE95: {np.percentile(errs, 95):.3f} px")
if np.sum(overlap) > 0:
orig_canvas = np.zeros_like(result, dtype=np.uint8)
orig_canvas[y1_offset:y1_offset + h1, x1_offset:x1_offset + w1] = gray1
ov1 = orig_canvas[overlap]
ov2 = warped[overlap]
rmse = np.sqrt(np.mean((ov1.astype(np.float32) - ov2.astype(np.float32)) ** 2))
print(f"RMSE (灰度图): {rmse:.3f}")
rows, cols = np.where(overlap)
rmin, rmax = rows.min(), rows.max()
cmin, cmax = cols.min(), cols.max()
roi1 = orig_canvas[rmin:rmax + 1, cmin:cmax + 1]
roi2 = warped[rmin:rmax + 1, cmin:cmax + 1]
if roi1.shape[0] >= 7 and roi1.shape[1] >= 7:
ssim_val = ssim(roi1, roi2, data_range=255)
print(f"SSIM (灰度图): {ssim_val:.4f}")
else:
print("SSIM: 重叠区域过小")
print("\n===== 光谱角(基于原始辐射值)=====")
if np.sum(overlap_mask_temp) == 0:
print("无重叠区域")
else:
full_coords = np.argwhere(overlap_mask_temp)
if len(full_coords) > 5000:
full_coords = full_coords[np.random.choice(len(full_coords), 5000, replace=False)]
sa_list = []
for yc, xc in full_coords:
x1 = xc - x1_offset
y1 = yc - y1_offset
if not (0 <= x1 < w1 and 0 <= y1 < h1):
continue
spec1 = data1_raw[y1, x1, :].astype(np.float32)
# 使用原始单应性矩阵H将图像1的点映射到图像2
pt1 = np.array([[[x1, y1]]], dtype=np.float32)
x2, y2 = cv2.perspectiveTransform(pt1, H).flatten()
if not (0 <= x2 < w2 - 1 and 0 <= y2 < h2 - 1):
continue
x0, y0 = int(np.floor(x2)), int(np.floor(y2))
dx, dy = x2 - x0, y2 - y0
w00, w01, w10, w11 = (1 - dx) * (1 - dy), dx * (1 - dy), (1 - dx) * dy, dx * dy
spec2 = (w00 * data2_corrected_raw[y0, x0, :] + w01 * data2_corrected_raw[y0, x0 + 1, :] +
w10 * data2_corrected_raw[y0 + 1, x0, :] + w11 * data2_corrected_raw[y0 + 1, x0 + 1, :]).astype(
np.float32)
n1 = np.linalg.norm(spec1)
n2 = np.linalg.norm(spec2)
if n1 > 1e-6 and n2 > 1e-6:
cos_sim = np.dot(spec1, spec2) / (n1 * n2)
cos_sim = np.clip(cos_sim, -1.0, 1.0)
sa_list.append(np.arccos(cos_sim))
if sa_list:
sa_deg = np.degrees(sa_list)
print(f"平均光谱角: {np.mean(sa_deg):.2f}°")
print(f"光谱角标准差: {np.std(sa_deg):.2f}°")
print(f"中位数光谱角: {np.median(sa_deg):.2f}°")
print(f"95%分位数光谱角: {np.percentile(sa_deg, 95):.2f}°")
else:
print("有效采样点不足")
# ------------------------------
# 11. 可视化匹配结果和拼接图
# ------------------------------
draw_params = dict(matchColor=(0, 255, 0), singlePointColor=None,
matchesMask=matchesMask, flags=2)
match_img = cv2.drawMatches(gray1_feat, kp1, gray2_feat, kp2, good_matches, None, **draw_params)
plt.figure(figsize=(15, 10))
plt.subplot(2, 1, 1)
plt.imshow(match_img)
plt.title(f"SIFT Inliers: {len(inliers)}/{len(good_matches)}")
plt.axis('off')
plt.subplot(2, 1, 2)
plt.imshow(result, cmap='gray')
plt.title("Stitched Result (Enhanced Radiometric Correction)")
plt.axis('off')
plt.tight_layout()
plt.show()
# ------------------------------
# 12. 提取对应点光谱曲线并绘图对比
# ------------------------------
print("\n===== 提取对应点光谱曲线 =====")
N_points = min(5, len(sample_pts1)) # 显示最多5个点
if N_points == 0:
print("没有可用的采样点,无法绘制光谱曲线。")
else:
n_bands = data1_raw.shape[2]
wavelengths = np.arange(n_bands) # 若无真实波长,用波段序号代替
plt.figure(figsize=(15, 8))
for i in range(N_points):
y1, x1 = sample_pts1[i]
x2, y2 = sample_pts2[i]
# 图像1原始光谱
spec1 = data1_raw[y1, x1, :].astype(np.float32)
# 图像2校正后的光谱(双线性插值)
x0, y0 = int(np.floor(x2)), int(np.floor(y2))
dx, dy = x2 - x0, y2 - y0
w00, w01, w10, w11 = (1 - dx) * (1 - dy), dx * (1 - dy), (1 - dx) * dy, dx * dy
spec2 = (w00 * data2_corrected_raw[y0, x0, :] +
w01 * data2_corrected_raw[y0, x0 + 1, :] +
w10 * data2_corrected_raw[y0 + 1, x0, :] +
w11 * data2_corrected_raw[y0 + 1, x0 + 1, :]).astype(np.float32)
# 可选:归一化到 [0,1] 以比较形状(取消注释下一行即可)
# spec1 = (spec1 - spec1.min()) / (spec1.max() - spec1.min() + 1e-8)
# spec2 = (spec2 - spec2.min()) / (spec2.max() - spec2.min() + 1e-8)
plt.subplot(1, N_points, i + 1)
plt.plot(wavelengths, spec1, 'b-', linewidth=1.5, label='Image1 original')
plt.plot(wavelengths, spec2, 'r--', linewidth=1.5, label='Image2 corrected')
plt.xlabel('Band index')
plt.ylabel('Radiance (raw)')
plt.title(f'Point {i + 1} (overlap)')
plt.legend(fontsize=8)
plt.grid(True, alpha=0.3)
# 调试:打印前10个波段的值
print(f"Point {i + 1}: img1[0:10]={spec1[:10]}, img2[0:10]={spec2[:10]}")
plt.suptitle('Spectral Comparison of Corresponding Points (Original vs Corrected)')
plt.tight_layout()
plt.show()