用户树查找优化

1
2
3
4
5
6
7
8
9
    A                   B
/ \ /|\
C D E F G
/ \ / \
H I J K
/ \
L M
/ \
N O
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
id  pid name age  email
1 0 A 18 A@baomidou.com
2 0 B 20 B@baomidou.com
3 1 C 28 C@baomidou.com
4 1 D 21 D@baomidou.com
5 2 E 24 E@baomidou.com
6 2 F 18 F@baomidou.com
7 2 G 10 G@baomidou.com
8 3 H 35 H@baomidou.com
9 3 I 21 I@baomidou.com
10 5 J 18 J@baomidou.com
11 5 K 23 K@baomidou.com
12 9 L 18 L@baomidou.com
13 9 M 33 M@baomidou.com
14 13 N 22 N@baomidou.com
15 13 O 11 O@baomidou.com

背景

节点和父节点之间只有 id 和 pid 上有关联;没有层级字段,没有当前所在的树高度字段,如何在不改变表的情况下,获取所有子节点?;

如果数据库可以重新设计的话,可以用下面的链接介绍的方式来设计:
在数据库中存储一棵树,实现无限级分类 - 个人文章 - SegmentFault 思否

思路

递归的方式

找到自己的子节点,再找子子节点,再找子子子节点…
这种方式虽能实现,但是如果节点关系复杂越来,因为要每次都要从 sql 中建立连接并查询,就会越来越慢
(递归的方式 200 多个后代要 2s 左右才能查出来; 在同样的数据下,用下面介绍的迭代的方式来查不到 100ms)。

迭代的方式

  1. 根据 id 查找后代节点时,拼成集合的方式 ids=set(id),根据 select id,pid from user where pid in (ids) 获取到所有的子节点(拼成实体对象:如具有 id 和 pid 属性的 ParentChild),加入到 ParentChild 列表中;
  2. 根据 1 获取的 id 拼成一组新的 ids 获取第二波数据的所有子节点;以此类推,直到 ids 为空为止,说明再无后代了;
  3. 根据 ParentChild 列表来构建 map 结构(其中的 key 是节点 ID,value 是直属子节点 ID 列表)
  4. 通过在第 3 步获取的 map 缓存,来进行更多的操作,如:查找直属子节点、 查找所有后代;

通过迭代的方式,一波一波地获取后代再数据库中进行查询,如果有 10 代,只需要从数据库中,查 10 次即可
(用递归的话,可能得成千上万次了),当然如果用户量特别多的时候,就不要用这种只有 id 和 pid 的方式了,得重新设计表了;
在数据库中存储一棵树,实现无限级分类 - 个人文章 - SegmentFault 思否

如何查找直属子节点

根据节点 id,直接从 map 缓存中获取即可;

如何查找所有后代

第 1 步: 建 1 个 set 集合,用来存储所有的后代;
第 2 步: 根据节点 id,从 map 中先找到直属子节点列表,加入到 set 中;
第 3 步: 再根据第 2 步得到的子节点列表,遍历这些子节点,从 map 中获取这些子节点的子节点(递归到第 2 步)

代码实现

