HMM遇到问题,numpy.isfinite()报错TypeError,代码及报错如下
本程序处理韩语,预测元音类别,读入后拆分成音素,只保留元音类别序列。
import numpy as np
from hmmlearn import hmm
import jamotools
# Specify the position of the vowel to predict
test_position = 2
# Define the vowel categories
vowel_categories = {'0': ['ㅏ', 'ㅑ', 'ㅗ', 'ㅛ', 'ㅐ', 'ㅘ', 'ㅚ', 'ㅙ'],
'1': ['ㅓ', 'ㅕ', 'ㅜ', 'ㅠ', 'ㅔ', 'ㅝ', 'ㅟ', 'ㅞ'],
'2': ['ㅡ', 'ㅣ', 'ㅢ']}
states = ['0', '1', '2']
observations = np.array(['ㅏ', 'ㅑ', 'ㅗ', 'ㅛ', 'ㅐ', 'ㅘ', 'ㅚ', 'ㅙ', 'ㅓ', 'ㅕ', 'ㅜ', 'ㅠ', 'ㅔ', 'ㅝ', 'ㅟ', 'ㅞ', 'ㅡ', 'ㅣ', 'ㅢ'])
emission_probability3 = np.array([[0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.01, 0, 0, 0, 0, 0, 0, 0, 0, 0.09, 0.09, 0.09],
[0, 0, 0, 0, 0, 0, 0, 0, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09],
[0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0526, 0.0532, 0.0526]])
emission_probability3 = emission_probability3.astype(float)
emission_probability2 = np.array([[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.33, 0.34, 0.33]])
emission_probability2 = emission_probability2.astype(float)
# Function to split a Korean character into its components
def split_korean_character(character):
components = []
x = jamotools.split_syllables(character)
for i in x:
i = str(i)
components.append(i)
return components
# Assign vowel categories to each vowel in the text
def mark_words(text):
marked_text = []
text_split = text.split()
for word in text_split:
word = word.strip("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-'!?.,—%1234567890(《》<>)")
#print("learning word:",word)
marked_word = ""
for char in word:
components = split_korean_character(char)
word_categories = ""
for component in components:
if component in ['ㅏ', 'ㅑ', 'ㅗ', 'ㅛ', 'ㅐ', 'ㅘ', 'ㅚ', 'ㅙ']:
category = '0'
elif component in ['ㅓ', 'ㅕ', 'ㅜ', 'ㅠ', 'ㅔ', 'ㅝ', 'ㅟ', 'ㅞ']:
category = '1'
elif component in ['ㅡ', 'ㅣ', 'ㅢ']:
category = '2'
else:
category = ''
word_categories += category
marked_word += word_categories
if len(marked_word) > test_position + 1:
marked_text.append(marked_word)
#print("marked word:",marked_word)
#print("length of training words:",len(marked_text))
return marked_text
# Learn from the annotated data
def learn_markov_chain(data):
model = hmm.CategoricalHMM(n_components=len(states), n_features= 19, init_params='st')
model.emissionprob_ = np.array(emission_probability2, dtype=object)
observed = []
lengths = []
for sentence in data:
lengths.append(len(sentence))
observed.extend([vowel[0] for vowel in sentence])
#print("sentence:",sentence)
X = np.array([[int(obs)] for obs in observed])
print("Xtype",X.dtype)
lengths = np.array(lengths)
lengths = lengths.astype(int)
print("lengths type:",lengths.dtype)
model.fit(X, lengths)
return model
# Test the model by hiding the specified syllable at the test position in each word
def predict_vowel_category(model, text, position):
print("prediction begins")
sentence = mark_words(text)
corretto = 0
fiadata = 0
corretto_fiadata = 0
for word_idx, word in enumerate(sentence):
#print("\nsentence:", word_idx, '\n', 'word:', word, 'len_sen:',len(sentence))
if position <= len(word):
hidden_word = word[:position-1] + word[position:]
#print("hidden_word:",hidden_word)
X = np.array([[int(char)] for char in hidden_word])
# calculate correct ratio
predicted_category = states[model.predict(X)[position-1]]
real_category = word[position-1]
print(f"Word '{text.split()[word_idx]}': Predicted category at position {position} is {predicted_category}, Real category is {real_category}")
if predicted_category == real_category:
corretto += 1
# calculate correct ratio of the words vialating vowel harmony
add = 0
len_drop = 0
for i in X:
if i!= 2:
add+=i
len_drop+=1
if add != len_drop and add!= 0:
fiadata += 1
if predicted_category == real_category:
corretto_fiadata += 1
print('add:', add, 'len_drop:', len_drop, 'Vowel Harmony Vialation')
print('ratio:', corretto / len(sentence))
print('fiadata:', fiadata, 'ratio of vialated:', corretto_fiadata/fiadata)
# Training text
korean2 = open('korean2.txt','r',encoding='utf-8')
training_text = korean2.read()
korean2.close()
marked_data = mark_words(training_text)
for i in range(200):
model = learn_markov_chain(marked_data)
print(f"training time:{i}")
# Test text
with open('korean1.txt', 'r', encoding='utf-8') as korean1:
test_text = korean1.read() # Test text
test_split = test_text.split()
elements_to_remove = []
for i in test_split:
if len(i) < test_position+1:
elements_to_remove.append(i)
test_split = [x for x in test_split if x not in elements_to_remove]
test_text = ' '.join(test_split)
predicted_category = predict_vowel_category(model, test_text, test_position)
"""
Traceback (most recent call last):
File "c:\Users\11609\Desktop\hmm2.py", line 127, in <module>
model = learn_markov_chain(marked_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "c:\Users\11609\Desktop\hmm2.py", line 79, in learn_markov_chain
model.fit(X, lengths)
File "C:\conda\envs\mne\Lib\site-packages\hmmlearn\_emissions.py", line 27, in <lambda>
return functools.wraps(func)(lambda *args, **kwargs: func(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^
File "C:\conda\envs\mne\Lib\site-packages\hmmlearn\base.py", line 481, in fit
self._check()
File "C:\conda\envs\mne\Lib\site-packages\hmmlearn\hmm.py", line 148, in _check
self._check_sum_1("emissionprob_")
File "C:\conda\envs\mne\Lib\site-packages\hmmlearn\base.py", line 950, in _check_sum_1
if not np.allclose(s, 1):
^^^^^^^^^^^^^^^^^
File "C:\conda\envs\mne\Lib\site-packages\numpy\core\numeric.py", line 2241, in allclose
res = all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\conda\envs\mne\Lib\site-packages\numpy\core\numeric.py", line 2348, in isclose
xfin = isfinite(x)
^^^^^^^^^^^
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
"""
谢谢大家。