System instructions |
You are an experienced programmer, well-versed in various common algorithms. Your task is to create sufficiently detailed unit tests for the implementation of a binary search tree to ensure the correctness of this class's implementation. Do not offer explanations, just create the unit tests.
|
I have a code implementation of a binary search tree, but I'm not sure if it's correct. To confirm that, I hope you can write unit tests for me to test each function (except for private functions) in this class. Your unit tests must include edge cases such as performing operations on an empty tree and handling duplicate values. Additionally, your code should contain a main function.
bst.py
from dataclasses import dataclass
from typing import Optional
@dataclass
class TreeNode:
value: int
left: Optional['TreeNode'] = None
right: Optional['TreeNode'] = None
class BST:
def __init__(self):
self.root: Optional[TreeNode] = None
def insert(self, value: int) -> None:
if not self.root:
self.root = TreeNode(value)
else:
self._insert_helper(self.root, value)
def _insert_helper(self, node: TreeNode, value: int) -> None:
if value < node.value:
if node.left is None:
node.left = TreeNode(value)
else:
self._insert_helper(node.left, value)
else:
if node.right is None:
node.right = TreeNode(value)
else:
self._insert_helper(node.right, value)
def find(self, value: int) -> Optional[TreeNode]:
return self._find_helper(self.root, value)
def _find_helper(self, node: Optional[TreeNode], value: int) -> Optional[TreeNode]:
if node is None or node.value == value:
return node
if value < node.value:
return self._find_helper(node.left, value)
return self._find_helper(node.right, value)
def _lookup_min(self, node: TreeNode) -> TreeNode:
current = node
while current.left:
current = current.left
return current
def print_inorder(self) -> None:
self._print_inorder_helper(self.root)
print()
def _print_inorder_helper(self, node: Optional[TreeNode]) -> None:
if node:
self._print_inorder_helper(node.left)
print(node.value, end=' ')
self._print_inorder_helper(node.right)
Response
# test_bst.py
import unittest
import io
from contextlib import redirect_stdout
# Assuming bst.py is in the same directory
from bst import BST, TreeNode
class TestBST(unittest.TestCase):
def setUp(self):
"""Set up a new, empty BST for each test."""
self.bst = BST()
# __init__ tests
def test_initialization(self):
"""Test that a new BST is empty (root is None)."""
self.assertIsNone(self.bst.root)
# insert() tests
def test_insert_into_empty_tree(self):
"""Test inserting a single node into an empty tree."""
self.bst.insert(10)
self.assertIsNotNone(self.bst.root)
self.assertIsInstance(self.bst.root, TreeNode)
self.assertEqual(self.bst.root.value, 10)
self.assertIsNone(self.bst.root.left)
self.assertIsNone(self.bst.root.right)
def test_insert_smaller_value(self):
"""Test inserting a smaller value, which should go to the left."""
self.bst.insert(10)
self.bst.insert(5)
self.assertIsNotNone(self.bst.root.left)
self.assertIsNone(self.bst.root.right)
self.assertEqual(self.bst.root.left.value, 5)
def test_insert_larger_value(self):
"""Test inserting a larger value, which should go to the right."""
self.bst.insert(10)
self.bst.insert(15)
self.assertIsNone(self.bst.root.left)
self.assertIsNotNone(self.bst.root.right)
self.assertEqual(self.bst.root.right.value, 15)
def test_insert_multiple_values(self):
"""Test inserting multiple values to create a more complex tree."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
# Check root
self.assertEqual(self.bst.root.value, 10)
# Check left subtree
self.assertEqual(self.bst.root.left.value, 5)
self.assertEqual(self.bst.root.left.left.value, 3)
self.assertEqual(self.bst.root.left.right.value, 7)
# Check right subtree
self.assertEqual(self.bst.root.right.value, 15)
self.assertEqual(self.bst.root.right.left.value, 12)
self.assertEqual(self.bst.root.right.right.value, 18)
def test_insert_duplicate_value(self):
"""Test inserting a duplicate value. Per implementation, it should go to the right."""
self.bst.insert(10)
self.bst.insert(5)
self.bst.insert(10) # Duplicate
self.assertEqual(self.bst.root.value, 10)
self.assertIsNotNone(self.bst.root.right)
self.assertEqual(self.bst.root.right.value, 10)
self.assertIsNone(self.bst.root.right.left)
self.assertIsNone(self.bst.root.right.right)
def test_insert_multiple_duplicates(self):
"""Test inserting the same value multiple times."""
self.bst.insert(10)
self.bst.insert(10)
self.bst.insert(10)
self.assertEqual(self.bst.root.value, 10)
self.assertIsNotNone(self.bst.root.right)
self.assertEqual(self.bst.root.right.value, 10)
self.assertIsNotNone(self.bst.root.right.right)
self.assertEqual(self.bst.root.right.right.value, 10)
# find() tests
def test_find_in_empty_tree(self):
"""Test find on an empty tree, should return None."""
self.assertIsNone(self.bst.find(10))
def test_find_root_node(self):
"""Test finding the root node in a tree with one element."""
self.bst.insert(10)
found_node = self.bst.find(10)
self.assertIsNotNone(found_node)
self.assertIs(found_node, self.bst.root)
self.assertEqual(found_node.value, 10)
def test_find_existing_node(self):
"""Test finding an existing node in a larger tree."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
found_node = self.bst.find(7)
self.assertIsNotNone(found_node)
self.assertEqual(found_node.value, 7)
self.assertIs(found_node, self.bst.root.left.right)
def test_find_non_existing_node(self):
"""Test finding a node that does not exist in a populated tree."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
self.assertIsNone(self.bst.find(99))
self.assertIsNone(self.bst.find(0))
self.assertIsNone(self.bst.find(13))
def test_find_min_value(self):
"""Test finding the minimum value in the tree."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
min_node = self.bst.find(3)
self.assertIsNotNone(min_node)
self.assertEqual(min_node.value, 3)
self.assertIs(min_node, self.bst.root.left.left)
def test_find_max_value(self):
"""Test finding the maximum value in the tree."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
max_node = self.bst.find(18)
self.assertIsNotNone(max_node)
self.assertEqual(max_node.value, 18)
self.assertIs(max_node, self.bst.root.right.right)
def test_find_with_duplicate_value(self):
"""Test find with duplicates. It should find the first occurrence encountered."""
self.bst.insert(10)
self.bst.insert(15)
self.bst.insert(10) # Duplicate
found_node = self.bst.find(10)
self.assertIsNotNone(found_node)
# It should find the root node, not the duplicate in the right subtree
self.assertIs(found_node, self.bst.root)
# print_inorder() tests
def test_print_inorder_empty_tree(self):
"""Test inorder printing on an empty tree. Should print a newline."""
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
self.assertEqual(output, "\n")
def test_print_inorder_single_node(self):
"""Test inorder printing on a tree with a single node."""
self.bst.insert(10)
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
self.assertEqual(output, "10 \n")
def test_print_inorder_multiple_nodes(self):
"""Test inorder printing on a complex tree, should produce sorted output."""
values = [10, 5, 15, 3, 7, 12, 18]
for val in values:
self.bst.insert(val)
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
expected_output = "3 5 7 10 12 15 18 \n"
self.assertEqual(output, expected_output)
def test_print_inorder_with_duplicates(self):
"""Test inorder printing with duplicate values."""
values = [10, 5, 15, 5, 10]
for val in values:
self.bst.insert(val)
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
# Inorder traversal of the resulting tree: 5 -> 5 -> 10 -> 10 -> 15
expected_output = "5 5 10 10 15 \n"
self.assertEqual(output, expected_output)
def test_print_inorder_left_skewed_tree(self):
"""Test inorder printing on a left-skewed tree."""
values = [10, 8, 6, 4, 2]
for val in values:
self.bst.insert(val)
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
expected_output = "2 4 6 8 10 \n"
self.assertEqual(output, expected_output)
def test_print_inorder_right_skewed_tree(self):
"""Test inorder printing on a right-skewed tree."""
values = [2, 4, 6, 8, 10]
for val in values:
self.bst.insert(val)
f = io.StringIO()
with redirect_stdout(f):
self.bst.print_inorder()
output = f.getvalue()
expected_output = "2 4 6 8 10 \n"
self.assertEqual(output, expected_output)
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
|
Model: |
gemini-2.5-pro |
Temperature: |
0.2 |
Max output tokens: |
8192 |
TopK: |
40.0 |
TopP: |
0.95 |
|