service 层实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@Service
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {

@Override
public List<User> listDirectChild(Integer userId) {
AbsUserTreeHelper helper = newUserTreeHelper(userId);
return getUsers(helper.getDirectChildIds(userId));
}

private AbsUserTreeHelper newUserTreeHelper(Integer userId) {
checkUserId(userId);
AbsUserTreeHelper helper = new UserTreeHelper(this.baseMapper);
helper.init(userId);
return helper;
}

private void checkUserId(Integer userId) {
if (userId == null) {
throw new ParamException("参数无效:userId为" + userId);
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
public class UserTreeHelper extends AbsUserTreeHelper {
UserMapper userMapper;

public UserTreeHelper(UserMapper userMapper) {
this.userMapper = userMapper;
}

@Override
protected List<ParentChild> listParentChild(List<Integer> userIdList) {
return this.userMapper.listParentChild(userIdList);
}
}

mapper 层实现

1
2
3
4
public interface UserMapper extends BaseMapper<User> {

List<ParentChild> listParentChild(@Param("userIdList") List<Integer> userIdList);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.lyloou.demo.mapper.UserMapper">

<!-- 通用查询映射结果 -->
<resultMap id="BaseResultMap" type="com.lyloou.demo.model.User">
<id column="id" property="id"/>
<result column="name" property="name"/>
<result column="age" property="age"/>
<result column="email" property="email"/>
</resultMap>

<select id="listParentChild" resultType="com.lyloou.demo.model.ParentChild">
select id as childId, pid as parentId from t_user where pid in
<foreach item="item" index="index" collection="userIdList"
open="(" separator="," close=")">
#{item}
</foreach>
</select>

</mapper>

UserTreeHelper 源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// UserTreeHelper.java

/**
* 获取用户的直接下属,或全部下属
*/
public class UserTreeHelper {

private Map<Integer, List<Integer>> parentToChildrenMap;

/**
* 获取map中的所有用户ID
*
* @return 用户ID集合
*/
public Set<Integer> getAllIds() {
checkInit();
Set<Integer> allIds = new HashSet<>();
Set<Integer> keySet = parentToChildrenMap.keySet();
for (Integer key : keySet) {
allIds.add(key);
allIds.addAll(getAllChildIds(key));
}
return allIds;
}

/**
* 获取树上的全部用户ID列表(不包含自己)
*
* @param userId 用户ID
* @return 全部用户ID
*/
public Set<Integer> getAllChildIds(Integer userId) {
checkInit();
return getAllChildIds(userId, null);
}

private void checkInit() {
if (parentToChildrenMap == null) {
throw new RuntimeException("没有调用初始化方法,请调用init方法初始化");
}
}

private Set<Integer> getAllChildIds(Integer userId, Set<Integer> idSet) {
if (idSet == null) {
idSet = new HashSet<>();
}
List<Integer> idList = parentToChildrenMap.getOrDefault(userId, Collections.emptyList());
if (idList.isEmpty()) {
return idSet;
}
idSet.addAll(idList);
for (Integer id : idList) {
getAllChildIds(id, idSet);
}
return idSet;
}

/**
* 获取直属用户ID列表 (不包含自己)
*
* @param userId 用户的ID
* @return 直属用户ID列表
*/
public Set<Integer> getDirectChildIds(Integer userId) {
checkInit();
return new HashSet<>(parentToChildrenMap.getOrDefault(userId, new ArrayList<>()));
}


/**
* 根据用户ID来初始化 Helper
*
* @param userId 用户ID
*/
public void init(Integer userId) {
init(Collections.singletonList(userId));
}

/**
* 把一组用户和它们的后代缓存起来
* 后面可以通过map.get(userId) 的方式获取userId对应的直接子节点
*
* @param userIdList 用户列表
*/
public void init(List<Integer> userIdList) {
parentToChildrenMap = new HashMap<>();
iterateUser(userIdList);
}

/**
* 通过迭代的方式重新实现,一波一波地获取,而不是一个一个地获取
*
* @param userIdList 用户列表
*/
private void iterateUser(List<Integer> userIdList) {
Set<ParentChild> allParentChildren = new HashSet<>();
List<ParentChild> parentChildren = listParentChild(userIdList);
while (!parentChildren.isEmpty()) {
allParentChildren.addAll(parentChildren);

// 获取这一波的id
List<Integer> newRoundIds = parentChildren.stream().map(ParentChild::getChildId).collect(Collectors.toList());
parentChildren = listParentChild(newRoundIds);
}

for (ParentChild parentChild : allParentChildren) {
List<Integer> oldList = parentToChildrenMap.getOrDefault(parentChild.getParentId(), new ArrayList<>());
oldList.add(parentChild.getChildId());
parentToChildrenMap.put(parentChild.getParentId(), oldList);
}

}

public List<ParentChild> listParentChild(List<Integer> userIdList){
//伪代码 List<ParentChild> list = select id,pid from user where pid in (userIdList);
}

}

// ParentChild.java
@Data
public class ParentChild {
private Integer id;
private Integer pid;
}

// Main.java
public class Main{
public static void main(String[] args) {
UserTreeHelper helper = new UserTreeHelper();
helper.init(1);
List<Integer> ids = helper.getAllChildIds(1);
System.out.println(ids);
}
}

更多可查看源码实现

https://github.com/lyloou/spring-boot-web/blob/v1.1.0/src/main/java/com/lyloou/demo/service/helper/AbsUserTreeHelper.java