Оптимизация поиска по дереву Монте-Карло и предотвращение потерь

Я работаю над реализацией Поиск по дереву Монте-Карло в Быстрый.

Неплохо, но могло быть и лучше! Меня принципиально интересует создание моего алгоритма:

  1. быстрее (больше итераций в секунду)
  2. расставляйте приоритеты в действиях, предотвращающих мгновенные потери (вы увидите …)

Здесь Основной драйвер:

final class MonteCarloTreeSearch {
    var player: Player
    var timeBudget: Double
    var maxDepth: Int
    var explorationConstant: Double
    var root: Node?
    var iterations: Int

    init(for player: Player, timeBudget: Double = 5, maxDepth: Int = 5, explorationConstant: Double = sqrt(2)) {
        self.player = player
        self.timeBudget = timeBudget
        self.maxDepth = maxDepth
        self.explorationConstant = explorationConstant
        self.iterations = 0
    }
    
    func update(with game: Game) {
        if let newRoot = findNode(for: game) {
            newRoot.parent = nil
            newRoot.move = nil
            root = newRoot
        } else {
            root = Node(game: game)
        }
    }

    func findMove(for game: Game? = nil) -> Move? {
        iterations = 0
        let start = CFAbsoluteTimeGetCurrent()
        if let game = game {
            update(with: game)
        }
        while CFAbsoluteTimeGetCurrent() - start < timeBudget {
            refine()
            iterations += 1
        }
        print("Iterations: (iterations)")
        return bestMove
    }
    
    private func refine() {
        let leafNode = root!.select(explorationConstant)
        let value = rollout(leafNode)
        leafNode.backpropogate(value)
    }
    
    private func rollout(_ node: Node) -> Double {
        var depth = 0
        var game = node.game
        while !game.isFinished {
            if depth >= maxDepth { break }
            guard let move = game.randomMove() else { break }
            game = game.update(move)
            depth += 1
        }
        let value = game.evaluate(for: player).value
        return value
    }
    
    private var bestMove: Move? {
        root?.selectChildWithMaxUcb(0)?.move
    }
    
    private func findNode(for game: Game) -> Node? {
        guard let root = root else { return nil }
        var queue = [root]
        while !queue.isEmpty {
            let head = queue.removeFirst()
            if head.game == game {
                return head
            }
            for child in head.children {
                queue.append(child)
            }
        }
        return nil
    }
}

Я построил этот драйвер с maxDepth аргумент, потому что воспроизведения / развертывания в моем настоящий игры довольно длинные, и у меня есть доступ к приличной функции статической оценки. Также BFS findNode метод таков, что я могу повторно использовать части дерева.

Вот что за узел в драйвере выглядит так:

final class Node {
    weak var parent: Node?
    var move: Move?
    var game: Game
    var untriedMoves: [Move]
    var children: [Node]
    var cumulativeValueFor: Double
    var cumulativeValueAgainst: Double
    var visits: Double

    init(parent: Node? = nil, move: Move? = nil, game: Game) {
        self.parent = parent
        self.move = move
        self.game = game
        self.children = []
        self.untriedMoves = game.availableMoves()
        self.cumulativeValueFor = 0
        self.cumulativeValueAgainst = 0
        self.visits = 0
    }
    
    var isFullyExpanded: Bool {
        untriedMoves.isEmpty
    }
    
    lazy var isTerminal: Bool = {
        game.isFinished
    }()
    
    func select(_ c: Double) -> Node {
        var leafNode = self
        while !leafNode.isTerminal {
            if !leafNode.isFullyExpanded {
                return leafNode.expand()
            } else {
                leafNode = leafNode.selectChildWithMaxUcb(c)!
            }
        }
        return leafNode
    }
    
    func expand() -> Node {
        let move = untriedMoves.popLast()!
        let nextGame = game.update(move)
        let childNode = Node(parent: self, move: move, game: nextGame)
        children.append(childNode)
        return childNode
    }
    
    func backpropogate(_ value: Double) {
        visits += 1
        cumulativeValueFor += value
        if let parent = parent {
            parent.backpropogate(value)
        }
    }
    
    func selectChildWithMaxUcb(_ c: Double) -> Node? {
        children.max { $0.ucb(c) < $1.ucb(c) }
    }

    func ucb(_ c: Double) -> Double {
        q + c * u
    }
    
    private var q: Double {
        let value = cumulativeValueFor - cumulativeValueAgainst
        return value / visits
    }
    
