玖叭玖 2020-06-30 20:47 采纳率: 100%
浏览 1690
已采纳

求助!求大佬帮忙用Matlab编写一段程序实现计算两个分布的Wasserstein距离

如题,需要用到Wassertein距离计算两个离散点分布之间的距离,麻烦懂的大佬动下小手帮忙做做。

可偿!可偿!可偿!

如数据集1{45,46,46,49,50,47,49,41,52,46};数据集2{49,47,42,38,53,42,47,41,45,50}。

用Matlab编写。谢谢!!!

  • 写回答

1条回答 默认 最新

  • i-Data 2020-07-01 15:44
    关注

    源代码如下:

    function wsd = ws_distance(u_samples, v_samples, p)
    % WS_DISTANCE 1- and 2- Wasserstein distance between two discrete 
    % probability measures 
    %   
    %   wsd = WS_DISTANCE(u_samples, v_samples) returns the 1-Wasserstein 
    %   distance between the discrete probability measures u and v 
    %   corresponding to the sample vectors u_samples and v_samples
    %
    %   wsd = WS_DISTANCE(u_samples, v_samples, p) returns the p-Wasserstein 
    %   distance between the discrete probability measures u and v
    %   corresponding to the sample vectors u_samples and v_samples. 
    %   p must be 1 or 2.
    %
    % from https://github.com/nklb/wasserstein-distance
    if ~exist('p', 'var')
        p = 1;
    end
    u_samples_sorted = sort(u_samples(:));
    v_samples_sorted = sort(v_samples(:));
    if p == 1
    
        all_samples = unique([u_samples_sorted; v_samples_sorted], 'sorted');
    
        u_cdf = find_interval(u_samples_sorted, all_samples(1:end-1)) ...
            / numel(u_samples);
        v_cdf = find_interval(v_samples_sorted, all_samples(1:end-1)) ...
            / numel(v_samples);
    
        wsd = sum(abs(u_cdf - v_cdf) .* diff(all_samples));
    
    elseif p == 2
    
        u_N = numel(u_samples);
        v_N = numel(v_samples);    
        all_prob = unique([(0:u_N) / u_N, (0:v_N) / v_N], 'sorted').';
    
        u_icdf = u_samples_sorted(fix(all_prob(1:end-1) * u_N) + 1);
        v_icdf = v_samples_sorted(fix(all_prob(1:end-1) * v_N) + 1);
    
        wsd = sqrt(sum((u_icdf-v_icdf).^2 .* diff(all_prob)));
    
    else
    
        error('Only p=1 or p=2 allowed.')
    
    end
    end
    
    
    %注意这里是第二部分调用函数
    function idx = find_interval(bounds, vals)
    % Given the two sorted arrays bounds and vals, the function 
    % idx = FIND_INTERVAL(bounds, vals) identifies for each vals(i) the index 
    % idx(i) s.t. bounds(idx(i)) <= vals(i) < bounds(idx(i) + 1).
    m = 0;
    bounds = [bounds(:); inf];
    idx = zeros(numel(vals), 1);
    for i = 1:numel(vals)
        while bounds(m+1) <= vals(i)
            m = m + 1;
        end
        idx(i) = m;
    end
    end
    

    运行以下:

    wsd = ws_distance([45,46,46,49,50,47,49,41,52,46], [49,47,42,38,53,42,47,41,45,50], 1)
    

    我现在电脑上没有安装matlab,所以还没实际运行过,不过我用Python算了下结果是1.9,你可以用matlab运行确认下。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!
  • ¥15 drone 推送镜像时候 purge: true 推送完毕后没有删除对应的镜像,手动拷贝到服务器执行结果正确在样才能让指令自动执行成功删除对应镜像,如何解决?