涉及类
BeanFactory、BeanDefinitionRegistryPostProcessor、JDK动态代理、ClassPathBeanDefinitionScanner
使用spring提供的包扫描工具获取需要代理的mapper接口的BeanDefinition集合
for循环设置每个BeanDefinition的beanClass为MapperFactoryBean.class
spring ioc 在创建对象时发现该对象的beanClass是FactoryBean的实现类
则通过FactoryBean接口的getObject()方法获取对象
mybatis在这个接口利用jdk动态代理创建了代理对象返回。
简单模拟代码
自定义注解
@Documented
@Inherited
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER })
public @interface MyMapper {
String value() default "";
}
被代理接口类
@MyMapper("student")
public interface MyStudentMapper {
String selectById(int id);
}
通知类
public class MyMapperProxy implements InvocationHandler {
private Map<String, String> nameSqlMap;
public MyMapperProxy(String tableName) {
//把数据库连接、表映射等信息传进来
this.nameSqlMap = new HashMap<>();
this.nameSqlMap.put("selectById"
, "select * from " + tableName + " where id = ");
this.nameSqlMap.put("selectOne"
, "select * from " + tableName + " limit 1 ");
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
return execute(method, args);
}
public Object execute(Method method, Object[] args) {
String sql = getSql(method, args);
if (sql == null) return null;
System.out.println("执行sql ======> " + sql);
return "这是查到的数据";
}
private String getSql(Method method, Object[] args) {
String sql = nameSqlMap.get(method.getName());
if (Objects.nonNull(sql)) {
return sql + args[0];
}
return null;
}
FactoryBean类
public class MyProxyFactoryBean<T> implements FactoryBean<T> {
private MyMapperProxyFactory factory;
private Class<T> mapperInterface;
private ClassLoader loader;
private MyMapperProxy proxy;
public MyProxyFactoryBean(MyMapperProxyFactory factory,
Class<T> mapperInterface,
ClassLoader loader,
MyMapperProxy proxy) {
this.factory = factory;
this.mapperInterface = mapperInterface;
this.loader = loader;
this.proxy = proxy;
}
@Override
public T getObject() {
return (T) factory.gerProxy(this.loader, this.mapperInterface, this.proxy);
}
@Override
public Class<?> getObjectType() {
return this.mapperInterface;
}
}
代理生成工具类
public class MyMapperProxyFactory<T> {
public T gerProxy(ClassLoader loader,
Class<?> interfaces,
InvocationHandler proxy) {
return (T) Proxy.newProxyInstance(loader, new Class[]{interfaces}, proxy);
}
}
扫描包工具类
public class MyBeanDefinitionScanner extends ClassPathBeanDefinitionScanner {
public MyBeanDefinitionScanner(BeanDefinitionRegistry registry) {
super(registry, false);
addIncludeFilter(new AnnotationTypeFilter(MyMapper.class));
}
@Override
protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
Set<BeanDefinitionHolder> beanDefinitionHolders = super.doScan(basePackages);
if (beanDefinitionHolders.isEmpty()) {
return beanDefinitionHolders;
}
postProcessBeanDefinitions(beanDefinitionHolders);
return beanDefinitionHolders;
}
private void postProcessBeanDefinitions(Set<BeanDefinitionHolder> beanDefinitionHolders) {
for (BeanDefinitionHolder holder : beanDefinitionHolders) {
AbstractBeanDefinition beanDefinition
= (AbstractBeanDefinition) holder.getBeanDefinition();
String beanClassName = beanDefinition.getBeanClassName();
Class<?> className;
try {
className = Class.forName(beanClassName, true, Thread.currentThread().getContextClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
MyMapper annotation = className.getAnnotation(MyMapper.class);
if (Objects.isNull(annotation)) {
continue;
}
ConstructorArgumentValues constructorArguments = beanDefinition.getConstructorArgumentValues();
MyMapperProxyFactory<Object> factory = new MyMapperProxyFactory<>();
MyMapperProxy myMapperProxy = new MyMapperProxy(annotation.value());
constructorArguments.addIndexedArgumentValue(0, factory);
constructorArguments.addIndexedArgumentValue(1, className);
constructorArguments.addIndexedArgumentValue(2, Thread.currentThread().getContextClassLoader());
constructorArguments.addIndexedArgumentValue(3, myMapperProxy);
//修改BeanDefinition类
beanDefinition.setBeanClass(MyProxyFactoryBean.class);
}
}
@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
return beanDefinition.getMetadata().isInterface() && beanDefinition.getMetadata().isIndependent();
}
}
修改BeanDefinition启动类
@Component
public class MyMapperBeanPostProcessor implements BeanDefinitionRegistryPostProcessor
, BeanFactoryAware {
private BeanFactory beanFactory;
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
MyBeanDefinitionScanner scanner = new MyBeanDefinitionScanner(registry);
String packageName = AutoConfigurationPackages.get(this.beanFactory).get(0);
scanner.scan(
StringUtils.tokenizeToStringArray(packageName,
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS));
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}
}