本文將以八数码问题为例,讨论在解决问题时候的搜索算法的选择和优化。

人工智能-搜索算法

jskyzero 2017/10/11

以简单的问题切入

其实不用特别关注八数码究竟是什么问题,总之也是给我们一个输入,输出,变化规则,我们需要想办法找到从输入到输出的一个变化的过程,

我们来具体看下输入和输出,输入的是包含0的九个数,代表九个格子现在上面的数字,输出是按照1-8,最后是0的顺序。也就是要把数字放到对应的格子的意思,然后变化方法就是,移动0和上下左右的数字交换。

将问题转化到搜索上

每个状态可以理解为,一个节点,由一个状态转化到另一个状态可以理解为一条有向边,发现重复的状态的话就不在生成,这样我们就把这个问题变成了一个有向无环图(树),我们只要遍历这棵树,找到符合的目标状态,就算解决问题了。

那么树的遍历我们是熟悉的,无外乎深度优先广度优先。

广度优先可以找到最优解,但是相对而言对内存和时间的需求要高一些,深度的话可以在内存不足的时候起到关键作用。

随便实现一下逻辑

引入头文件

import heapq

initial = [8, 7, 6, 5, 4, 3, 2, 1, 0]
final = [1, 2, 3, 4, 5, 6, 7, 8, 0]

一些基本函数

def sub_move_choice(now_state):
    next_state = []
    # 0 1 2
    # 3 4 5
    # 6 7 8
    index = now_state.index(0)
    if index % 3 == 0:
        next_state.append(swap(now_state, index, index + 1))
    if index % 3 == 1:
        next_state.append(swap(now_state, index, index - 1))
        next_state.append(swap(now_state, index, index + 1))
    if index % 3 == 2:
        next_state.append(swap(now_state, index, index - 1))
    if index // 3 == 0:
        next_state.append(swap(now_state, index, index + 3))
    if index // 3 == 1:
        next_state.append(swap(now_state, index, index - 3))
        next_state.append(swap(now_state, index, index + 3))
    if index // 3 == 2:
        next_state.append(swap(now_state, index, index - 3))
    return next_state


def swap(now_state, left, right):
    new_state = now_state[:]
    temp = new_state[left]
    new_state[left] = new_state[right]
    new_state[right] = temp
    return new_state

核心部分


def BFS(initial_state, end_state):
    next_states = [initial_state]
    traveled = set(tuple(initial_state))
    depth = 0
    while not end_state in next_states and not len(next_states) == 0:
        depth = depth + 1
        new_next_states = []
        for next_state in next_states:
            for sub_next_state in sub_move_choice(next_state):
                if not tuple(sub_next_state) in traveled:
                    new_next_states.append(sub_next_state)
                    traveled.add(tuple(sub_next_state))
        next_states = new_next_states[:]
        # print("depth = ", depth,
        #       "len = ", len(new_next_states),
        #       "traveled = ", len(traveled))
    print( "FAILED" if len(next_states) == 0 else "traveled = ", len(traveled))


def DFS(initial_state, end_state, Afunction):
    next_states = []
    heapq.heappush(
        next_states, (-Afunction(initial_state, end_state), initial_state))
    traveled = set(tuple(initial_state))
    while not len(next_states) == 0:
        top_state = heapq.heappop(next_states)[-1]
        if top_state == end_state:
            print("traveled = ", len(traveled))
            return
        else:
            for next_state in sub_move_choice(top_state):
                if not tuple(next_state) in traveled:
                    heapq.heappush(
                        next_states, (-Afunction(next_state, end_state), next_state))
                    traveled.add(tuple(next_state))
    print("FAILED")


def diferent_bit(one_state, another_state=final):
    return -sum([0 if x == y else 1 for(x, y) in zip(one_state, another_state)])

def diferent_sum(one_state, another_state=final):
    return -sum([abs(x - y) for(x, y) in zip(one_state, another_state)])

def manhadun_space(one_state, another_state):
    space = 0
    for i in range(0, 9, 1):
      j = another_state.index(one_state[i])
      space = space + abs(j % 3 - i % 3) + abs(j // 3 - i // 3)
    return -space

我们可以看到这个差别,还是挺明显的。

BFS(initial, final)
DFS(initial, final, diferent_sum)
DFS(initial, final, diferent_bit)
DFS(initial, final, manhadun_space)

# traveled =  181447
# traveled =  18656
# traveled =  1327
# traveled =  247

results matching ""

    No results matching ""