• 技术文章 >Python技术 >Python基础教程

    python A*算法是什么

    小妮浅浅小妮浅浅2021-08-31 10:15:22原创3789

    说明

    1、A*算法是静态路网中解决最短路径最有效的直接搜索方法。

    2、A*算法是启发式算法,采用最佳优先搜索策略(Best-first),基于评估函数对每个搜索位置的评估结果,猜测最佳优先搜索位置。

    A*算法大大降低了低质量的搜索路径,因此搜索效率高,比传统的路径规划算法更实时、更灵活。但A*算法找到的是相对最优的路径,而不是绝对最短的路径,适合大规模、实时性高的问题。

    实例

    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14

    15

    16

    17

    18

    19

    20

    21

    22

    23

    24

    25

    26

    27

    28

    29

    30

    31

    32

    33

    34

    35

    36

    37

    38

    39

    40

    41

    42

    43

    44

    45

    46

    47

    48

    49

    50

    51

    52

    53

    54

    55

    56

    57

    58

    59

    60

    61

    62

    63

    64

    65

    66

    67

    68

    69

    70

    71

    72

    73

    74

    75

    76

    77

    78

    79

    80

    81

    82

    83

    84

    85

    86

    87

    88

    89

    90

    91

    92

    93

    94

    95

    96

    97

    98

    99

    100

    101

    102

    103

    104

    105

    106

    107

    108

    109

    110

    111

    112

    113

    114

    115

    116

    117

    118

    119

    120

    121

    122

    123

    124

    125

    126

    127

    128

    129

    130

    131

    132

    133

    134

    135

    136

    137

    138

    139

    140

    141

    142

    143

    144

    145

    146

    147

    148

    149

    150

    151

    152

    153

    154

    155

    156

    157

    158

    159

    160

    161

    162

    163

    164

    165

    166

    167

    168

    169

    170

    171

    172

    173

    174

    175

    176

    177

    178

    179

    180

    181

    182

    183

    184

    185

    186

    187

    188

    189

    190

    191

    192

    193

    194

    195

    196

    197

    198

    199

    200

    201

    202

    203

    204

    205

    206

    207

    208

    209

    210

    211

    212

    213

    214

    215

    216

    217

    218

    219

    220

    221

    222

    223

    224

    225

    226

    227

    228

    229

    230

    231

    232

    import heapq

    import copy

    import re

    import datetime

      

    BLOCK = []  # 给定状态

    GOAL = []  # 目标状态

      

    # 4个方向

    direction = [[0, 1], [0, -1], [1, 0], [-1, 0]]

      

    # OPEN表

    OPEN = []

      

    # 节点的总数

    SUM_NODE_NUM = 0

      

      

    # 状态节点

    class State(object):

        def __init__(self, gn=0, hn=0, state=None, hash_value=None, par=None):

            '''

            初始化

            :param gn: gn是初始化到现在的距离

            :param hn: 启发距离

            :param state: 节点存储的状态

            :param hash_value: 哈希值,用于判重

            :param par: 父节点指针

            '''

            self.gn = gn

            self.hn = hn

            self.fn = self.gn + self.hn

            self.child = []  # 孩子节点

            self.par = par  # 父节点

            self.state = state  # 局面状态

            self.hash_value = hash_value  # 哈希值

      

        def __lt__(self, other):  # 用于堆的比较,返回距离最小的

            return self.fn < other.fn

      

        def __eq__(self, other):  # 相等的判断

            return self.hash_value == other.hash_value

      

        def __ne__(self, other):  # 不等的判断

            return not self.__eq__(other)

      

      

    def manhattan_dis(cur_node, end_node):

        '''

        计算曼哈顿距离

        :param cur_state: 当前状态

        :return: 到目的状态的曼哈顿距离

        '''

        cur_state = cur_node.state

        end_state = end_node.state

        dist = 0

        N = len(cur_state)

        for i in range(N):

            for j in range(N):

                if cur_state[i][j] == end_state[i][j]:

                    continue

                num = cur_state[i][j]

                if num == 0:

                    x = N - 1

                    y = N - 1

                else:

                    x = num / N  # 理论横坐标

                    y = num - N * x - 1  # 理论的纵坐标

                dist += (abs(x - i) + abs(y - j))

      

        return dist

      

      

    def test_fn(cur_node, end_node):

        return 0

      

      

    def generate_child(cur_node, end_node, hash_set, open_table, dis_fn):

        '''

        生成子节点函数

        :param cur_node:  当前节点

        :param end_node:  最终状态节点

        :param hash_set:  哈希表,用于判重

        :param open_table: OPEN表

        :param dis_fn: 距离函数

        :return: None

        '''

        if cur_node == end_node:

            heapq.heappush(open_table, end_node)

            return

        num = len(cur_node.state)

        for i in range(0, num):

            for j in range(0, num):

                if cur_node.state[i][j] != 0:

                    continue

                for d in direction:  # 四个偏移方向

                    x = i + d[0]

                    y = j + d[1]

                    if x < 0 or x >= num or y < 0 or y >= num:  # 越界了

                        continue

                    # 记录扩展节点的个数

                    global SUM_NODE_NUM

                    SUM_NODE_NUM += 1

      

                    state = copy.deepcopy(cur_node.state)  # 复制父节点的状态

                    state[i][j], state[x][y] = state[x][y], state[i][j]  # 交换位置

                    h = hash(str(state))  # 哈希时要先转换成字符串

                    if h in hash_set:  # 重复了

                        continue

                    hash_set.add(h)  # 加入哈希表

                    gn = cur_node.gn + 1  # 已经走的距离函数

                    hn = dis_fn(cur_node, end_node)  # 启发的距离函数

                    node = State(gn, hn, state, h, cur_node)  # 新建节点

                    cur_node.child.append(node)  # 加入到孩子队列

                    heapq.heappush(open_table, node)  # 加入到堆中

      

      

    def print_path(node):

        '''

        输出路径

        :param node: 最终的节点

        :return: None

        '''

        num = node.gn

      

        def show_block(block):

            print("---------------")

            for b in block:

                print(b)

      

        stack = []  # 模拟栈

        while node.par is not None:

            stack.append(node.state)

            node = node.par

        stack.append(node.state)

        while len(stack) != 0:

            t = stack.pop()

            show_block(t)

        return num

      

      

    def A_start(start, end, distance_fn, generate_child_fn, time_limit=10):

        '''

        A*算法

        :param start: 起始状态

        :param end: 终止状态

        :param distance_fn: 距离函数,可以使用自定义的

        :param generate_child_fn: 产生孩子节点的函数

        :param time_limit: 时间限制,默认10秒

        :return: None

        '''

        root = State(0, 0, start, hash(str(BLOCK)), None)  # 根节点

        end_state = State(0, 0, end, hash(str(GOAL)), None)  # 最后的节点

        if root == end_state:

            print("start == end !")

      

        OPEN.append(root)

        heapq.heapify(OPEN)

      

        node_hash_set = set()  # 存储节点的哈希值

        node_hash_set.add(root.hash_value)

        start_time = datetime.datetime.now()

        while len(OPEN) != 0:

            top = heapq.heappop(OPEN)

            if top == end_state:  # 结束后直接输出路径

                return print_path(top)

            # 产生孩子节点,孩子节点加入OPEN表

            generate_child_fn(cur_node=top, end_node=end_state, hash_set=node_hash_set,

                              open_table=OPEN, dis_fn=distance_fn)

            cur_time = datetime.datetime.now()

            # 超时处理

            if (cur_time - start_time).seconds > time_limit:

                print("Time running out, break !")

                print("Number of nodes:", SUM_NODE_NUM)

                return -1

      

        print("No road !")  # 没有路径

        return -1

      

      

    def read_block(block, line, N):

        '''

        读取一行数据作为原始状态

        :param block: 原始状态

        :param line: 一行数据

        :param N: 数据的总数

        :return: None

        '''

        pattern = re.compile(r'\d+')  # 正则表达式提取数据

        res = re.findall(pattern, line)

        t = 0

        tmp = []

        for i in res:

            t += 1

            tmp.append(int(i))

            if t == N:

                t = 0

                block.append(tmp)

                tmp = []

      

      

    if __name__ == '__main__':

        try:

            file = open("./infile.txt", "r")

        except IOError:

            print("can not open file infile.txt !")

            exit(1)

      

        f = open("./infile.txt")

        NUMBER = int(f.readline()[-2])

        n = 1

        for i in range(NUMBER):

            l = []

            for j in range(NUMBER):

                l.append(n)

                n += 1

            GOAL.append(l)

        GOAL[NUMBER - 1][NUMBER - 1] = 0

      

        for line in f:  # 读取每一行数据

            OPEN = []  # 这里别忘了清空

            BLOCK = []

            read_block(BLOCK, line, NUMBER)

            SUM_NODE_NUM = 0

            start_t = datetime.datetime.now()

            # 这里添加5秒超时处理,可以根据实际情况选择启发函数

            length = A_start(BLOCK, GOAL, manhattan_dis, generate_child, time_limit=10)

            end_t = datetime.datetime.now()

            if length != -1:

                print("length =", length)

                print("time = ", (end_t - start_t).total_seconds(), "s")

                print("Nodes =", SUM_NODE_NUM)

    以上就是python A*算法的介绍,希望对大家有所帮助。更多Python学习指路:python基础教程

    本文教程操作环境:windows7系统、Python 3.9.1,DELL G3电脑。

    专题推荐:python a*算法
    上一篇:python Floyd算法是什么 下一篇:python最短路径算法如何选择

    相关文章推荐

    • python使用Pyecharts绘制疫情分布图• 如何查看python解释器的路径• python用户输入的方法• python if-elif-else语句是什么• python列表中if语句的用途• Python f-string字符串格式化的介绍• Python Dijkstra算法是什么• python最短路径问题的介绍• python Bellman-Ford算法是什么• python Floyd算法是什么

    全部评论我要评论

    © 2021 Python学习网 苏ICP备2021003149号-1

  • 取消发布评论
  • 

    Python学习网