对序列进行分类
我们可以使用我们的 SequenceKernel
对序列进行分类。我们将使用一个包含6个序列的数据集,并根据序列中是否存在 'A' 进行训练。然后,我们将对另外5个序列进行预测,实际情况就是序列中是否至少有一个 'A'。在这里,我们做出了四个正确的分类,有一个分类错误。
X_train = np.array(["AGCT", "CGA", "TAAC", "TCG", "CTTT", "TGCT"])
## 序列中是否存在 'A'
Y_train = np.array([True, True, True, False, False, False])
gp = GaussianProcessClassifier(kernel)
gp.fit(X_train, Y_train)
X_test = ["AAA", "ATAG", "CTC", "CT", "C"]
Y_test = [True, True, False, False, False]
plt.figure(figsize=(8, 5))
plt.scatter(
np.arange(len(X_train)),
[1.0 if c else -1.0 for c in Y_train],
s=100,
marker="o",
edgecolor="none",
facecolor=(1, 0.75, 0),
label="training",
)
plt.scatter(
len(X_train) + np.arange(len(X_test)),
[1.0 if c else -1.0 for c in Y_test],
s=100,
marker="o",
edgecolor="none",
facecolor="r",
label="truth",
)
plt.scatter(
len(X_train) + np.arange(len(X_test)),
[1.0 if c else -1.0 for c in gp.predict(X_test)],
s=100,
marker="x",
facecolor="b",
linewidth=2,
label="prediction",
)
plt.xticks(np.arange(len(X_train) + len(X_test)), np.concatenate((X_train, X_test)))
plt.yticks([-1, 1], [False, True])
plt.title("Classification on sequences")
plt.legend()
plt.show()