I came up with this, which seems to work for the couple of cases I've tested. Only tested with Spring 2.5.
Code:
package ca.digitalrapids.spring.context.support;
import java.util.HashMap;
import java.util.Map;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.ClassPathXmlApplicationContext;
/**
* Subclass of {@link ClassPathXmlApplicationContext} that allows overriding of
* individual beans with custom instantiations. This is useful mocking out
* particular beans in integration tests. This does NOT work with autowiring.
*/
@SuppressWarnings("unchecked")
public class OverridableClassPathXmlApplicationContext
extends ClassPathXmlApplicationContext
{
private class MyDefaultListableBeanFactory extends DefaultListableBeanFactory {
public MyDefaultListableBeanFactory(BeanFactory parentBeanFactory)
{
super(parentBeanFactory);
}
@Override
public Object getBean(final String name, final Class requiredType,
final Object[] args) throws BeansException {
Object bean = singletonsByBeanName.get(name);
return bean == null ? super.getBean(name, requiredType, args) : bean;
}
}
private final Map<String, Object> singletonsByBeanName = new HashMap<String, Object>();
public OverridableClassPathXmlApplicationContext(String[] configLocations,
boolean refresh,
ApplicationContext parent,
Map<String, Object> singletonsByBeanName) throws BeansException
{
super(configLocations, false, parent);
if ( singletonsByBeanName != null )
this.singletonsByBeanName.putAll(singletonsByBeanName);
if ( refresh ) refresh();
}
@Override
protected DefaultListableBeanFactory createBeanFactory()
{
return new MyDefaultListableBeanFactory(getInternalParentBeanFactory());
}
}
Unit test:
Code:
package ca.digitalrapids.spring.context.support;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
public class OverridableClassPathXmlApplicationContextTest
{
private static final String TO_BE_OVERRIDEN_BEAN_NAME = "toBeOverriden";
private static final String ANOTHER_BEAN_NAME = "anotherBean";
public interface SomeInterface {
}
static public class SomeInterfaceImpl implements SomeInterface {
}
static public class SomeInterfaceImpl2 implements SomeInterface {
}
static public class AnotherBean {
private final SomeInterface dependency;
private SomeInterface dependency2;
public AnotherBean(SomeInterface dependency)
{
super();
this.dependency = dependency;
}
public SomeInterface getDependency()
{
return dependency;
}
public SomeInterface getDependency2()
{
return dependency2;
}
public void setDependency2(SomeInterface dependency2)
{
this.dependency2 = dependency2;
}
}
private OverridableClassPathXmlApplicationContext context;
@Before
public void setUp() throws Exception
{
}
@Test
public void testNoOverride() throws Throwable
{
context = new OverridableClassPathXmlApplicationContext(new String[] {
"ca/digitalrapids/spring/context/support/spring-test-beans.xml"
}, true, null, null);
assertEquals(SomeInterfaceImpl.class,
context.getBean(TO_BE_OVERRIDEN_BEAN_NAME).getClass());
AnotherBean anotherBean = (AnotherBean)context.getBean(ANOTHER_BEAN_NAME);
assertNotNull(anotherBean);
assertEquals(SomeInterfaceImpl.class, anotherBean.getDependency().getClass());
assertEquals(SomeInterfaceImpl.class, anotherBean.getDependency2().getClass());
}
@Test
public void testOverride() throws Throwable
{
SomeInterface overrider = new SomeInterfaceImpl2();
Map<String, Object> singletonsByBeanName = new HashMap<String, Object>();
singletonsByBeanName.put(TO_BE_OVERRIDEN_BEAN_NAME, overrider);
context = new OverridableClassPathXmlApplicationContext(new String[] {
"ca/digitalrapids/spring/context/support/spring-test-beans.xml"
}, true, null, singletonsByBeanName);
assertEquals(SomeInterfaceImpl2.class,
context.getBean(TO_BE_OVERRIDEN_BEAN_NAME).getClass());
AnotherBean anotherBean = (AnotherBean)context.getBean(ANOTHER_BEAN_NAME);
assertNotNull(anotherBean);
assertEquals(SomeInterfaceImpl2.class, anotherBean.getDependency().getClass());
assertEquals(SomeInterfaceImpl2.class, anotherBean.getDependency2().getClass());
}
}
spring-test-beans.xml used by unit test:
Code:
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE beans PUBLIC "-//SPRING//DTD BEAN//EN"
"http://www.springframework.org/dtd/spring-beans.dtd">
<beans>
<bean id="toBeOverriden"
class="ca.digitalrapids.spring.context.support.OverridableClassPathXmlApplicationContextTest$SomeInterfaceImpl"/>
<bean id="anotherBean"
class="ca.digitalrapids.spring.context.support.OverridableClassPathXmlApplicationContextTest$AnotherBean">
<constructor-arg ref="toBeOverriden"/>
<property name="dependency2" ref="toBeOverriden"/>
</bean>
</beans>