Aprendizado Ativo com Propagação de Rótulos

Beginner

This tutorial is from open-source community. Access the source code

Introdução

Este laboratório demonstra uma técnica de aprendizagem ativa para aprender dígitos manuscritos usando propagação de rótulos. A Propagação de Rótulos é um método de aprendizagem semi-supervisionada que utiliza uma abordagem baseada em grafos para propagar rótulos entre pontos de dados. A aprendizagem ativa é um processo que nos permite selecionar iterativamente pontos de dados para serem rotulados e utilizar esses pontos rotulados para treinar novamente o modelo.

Dicas da Máquina Virtual

Após o arranque da máquina virtual, clique no canto superior esquerdo para mudar para a aba Notebook para aceder ao Jupyter Notebook para a prática.

Por vezes, pode ser necessário esperar alguns segundos para o Jupyter Notebook terminar de carregar. A validação das operações não pode ser automatizada devido a limitações no Jupyter Notebook.

Se tiver problemas durante a aprendizagem, não hesite em contactar o Labby. Forneça feedback após a sessão e resolveremos prontamente o problema para si.

Carregar o Conjunto de Dados de Dígitos

Começaremos carregando o conjunto de dados de dígitos da biblioteca scikit-learn.

from sklearn import datasets

digits = datasets.load_digits()

Embaralhar e Dividir os Dados

Em seguida, embaralharemos e dividiremos o conjunto de dados em partes rotuladas e não rotuladas. Começaremos com apenas 10 pontos rotulados.

import numpy as np

rng = np.random.RandomState(0)
indices = np.arange(len(digits.data))
rng.shuffle(indices)

X = digits.data[indices[:330]]
y = digits.target[indices[:330]]
images = digits.images[indices[:330]]

n_total_samples = len(y)
n_labeled_points = 10
unlabeled_indices = np.arange(n_total_samples)[n_labeled_points:]

Treinar o Modelo de Propagação de Rótulos

Agora, treinaremos um modelo de propagação de rótulos com os pontos de dados rotulados e o usaremos para prever os rótulos dos pontos de dados não rotulados restantes.

from sklearn.semi_supervised import LabelSpreading

lp_model = LabelSpreading(gamma=0.25, max_iter=20)
lp_model.fit(X, y_train)
predicted_labels = lp_model.transduction_[unlabeled_indices]

Selecionar os Pontos Mais Incertos

Selecionaremos os cinco pontos mais incertos com base nas suas distribuições de rótulos previstos e solicitaremos rótulos humanos para eles.

from scipy import stats

pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)
uncertainty_index = np.argsort(pred_entropies)[::-1]
uncertainty_index = uncertainty_index[np.in1d(uncertainty_index, unlabeled_indices)][:5]

Rotular os Pontos Mais Incertos

Adicionaremos os rótulos humanos aos pontos de dados rotulados e treinaremos o modelo com eles.

y_train[uncertainty_index] = y[uncertainty_index]
lp_model.fit(X, y_train)

Repetir

Repetiremos o processo de seleção dos cinco pontos mais incertos, adicionando seus rótulos aos pontos de dados rotulados e treinando o modelo até termos 30 pontos de dados rotulados.

max_iterations = 3

for i in range(max_iterations):
    if len(unlabeled_indices) == 0:
        print("Não há mais itens não rotulados para serem rotulados.")
        break

    ## selecionar os cinco pontos mais incertos
    pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)
    uncertainty_index = np.argsort(pred_entropies)[::-1]
    uncertainty_index = uncertainty_index[np.in1d(uncertainty_index, unlabeled_indices)][:5]

    ## adicionar rótulos aos pontos de dados rotulados
    y_train[uncertainty_index] = y[uncertainty_index]

    ## treinar o modelo
    lp_model.fit(X, y_train)

    ## remover pontos de dados rotulados do conjunto não rotulado
    delete_indices = np.array([], dtype=int)
    for index, image_index in enumerate(uncertainty_index):
        (delete_index,) = np.where(unlabeled_indices == image_index)
        delete_indices = np.concatenate((delete_indices, delete_index))
    unlabeled_indices = np.delete(unlabeled_indices, delete_indices)
    n_labeled_points += len(uncertainty_index)

Resumo

Em resumo, este laboratório demonstrou uma técnica de aprendizado ativo utilizando a Propagação de Rótulos para aprender dígitos manuscritos. Iniciamos treinando um modelo de propagação de rótulos com apenas 10 pontos rotulados e, iterativamente, selecionamos os cinco pontos mais incertos para rotulação até termos 30 pontos de dados rotulados. Essa técnica de aprendizado ativo pode ser útil para minimizar o número de pontos de dados rotulados necessários para treinar um modelo, maximizando seu desempenho.