weixin_45750394 2021-09-27 10:17 采纳率: 100%
浏览 24
已结题

如何将这段tensorflow代码转化为pytorch ?

def _soft_assignment(self, embeddings, cluster_centers):
        """Implemented a soft assignment as the  probability of assigning sample i to cluster j.
        
        Args:
            embeddings: (num_points, dim)
            cluster_centers: (num_cluster, dim)
            
        Return:
            q_i_j: (num_points, num_cluster)
        """
        def _pairwise_euclidean_distance(a,b):
            p1 = tf.matmul(
                tf.expand_dims(tf.reduce_sum(tf.square(a), 1), 1),
                tf.ones(shape=(1, self.n_cluster))
            )
            p2 = tf.transpose(tf.matmul(
                tf.reshape(tf.reduce_sum(tf.square(b), 1), shape=[-1, 1]),
                tf.ones(shape=(self.ae.input_batch_size, 1)),
                transpose_b=True
            ))
            res = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(a, b, transpose_b=True))
            return res

        dist = _pairwise_euclidean_distance(embeddings, cluster_centers)
        q = 1.0/(1.0+dist**2/self.alpha)**((self.alpha+1.0)/2.0)
        q = (q/tf.reduce_sum(q, axis=1, keepdims=True))
        return q

我自己转化不知道是否转化正确了:

def soft_assignment(self, embeddings, cluster_centers):
    """Implemented a soft assignment as the  probability of assigning sample i to cluster j.

    Args:
        embeddings: (num_points, dim)
        cluster_centers: (num_cluster, dim)

    Return:
        q_i_j: (num_points, num_cluster)
    """

    n_obj = embeddings.shape[0]
    n_clusters = cluster_centers.shape[0]

    def _pairwise_euclidean_distance(a, b, n_obj, n_clusters):
        p1 = torch.matmul(
            torch.unsqueeze(torch.sum(torch.mul(a,a), 1), 1),
            torch.ones(1, n_clusters).to(self.device)
        )
        p2 = torch.transpose(torch.matmul(
            torch.reshape(torch.sum(torch.mul(b,b), 1), (-1, 1)),
            torch.ones(1, n_obj).to(self.device),
        ), 0, 1)
        res = torch.sqrt(torch.add(p1, p2) - 2 * torch.matmul(a, torch.transpose(b, 0, 1)))
        return res

    dist = _pairwise_euclidean_distance(embeddings, cluster_centers, n_obj, n_clusters)
    q = 1.0 / (1.0 + dist ** 2 / self.alpha) ** ((self.alpha + 1.0) / 2.0)
    q = (q / torch.sum(q, dim=1, keepdim=True))
    return q
  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 10月5日
    • 修改了问题 9月27日
    • 修改了问题 9月27日
    • 修改了问题 9月27日
    • 展开全部

    悬赏问题

    • ¥15 逻辑谓词和消解原理的运用
    • ¥15 三菱伺服电机按启动按钮有使能但不动作
    • ¥15 js,页面2返回页面1时定位进入的设备
    • ¥200 关于#c++#的问题,请各位专家解答!网站的邀请码
    • ¥50 导入文件到网吧的电脑并且在重启之后不会被恢复
    • ¥15 (希望可以解决问题)ma和mb文件无法正常打开,打开后是空白,但是有正常内存占用,但可以在打开Maya应用程序后打开场景ma和mb格式。
    • ¥20 ML307A在使用AT命令连接EMQX平台的MQTT时被拒绝
    • ¥20 腾讯企业邮箱邮件可以恢复么
    • ¥15 有人知道怎么将自己的迁移策略布到edgecloudsim上使用吗?
    • ¥15 错误 LNK2001 无法解析的外部符号