The first idea that jumped out of my mind was using Sets to track two nodes and pick up the first intersection node between these two Sets. Hence came out the first solution:

from collections import defaultdict

class Solution:
    def bfs(self, node1: int, node2: int, conns, length) -> int:
        set1 = {node1}
        set2 = {node2}
        step = 0
        while step <= length:
            inter = set1 & set2
            if len(inter) > 0:
                return min(list(inter))
            new_set1 = set()
            new_set2 = set()
            for node in set1:
                new_set1 |= conns[node]
            for node in set2:
                new_set2 |= conns[node]
            if len(new_set1 - set1) <= 0 and len(new_set2 - set2) <= 0:
                return -1
            set1 |= new_set1
            set2 |= new_set2
            step += 1
    
    def closestMeetingNode(self, edges: List[int], node1: int, node2: int) -> int:
        conns = defaultdict(set)
        for index, edge in enumerate(edges):
            if edge >= 0:
                conns[index].add(edge)
        return self.bfs(node1, node2, conns, len(edges))        

I am pretty satisfied with this the simplicity of the above code. But unfortunately, it exceeded the time limit.

Sometimes we might not need to start a new solution before optimising the first one. Maybe I don’t need to use Set since they are too expensive in Python. Using an array to track all visited nodes instead and meeting a VISITED node means “intersection”. To distinguish visiting from two different nodes, I let Node1 mark “1” in the array and Node2 mark “2”. Then comes out my second solution. It’s a little longer but uses arrays instead of Sets:

class Solution:
    def closestMeetingNode(self, edges: List[int], node1: int, node2: int) -> int:
        if node1 == node2:
            return node1
        n = len(edges)
        visited = [0] * n
        step = 0
        visited[node1] = 1
        visited[node2] = 2
        while True:
            ans = []
            old_node1 = node1
            nxt = edges[node1]
            if nxt >= 0:
                if visited[nxt] == 0:
                    visited[nxt] = 1
                    node1 = nxt
                elif visited[nxt] == 2:
                    ans.append(nxt)
            old_node2 = node2
            nxt = edges[node2]
            if nxt >= 0:
                if visited[nxt] == 0:
                    visited[nxt] = 2
                    node2 = nxt
                elif visited[nxt] == 1:
                    ans.append(nxt)
            if len(ans) > 0:
                return min(ans)
            if old_node1 == node1 and old_node2 == node2:
                return -1
        return -1

As above, I use “old_node1” and “old_node2” to check for a dead loop. It beats 97% on time-spending. Not bad.