I modified your code a little, mainly to get the restart to work. I also chose to use an existing RowMapper rather than implement that myself.
The restart scenario exposes what I believe is a flaw in SqlParameterSource: there is no way to transparently add a value to it or get hold of all values in it. This is necessary in order to add the key after a restart. If you know the SqlParameterSource is Map-based, you can cast it and put a value. If you know it's bean-based, you can cast it and get all values, and then copy to a new Map-based, which you can then put to. I would have liked to see SqlParameterSource implement the Map interface, or at least provide a way to get all entries.
Anyway, here is my take on your code:
Code:
package example.item.reader;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.sql.DataSource;
import org.springframework.batch.item.ExecutionContext;
import org.springframework.batch.item.database.KeyCollector;
import org.springframework.batch.item.util.ExecutionContextUserSupport;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
/**
* Parameterized {@link KeyCollector} implementation that supports restart.
*
* @author Ulrik Sandberg
* @author "litius"
*/
public class SingleColumnWithParametersJdbcKeyCollector extends ExecutionContextUserSupport implements KeyCollector {
private static final String RESTART_KEY = "key";
private SqlParameterSource queryNamedParameters;
private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
private String sql;
private String restartSql;
private String restartIdName;
private RowMapper keyMapper = new SingleColumnRowMapper();
public SingleColumnWithParametersJdbcKeyCollector() {
setName(ClassUtils.getShortName(SingleColumnWithParametersJdbcKeyCollector.class));
}
/**
* Constructs a new instance using the provided namedParameterJdbcTemplate
* and string representing the sql statement that should be used to retrieve
* keys.
*
* @param namedParameterJdbcTemplate
* @param sql
* @throws IllegalArgumentException if jdbcTemplate is null.
* @throws IllegalArgumentException if sql string is empty or null.
*/
public SingleColumnWithParametersJdbcKeyCollector(NamedParameterJdbcTemplate namedParameterJdbcTemplate, String sql) {
this();
Assert.notNull(namedParameterJdbcTemplate, "namedParameterJdbcTemplate must not be null.");
Assert.hasText(sql, "The sql statement must not be null or empty.");
this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
this.sql = sql;
}
/*
* (non-Javadoc)
*
* @see org.springframework.batch.io.driving.KeyGenerationStrategy#retrieveKeys()
*/
public List retrieveKeys(ExecutionContext executionContext) {
Assert.notNull(executionContext, "The ExecutionContext must not be null");
if (executionContext.containsKey(getKey(RESTART_KEY))) {
Assert.state(StringUtils.hasText(restartSql), "The restart sql query must not be null or empty"
+ " in order to restart.");
Object storedKey = executionContext.get(getKey(RESTART_KEY));
MapSqlParameterSource mapSqlParameterSource = getModifiedParameterSource(storedKey);
return namedParameterJdbcTemplate.query(restartSql, mapSqlParameterSource, keyMapper);
}
else {
return namedParameterJdbcTemplate.query(sql, queryNamedParameters, keyMapper);
}
}
private MapSqlParameterSource getModifiedParameterSource(Object storedKey) {
MapSqlParameterSource mapSqlParameterSource = new MapSqlParameterSource();
if (queryNamedParameters instanceof MapSqlParameterSource) {
Map values = ((MapSqlParameterSource) queryNamedParameters).getValues();
mapSqlParameterSource.addValues(values);
mapSqlParameterSource.addValue(restartIdName, storedKey);
}
else if (queryNamedParameters instanceof BeanPropertySqlParameterSource) {
BeanPropertySqlParameterSource beanPropertySqlParameterSource = (BeanPropertySqlParameterSource) queryNamedParameters;
String[] names = beanPropertySqlParameterSource.getReadablePropertyNames();
HashMap values = new HashMap();
for (int i = 0; i < names.length; i++) {
Object value = beanPropertySqlParameterSource.getValue(names[i]);
values.put(names[i], value);
}
mapSqlParameterSource.addValues(values);
mapSqlParameterSource.addValue(restartIdName, storedKey);
}
else {
throw new IllegalArgumentException("Unsupported SqlParameterSource implementation: "
+ ClassUtils.getQualifiedName(queryNamedParameters.getClass()));
}
return mapSqlParameterSource;
}
/**
* Get the restart data representing the last processed key.
*
* @throws IllegalArgumentException if key is null.
*/
public void updateContext(Object key, ExecutionContext executionContext) {
Assert.notNull(key, "The key must not be null.");
Assert.notNull(executionContext, "The ExecutionContext must not be null");
executionContext.put(getKey(RESTART_KEY), key);
}
/*
* (non-Javadoc)
*
* @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
*/
public void afterPropertiesSet() throws Exception {
Assert.notNull(namedParameterJdbcTemplate, "JdbcTemplate must not be null.");
Assert.hasText(sql, "The DrivingQuery must not be null or empty.");
}
/**
* Set the {@link Map} to be used to create the {@link SqlParameterSource}.
*
* @param map
*/
public void setQueryNamedParametersMap(Map map) {
this.queryNamedParameters = new MapSqlParameterSource(map);
}
/**
* Set the {@link RowMapper} to be used to map each key to an object.
*
* @param keyMapper
*/
public void setKeyMapper(RowMapper keyMapper) {
this.keyMapper = keyMapper;
}
/**
* Set the SQL statement to be used to return the keys to be processed.
*
* @param sql
*/
public void setSql(String sql) {
this.sql = sql;
}
/**
* Set the SQL query to be used to return the remaining keys to be
* processed.
*
* @param restartSql
*/
public void setRestartSql(String restartSql) {
this.restartSql = restartSql;
}
/**
* Set the parameter name used in the restart query. Useful for supporting
* restart, where the parameter source needs to be updated with a new id
* that is read from the ExecutionContext.
*
* @param idName name of the identifier used in the restartSql
*/
public void setRestartIdName(String idName) {
this.restartIdName = idName;
}
/**
* Set the {@link DataSource} to be used.
*
* @param dataSource
*/
public void setDataSource(DataSource dataSource) {
this.namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(dataSource);
}
}