public int findClosestLeaf(TreeNode root, int k) {
TreeNode[] start = new TreeNode[1];
Map<TreeNode, List<TreeNode>> adjList = new HashMap<>();
buildGraph(root, null, k, start, adjList);
Queue<TreeNode> queue = new LinkedList<>();
Set<TreeNode> visited = new HashSet<>();
while (!queue.isEmpty()) {
for (int i = 0; i < size; i++) {
TreeNode curNode = queue.remove();
if (curNode.left == null && curNode.right == null) {
for (TreeNode nextNode : adjList.get(curNode)) {
if (!visited.contains(nextNode)) {
public void buildGraph(TreeNode curNode, TreeNode parent, int k, TreeNode[] start, Map<TreeNode, List<TreeNode>> adjList) {
adjList.putIfAbsent(curNode, new ArrayList<>());
adjList.putIfAbsent(parent, new ArrayList<>());
adjList.get(curNode).add(parent);
adjList.get(parent).add(curNode);
buildGraph(curNode.left, curNode, k, start, adjList);
buildGraph(curNode.right, curNode, k, start, adjList);