    private var u: Double {
        sqrt(log(parent!.visits) / visits)
    }
}

extension Node: CustomStringConvertible {
    var description: String {
        guard let move = move else { return "" }
        return "(move) ((cumulativeValueFor)/(visits))"
    }
}

Не думаю, что в моем объекте узла есть что-то экстраординарное? (Я надеюсь, что смогу что-нибудь сделать с / о q чтобы я мог предотвратить «мгновенную» потерю моего контрольная работа игра…


Я тестировал эту реализацию MCTS на одномерном варианте «Connect 4».

Вот игра и все ее примитивы:

enum Player: Int {
    case one = 1
    case two = 2
    
    var opposite: Self {
        switch self {
        case .one: return .two
        case .two: return .one
        }
    }
}

extension Player: CustomStringConvertible {
    var description: String {
        "(rawValue)"
    }
}

typealias Move = Int

enum Evaluation {
    case win
    case loss
    case draw
    case ongoing(Double)
    
    var value: Double {
        switch self {
        case .win: return 1
        case .loss: return 0
        case .draw: return 0.5
        case .ongoing(let v): return v
        }
    }
}

struct Game {
    var array: Array<Int>
    var currentPlayer: Player
    
    init(length: Int = 10, currentPlayer: Player = .one) {
        self.array = Array.init(repeating: 0, count: length)
        self.currentPlayer = currentPlayer
    }
    
    var isFinished: Bool {
        switch evaluate() {
        case .ongoing: return false
        default: return true
        }
    }

    func availableMoves() -> [Move] {
        array
            .enumerated()
            .compactMap { $0.element == 0 ? Move($0.offset) : nil}
    }
    
    func update(_ move: Move) -> Self {
        var copy = self
        copy.array[move] = currentPlayer.rawValue
        copy.currentPlayer = currentPlayer.opposite
        return copy
    }
    
    func evaluate(for player: Player) -> Evaluation {
        let player3 = three(for: player)
        let oppo3 = three(for: player.opposite)
        let remaining0 = array.contains(0)
        switch (player3, oppo3, remaining0) {
        case (true, true, _): return .draw
        case (true, false, _): return .win
        case (false, true, _): return .loss
        case (false, false, false): return .draw
        default: return .ongoing(0.5)
        }
    }
    
    private func three(for player: Player) -> Bool {
        var count = 0
        for slot in array {
            if slot == player.rawValue {
                count += 1
            } else {
                count = 0
            }
            if count == 3 {
                return true
            }
        }
        return false
    }
}

extension Game {
    func evaluate() -> Evaluation {
        evaluate(for: currentPlayer)
    }
    
    func randomMove() -> Move? {
        availableMoves().randomElement()
    }
}

extension Game: CustomStringConvertible {
    var description: String {
        return array.reduce(into: "") { result, i in
            result += String(i)
        }
    }
}

extension Game: Equatable {}

Хотя определенно можно добиться повышения эффективности в оптимизации evaluate/three(for:) методы оценки, меня больше беспокоит улучшение производительности драйвера и узла, так как эта игра «1d-connect-3» мне не подходит. настоящий игра. Тем не менее, если здесь есть большая ошибка и простое исправление, я ее устраню!

Еще одно примечание: на самом деле я использую ongoing(Double) в моем настоящий game (у меня есть статическая функция оценки, которая может надежно поставить игрока на 1-99% вероятности выигрыша).


Немного кода игровой площадки:

var mcts = MonteCarloTreeSearch(for: .two, timeBudget: 5, maxDepth: 3)
var game = Game(length: 10)
// 0000000000
game = game.update(0) // player 1
// 1000000000
game = game.update(8) // player 2
// 1000000020
game = game.update(1) // player 1
// 1100000020
let move1 = mcts.findMove(for: game)!
// usually 7 or 9... and not 2
print(mcts.root!.children)
game = game.update(move1) // player 2
mcts.update(with: game)
game = game.update(4) // player 1
mcts.update(with: game)
let move2 = mcts.findMove()!

К несчастью, move1 в этом примере «playthru» не пытается предотвратить мгновенное выигрышное условие на следующем ходу для игрока 1 ?! (Я знаю, что ортодоксальный поиск по дереву Монте-Карло направлен на максимальное увеличение выигрыша, а не на минимизацию проигрыша, но не на сбор 2 вот прискорбно).

Так что да, любая помощь в ускорении всего этого (возможно, за счет распараллеливания) и исправлении ситуации с «мгновенными потерями» будет великолепной!

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *