Как пройти по дереву решений тестовой выборкой на Python?



@officialandrey

Привет.
Есть код построения дерева решений по обучающей выборке.
Теперь нужно пройти по дереву с тестовой выборкой и получить результат классификации.
Спасибо.

Код

Import math
import pandas as pd
from functools import reduce

data = {
    "Кредитная_история":["Плохое","Неизвестно","Неизвестно","Неизвестно","Неизвестно","Неизвестно","Плохая","Плохая","Хорошая","Хорошая","Хорошая","Хорошая","Хорошая","Плохая"],
    "Долг":["Высокий","Высокий","Низкий","Низкий","Низкий","Низкий","Низкий","Низкий","Низкий","Высокий","Высокий","Высокий","Высокий","Низкий"], 
    "Поручитель":["Нет","Нет","Нет","Нет","Нет","Адекватный","Нет","Адекватный","Нет","Адекватный","Нет","Нет","Нет","Нет"],
    "Доход":["0-15000","15000-30000","15000-30000","0-15000","35000+","35000+","0-15000","35000+","35000+","35000+","0-15000","15000-30000","35000+","15000-30000"],
    "Риск":["Высокий","Высокий","Средний","Высокий","Низкий","Низкий","Высокий","Средний","Низкий","Низкий","Высокий","Средний","Низкий","Высокий"],
}

df = pd.DataFrame(data)

# массив содержащий строку вида: кортеж(key, value)
convert_str = lambda string:[key+":"+str(value) for key, value in sorted(string.value_counts().items())]

Tree = {
    "name": df.columns[-1]+" "+str(cstr(df.iloc[:,-1])),
    "df": df,
    "edges":[],
}

def log_(x):
    i = len(df.iloc[:,-1].unique())
    return math.log(x)/math.log(6)

X = [Tree]
entropy = lambda s:-reduce(lambda x,y:x+y,map(lambda x:(x/len(s))*math.log2(x/len(s)),s.value_counts()))

def solution():
    while(len(X)!=0):
        n = X.pop(0)
        df_n = n["df"]
        if 0==entropy(df_n.iloc[:,-1]):
            continue
        attrs = {}
        for attr in df_n.columns[:-1]:
            attrs[attr] = {"entropy":0,"dfs":[],"values":[]}
            for value in sorted(set(df_n[attr])):
                df_m = df_n.query(attr+"=='"+value+"'")
                attrs[attr]["entropy"] += entropy(df_m.iloc[:,-1])*df_m.shape[0]/df_n.shape[0]
                attrs[attr]["dfs"] += [df_m]
                attrs[attr]["values"] += [value]
                pass
            pass
        if len(attrs)==0:
            continue
        attr = min(attrs,key=lambda x:attrs[x]["entropy"])
        for d,v in zip(attrs[attr]["dfs"],attrs[attr]["values"]):
            m = {"name":attr+"="+v,"edges":[],"df":d.drop(columns=attr)}
            n["edges"].append(m)
            X.append(m)
        pass

def tstr(tree,indent=""):
    s = indent+tree["name"]+str(cstr(tree["df"].iloc[:,-1]) if len(tree["edges"])==0 else "")+"n"
    for e in tree["edges"]:
        s += tstr(e,indent+"  ")
        pass
    return s

solution()
print(tstr(Tree))
Тестовая выборка

data = {
    "Кредитная_история":["Плохое","Неизвестно","Неизвестно","Неизвестно","Неизвестно","Неизвестно","Плохая","Плохая","Хорошая","Хорошая","Хорошая","Хорошая","Хорошая","Плохая"],
    "Долг":["Высокий","Высокий","Низкий","Низкий","Низкий","Низкий","Низкий","Низкий","Низкий","Высокий","Высокий","Высокий","Высокий","Низкий"], 
    "Поручитель":["Нет","Нет","Нет","Нет","Нет","Адекватный","Нет","Адекватный","Нет","Адекватный","Нет","Нет","Нет","Нет"],
    "Доход":["0-15000","15000-30000","15000-30000","0-15000","35000+","35000+","0-15000","35000+","35000+","35000+","0-15000","15000-30000","35000+","15000-30000"],
}


Решения вопроса 0


Ответы на вопрос 1



@officialandrey Автор вопроса

Вот додумался только до этого.

for i in range(len(data["Долг"])):
    temp = list()
    for key,_ in data.items():
        temp.append(key+"="+data[key][i])
        pass
    new_data.append(temp.copy())

print(new_data)
print('--------------------')


def bypass(tree, data_x, data_y, data_z):
    name_tree = tree["name"]
    target_tree = str(cstr(tree["df"].iloc[:,-1]) if len(tree["edges"])==0 else "")
    for i in [data_x, data_y, data_z]:
        if i == name_tree:
            for e in tree["edges"]:
                name_tree += bypass(e, data_x, data_y, data_z)
                pass
            print(i, name_tree)
    for e in tree["edges"]:
        name_tree += bypass(e, data_x, data_y, data_z)
        pass
    return name_tree
  
for i in range(len(new_data)):
    bypass(Tree, new_data[i][3], new_data[i][0], new_data[i][1])

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *