Savvy在英国 2020-08-29 16:05 采纳率: 33.3%
浏览 81
已结题

麻烦大神帮我用python表达以下scala代码

之前询问过的,也给出了回复,可是自己能力不够,有些lambda不会转译。给出修改的代码和需要翻译的代码,麻烦大神帮忙一下,非常感谢!

```class MultiSenseSkipGramEmbeddingModel
negative = opts.negative.value
window = opts.window.value
rng = random
sample = opts.sample.value.toDouble
def process(doc: String):
Int = {
sen = doc.stripLineEnd.split(' ').map(word => vocab.getId(word.toLowerCase())).filter(id => id != -1)
wordCount = sen.size

```var rightSense = 0
if (kmeans == 1)
rightSense = cbow_predict_kmeans(currWord, contexts)
else if (dpmeans == 1)
rightSense = cbow_predict_dpmeans(currWord, contexts)
else
rightSense = cbow_predict(currWord, contexts)

  contexts.foreach(context => {

  trainer.processExample(new MSCBOWSkipGramNegSamplingExample(this, currWord, rightSense, context, 1))

   (0 until negative).foreach(neg => trainer.processExample(new MSCBOWSkipGramNegSamplingExample(this, currWord, rightSense, vocab.getRandWordId, -1)))

  })
}
return wordCount

}

def cbow_predict(word : Int, contexts: Seq[Int]): Int = {
val contextsEmbedding = new DenseTensor1(D, 0)

contexts.foreach(context => contextsEmbedding.+=(global_weights(context).value))
var sense = 0
if (learnMultiVec(word)) {
var maxdot = contextsEmbedding.dot(sense_weights(word)(0).value)
for (s <- 1 until S) {
val dot = contextsEmbedding.dot(sense_weights(word)(s).value)
if (dot > maxdot) {
maxdot = dot
sense = s
}
}
}
sense
}

def cbow_predict_kmeans(word: Int, contexts: Seq[Int]): Int = {

val contextsEmbedding = new DenseTensor1(D, 0)

contexts.foreach(context => contextsEmbedding.+=(global_weights(context).value))
var sense = 0

  if (learnMultiVec(word)) {
     var minDist = Double.MaxValue
     for (s <- 0 until ncluster(word)) { 
        val mu = clusterCenter(word)(s)/(clusterCount(word)(s)) 
        val dist = 1 - TensorUtils.cosineDistance(contextsEmbedding, mu) 
        if (dist < minDist) {
          minDist = dist
          sense = s
        }
     }
  }

  clusterCenter(word)(sense).+=(contextsEmbedding)
  clusterCount(word)(sense) += 1
  sense

}

def cbow_predict_dpmeans(word: Int, contexts: Seq[Int]): Int = {
val contextsEmbedding = new DenseTensor1(D, 0)

contexts.foreach(context => contextsEmbedding.+=(global_weights(context).value))
var sense = 0

  if (learnMultiVec(word)) {
    var minDist = Double.MaxValue
    var ncluster_word = ncluster(word)
    val nC = if (ncluster_word == S) S else ncluster_word + 1
    var prob = new Array[Double](nC)
    for (s <- 0 until ncluster_word) {
      val mu = clusterCenter(word)(s) / (clusterCount(word)(s))
      val dist = 1 - TensorUtils.cosineDistance(contextsEmbedding, mu) 
      prob(s) = dist
      if (dist < minDist) {
        minDist = dist
        sense = s
      }
  }

  if (ncluster_word < S) {
    if (createClusterlambda < minDist) {
      prob(ncluster_word) = createClusterlambda
      sense = ncluster_word
      ncluster(word) = ncluster_word + 1
    }
  }
}

clusterCenter(word)(sense).+=(contextsEmbedding)
clusterCount(word)(sense) += 1
sense

}

def subSample(word: Int): Int = {
val ran = vocab.getSubSampleProb(word)
val real_ran = rng.nextInt(0xFFFF) / 0xFFFF.toDouble
return if (ran < real_ran) -1 else word
}
}

class MSCBOWSkipGramNegSamplingExample(model: MultiSenseWordEmbeddingModel, word: Int, sense : Int, context : Int, label: Int) extends Example {

def accumulateValueAndGradient(value: DoubleAccumulator, gradient: WeightsMapAccumulator): Unit = {

val wordEmbedding = model.sense_weights(word)(sense).value
val contextEmbedding = model.global_weights(context).value


val score: Double = wordEmbedding.dot(contextEmbedding)
val exp: Double = math.exp(-score) // TODO : pre-compute expTable similar to word2vec

var objective: Double = 0.0
var factor: Double = 0.0

// for POS Label
if (label == 1) {
  objective = -math.log1p(exp) // log1p -> log(1+x)
  factor = exp / (1 + exp)
}
// for NEG Label
if (label == -1) {
  objective = -score - math.log1p(exp)
  factor = -1 / (1 + exp)
}

if (value ne null) value.accumulate(objective)
if (gradient ne null) {
 gradient.accumulate(model.sense_weights(word)(sense), contextEmbedding, factor)
 // don;t update if global_weights is fixed. 
 if (model.updateGlobal == 1) gradient.accumulate(model.global_weights(context), wordEmbedding, factor)
}

}
}
比较多,麻烦了!

  • 写回答

2条回答 默认 最新

报告相同问题?

悬赏问题

  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 保护模式-系统加载-段寄存器
  • ¥15 电脑桌面设定一个区域禁止鼠标操作
  • ¥15 求NPF226060磁芯的详细资料