- 查找命中所需的时间与被查找的键的长度成正比;
- 查找未命中只需检查若干个字符;
1.1 数据结构定义
- Trie树的根结点不保存字符(也可看成保存空字符"");
- Trie树的每个结点含有R条链接(R为字母表的大小),每个结点
为根结点的子树; - 每个键所关联的值保存在该键的最后一个字符所在的结点中(值为空的结点在Trie树中没有对应的键)。
public class TrieST<V> {
private static final int R = 256; // extended ASCII
private Node root; // root of trie
private int n; // number of keys in trie
// R-way trie node
private class Node {
private V val;
private Node[] children = new Node[R];
1.2 API定义
2.1 查找
- 从根结点开始一次搜索;
- 取得要查找关键词的第一个字母,并根据该字母选择对应的子树并转到该子树继续进行检索;
- 在相应的子树上,取得要查找关键词的第二个字母,并进一步选择对应的子树进行检索。
- 迭代过程……
- 在某个结点处,关键词的所有字母已被取出,则读取附在该结点上的信息,即完成查找。
- 键的尾字符对应的结点中保存的值为空;(未命中)
- 键的尾字符对应的结点中保存的值非空;(命中)
- 查找结束于一条空链接。(未命中)
public V get(String key) {
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
Node x = get(root, key, 0);
if (x == null)
return null;
return x.val;
private Node get(Node x, String key, int d) {
if (x == null)
return null;
if (d == key.length())
return x;
char c = key.charAt(d);
return get(x.children[c], key, d+1);
2.2 插入
- 在到达键的尾字符之前就遇到了一个空链接;
- 在遇到空链接之前就到达了键的尾字符。
public void put(String key, Value val) {
if (key == null)
throw new IllegalArgumentException("first argument to put() is null");
root = put(root, key, val, 0);
private Node put(Node x, String key, Value val, int d) {
if (x == null)
x = new Node();
if (d == key.length()) {
if (x.val == null)
x.val = val;
return x;
char c = key.charAt(d);
x.children[c] = put(x.children[c], key, val, d + 1);
return x;
2.3 删除
- 查找键所在结点,并将值置为null;
- 判断该结点是否含有指向子结点的非空链接?
public void delete(String key) {
if (key == null)
throw new IllegalArgumentException("argument to delete() is null");
root = delete(root, key, 0);
private Node delete(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) {
if (x.val != null) n--;
x.val = null;
else {
char c = key.charAt(d);
x.next[c] = delete(x.next[c], key, d+1);
// remove subtrie rooted at x if it is completely empty
if (x.val != null) return x;
for (int c = 0; c < R; c++)
if (x.next[c] != null)
return x;
return null;
2.4 遍历
注:根结点相当于保存空字符 ""。
public Iterable<String> keys() {
return keysWithPrefix("");
public Iterable<String> keysWithPrefix(String prefix) {
Queue<String> results = new Queue<String>();
Node x = get(root, prefix, 0);
collect(x, new StringBuilder(prefix), results);
return results;
private void collect(Node x, StringBuilder prefix, Queue<String> results) {
if (x == null) return;
if (x.val != null)
for (char c = 0; c < R; c++) {
collect(x.next[c], prefix, results);
prefix.deleteCharAt(prefix.length() - 1);
2.5 完整源码
public class TrieST<Value> {
private static final int R = 256; // extended ASCII
private Node root; // root of trie
private int n; // number of keys in trie
// R-way trie node
private static class Node {
private Object val;
private Node[] next = new Node[R];
public TrieST() {
* Returns the value associated with the given key.
* @param key the key
* @return the value associated with the given key if the key is in the symbol table
* and {@code null} if the key is not in the symbol table
* @throws IllegalArgumentException if {@code key} is {@code null}
public Value get(String key) {
if (key == null) throw new IllegalArgumentException("argument to get() is null");
Node x = get(root, key, 0);
if (x == null) return null;
return (Value) x.val;
* Does this symbol table contain the given key?
* @param key the key
* @return {@code true} if this symbol table contains {@code key} and
* {@code false} otherwise
* @throws IllegalArgumentException if {@code key} is {@code null}
public boolean contains(String key) {
if (key == null) throw new IllegalArgumentException("argument to contains() is null");
return get(key) != null;
private Node get(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) return x;
char c = key.charAt(d);
return get(x.next[c], key, d+1);
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
* @param key the key
* @param val the value
* @throws IllegalArgumentException if {@code key} is {@code null}
public void put(String key, Value val) {
if (key == null) throw new IllegalArgumentException("first argument to put() is null");
if (val == null) delete(key);
else root = put(root, key, val, 0);
private Node put(Node x, String key, Value val, int d) {
if (x == null) x = new Node();
if (d == key.length()) {
if (x.val == null) n++;
x.val = val;
return x;
char c = key.charAt(d);
x.next[c] = put(x.next[c], key, val, d+1);
return x;
* Returns the number of key-value pairs in this symbol table.
* @return the number of key-value pairs in this symbol table
public int size() {
return n;
* Is this symbol table empty?
* @return {@code true} if this symbol table is empty and {@code false} otherwise
public boolean isEmpty() {
return size() == 0;
* Returns all keys in the symbol table as an {@code Iterable}.
* To iterate over all of the keys in the symbol table named {@code st},
* use the foreach notation: {@code for (Key key : st.keys())}.
* @return all keys in the symbol table as an {@code Iterable}
public Iterable<String> keys() {
return keysWithPrefix("");
* Returns all of the keys in the set that start with {@code prefix}.
* @param prefix the prefix
* @return all of the keys in the set that start with {@code prefix},
* as an iterable
public Iterable<String> keysWithPrefix(String prefix) {
Queue<String> results = new Queue<String>();
Node x = get(root, prefix, 0);
collect(x, new StringBuilder(prefix), results);
return results;
private void collect(Node x, StringBuilder prefix, Queue<String> results) {
if (x == null) return;
if (x.val != null) results.enqueue(prefix.toString());
for (char c = 0; c < R; c++) {
collect(x.next[c], prefix, results);
prefix.deleteCharAt(prefix.length() - 1);
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
* @param pattern the pattern
* @return all of the keys in the symbol table that match {@code pattern},
* as an iterable, where . is treated as a wildcard character.
public Iterable<String> keysThatMatch(String pattern) {
Queue<String> results = new Queue<String>();
collect(root, new StringBuilder(), pattern, results);
return results;
private void collect(Node x, StringBuilder prefix, String pattern, Queue<String> results) {
if (x == null) return;
int d = prefix.length();
if (d == pattern.length() && x.val != null)
if (d == pattern.length())
char c = pattern.charAt(d);
if (c == '.') {
for (char ch = 0; ch < R; ch++) {
collect(x.next[ch], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
else {
collect(x.next[c], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
* @param query the query string
* @return the string in the symbol table that is the longest prefix of {@code query},
* or {@code null} if no such string
* @throws IllegalArgumentException if {@code query} is {@code null}
public String longestPrefixOf(String query) {
if (query == null) throw new IllegalArgumentException("argument to longestPrefixOf() is null");
int length = longestPrefixOf(root, query, 0, -1);
if (length == -1) return null;
else return query.substring(0, length);
// returns the length of the longest string key in the subtrie
// rooted at x that is a prefix of the query string,
// assuming the first d character match and we have already
// found a prefix match of given length (-1 if no such match)
private int longestPrefixOf(Node x, String query, int d, int length) {
if (x == null) return length;
if (x.val != null) length = d;
if (d == query.length()) return length;
char c = query.charAt(d);
return longestPrefixOf(x.next[c], query, d+1, length);
* Removes the key from the set if the key is present.
* @param key the key
* @throws IllegalArgumentException if {@code key} is {@code null}
public void delete(String key) {
if (key == null) throw new IllegalArgumentException("argument to delete() is null");
root = delete(root, key, 0);
private Node delete(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) {
if (x.val != null) n--;
x.val = null;
else {
char c = key.charAt(d);
x.next[c] = delete(x.next[c], key, d+1);
// remove subtrie rooted at x if it is completely empty
if (x.val != null) return x;
for (int c = 0; c < R; c++)
if (x.next[c] != null)
return x;
return null;
* Unit tests the {@code TrieST} data type.
* @param args the command-line arguments
public static void main(String[] args) {
// build symbol table from standard input
TrieST<Integer> st = new TrieST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++) {
String key = StdIn.readString();
st.put(key, i);
// print results
if (st.size() < 100) {
for (String key : st.keys()) {
StdOut.println(key + " " + st.get(key));
for (String s : st.keysWithPrefix("shor"))
for (String s : st.keysThatMatch(".he.l."))
4.1 定义
public class TST<V> {
private int n; // size
private Node<V> root; // root of TST
private static class Node<V> {
private char c;
private Node<V> left, mid, right;
private V val;
4.2 实现
4.2.1 查找
- 比较键的首字符与树的根结点字符的大小。
如果相等,则选择中链接。 - 递归地重复步骤1;
- 直到遇到一个空链接或到达键的末尾。
public V get(String key) {
if (key == null)
throw new IllegalArgumentException("calls get() with null argument");
if (key.length() == 0)
throw new IllegalArgumentException("key must have length >= 1");
Node<V> x = get(root, key, 0);
if (x == null)
return null;
return x.val;
// 在以x为根结点的树中,查找键key[d]
private Node<V> get(Node<V> x, String key, int d) {
if (x == null)
return null;
if (key.length() == 0)
throw new IllegalArgumentException("key must have length >= 1");
char c = key.charAt(d);
if (c < x.c)
return get(x.left, key, d);
else if (c > x.c)
return get(x.right, key, d);
else {
if (d < key.length() - 1)
return get(x.mid, key, d + 1);
return x;
4.2.2 插入
public void put(String key, V val) {
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
if (!contains(key))
root = put(root, key, val, 0);
private Node<V> put(Node<V> x, String key, V val, int d) {
char c = key.charAt(d);
if (x == null) {
x = new Node<V>();
x.c = c;
if (c < x.c)
x.left = put(x.left, key, val, d);
else if (c > x.c)
x.right = put(x.right, key, val, d);
else {
if (d < key.length() - 1)
x.mid = put(x.mid, key, val, d + 1);
x.val = val;
return x;
4.2.3 完整源码
public class TST<Value> {
private int n; // size
private Node<Value> root; // root of TST
private static class Node<Value> {
private char c; // character
private Node<Value> left, mid, right; // left, middle, and right subtries
private Value val; // value associated with string
public TST() {
* Returns the number of key-value pairs in this symbol table.
* @return the number of key-value pairs in this symbol table
public int size() {
return n;
* Does this symbol table contain the given key?
* @param key the key
* @return {@code true} if this symbol table contains {@code key} and
* {@code false} otherwise
* @throws IllegalArgumentException if {@code key} is {@code null}
public boolean contains(String key) {
if (key == null) {
throw new IllegalArgumentException("argument to contains() is null");
return get(key) != null;
* Returns the value associated with the given key.
* @param key the key
* @return the value associated with the given key if the key is in the symbol table
* and {@code null} if the key is not in the symbol table
* @throws IllegalArgumentException if {@code key} is {@code null}
public Value get(String key) {
if (key == null) {
throw new IllegalArgumentException("calls get() with null argument");
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
Node<Value> x = get(root, key, 0);
if (x == null) return null;
return x.val;
// return subtrie corresponding to given key
private Node<Value> get(Node<Value> x, String key, int d) {
if (x == null) return null;
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
char c = key.charAt(d);
if (c < x.c) return get(x.left, key, d);
else if (c > x.c) return get(x.right, key, d);
else if (d < key.length() - 1) return get(x.mid, key, d+1);
else return x;
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
* @param key the key
* @param val the value
* @throws IllegalArgumentException if {@code key} is {@code null}
public void put(String key, Value val) {
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
if (!contains(key)) n++;
root = put(root, key, val, 0);
private Node<Value> put(Node<Value> x, String key, Value val, int d) {
char c = key.charAt(d);
if (x == null) {
x = new Node<Value>();
x.c = c;
if (c < x.c) x.left = put(x.left, key, val, d);
else if (c > x.c) x.right = put(x.right, key, val, d);
else if (d < key.length() - 1) x.mid = put(x.mid, key, val, d+1);
else x.val = val;
return x;
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
* @param query the query string
* @return the string in the symbol table that is the longest prefix of {@code query},
* or {@code null} if no such string
* @throws IllegalArgumentException if {@code query} is {@code null}
public String longestPrefixOf(String query) {
if (query == null) {
throw new IllegalArgumentException("calls longestPrefixOf() with null argument");
if (query.length() == 0) return null;
int length = 0;
Node<Value> x = root;
int i = 0;
while (x != null && i < query.length()) {
char c = query.charAt(i);
if (c < x.c) x = x.left;
else if (c > x.c) x = x.right;
else {
if (x.val != null) length = i;
x = x.mid;
return query.substring(0, length);
* Returns all keys in the symbol table as an {@code Iterable}.
* To iterate over all of the keys in the symbol table named {@code st},
* use the foreach notation: {@code for (Key key : st.keys())}.
* @return all keys in the symbol table as an {@code Iterable}
public Iterable<String> keys() {
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), queue);
return queue;
* Returns all of the keys in the set that start with {@code prefix}.
* @param prefix the prefix
* @return all of the keys in the set that start with {@code prefix},
* as an iterable
* @throws IllegalArgumentException if {@code prefix} is {@code null}
public Iterable<String> keysWithPrefix(String prefix) {
if (prefix == null) {
throw new IllegalArgumentException("calls keysWithPrefix() with null argument");
Queue<String> queue = new Queue<String>();
Node<Value> x = get(root, prefix, 0);
if (x == null) return queue;
if (x.val != null) queue.enqueue(prefix);
collect(x.mid, new StringBuilder(prefix), queue);
return queue;
// all keys in subtrie rooted at x with given prefix
private void collect(Node<Value> x, StringBuilder prefix, Queue<String> queue) {
if (x == null) return;
collect(x.left, prefix, queue);
if (x.val != null) queue.enqueue(prefix.toString() + x.c);
collect(x.mid, prefix.append(x.c), queue);
prefix.deleteCharAt(prefix.length() - 1);
collect(x.right, prefix, queue);
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
* @param pattern the pattern
* @return all of the keys in the symbol table that match {@code pattern},
* as an iterable, where . is treated as a wildcard character.
public Iterable<String> keysThatMatch(String pattern) {
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), 0, pattern, queue);
return queue;
private void collect(Node<Value> x, StringBuilder prefix, int i, String pattern, Queue<String> queue) {
if (x == null) return;
char c = pattern.charAt(i);
if (c == '.' || c < x.c) collect(x.left, prefix, i, pattern, queue);
if (c == '.' || c == x.c) {
if (i == pattern.length() - 1 && x.val != null) queue.enqueue(prefix.toString() + x.c);
if (i < pattern.length() - 1) {
collect(x.mid, prefix.append(x.c), i+1, pattern, queue);
prefix.deleteCharAt(prefix.length() - 1);
if (c == '.' || c > x.c) collect(x.right, prefix, i, pattern, queue);
* Unit tests the {@code TST} data type.
* @param args the command-line arguments
public static void main(String[] args) {
// build symbol table from standard input
TST<Integer> st = new TST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++) {
String key = StdIn.readString();
st.put(key, i);
// print results
if (st.size() < 100) {
for (String key : st.keys()) {
StdOut.println(key + " " + st.get(key));
for (String s : st.keysWithPrefix("shor"))
for (String s : st.keysThatMatch(".he.l."))
4.3 性能分析
- 时间复杂度
查找未命中平均需要比较~InN次。 - 空间复杂度