public int countNodes(TreeNode root) {
int leftHeight = getLeftMostHeight(root.left);
int rightHeight = getLeftMostHeight(root.right);
if (leftHeight == rightHeight) {
count += (1 << leftHeight) - 1;
count += countNodes(root.right);
count += (1 << rightHeight) - 1;
count += countNodes(root.left);
public int getLeftMostHeight(TreeNode node) {