Skip to main content

java Spring Shard datasource with Mysql/Oracle

If you are implementing database sharding and using Spring JDBC then you are out of luck to using declarative transactions and find a Datasource with Spring that would handle sharding. I had to implement my own Datasource manager and own annotations to use declarative kind of transactions to hide complexities from average developers.  Its very important to abstract out cross cutting concerns as sharding and transactions so that any junior developers wont be confused and start copying code left and right without understanding the impact of their changes globally. 

So the idea is that
1) You would implement a ShardDataSourceManager that would be basically pool of connection pools and you would lookup a datasource by shard id.
2)You would define your own Transactional annotations and annotate methods with it
3) You need to write an interceptor at dao layer that would read annotations on method and some context info. From the context info you would lookup shard id and lookup datasource and inject into a thread local.
4)The dao layer when it looks up datasource would look into thread local to construct a jdbc template and execute queries on it.

Here is a sample ShardDataSourceManager, ShardTransactional Annotation

public @interface ShardTransactional {
      public abstract boolean readOnly() default false;

}

public class ShardTransactionInterceptor implements MethodInterceptor {
    private static final AppLogger logger = AppLogger.getLogger(ShardTransactionInterceptor.class);
    private static ThreadLocal dataSourceThreadLocal = new ThreadLocal();
    private ShardDataSourceManager shardDataSourceManager;
   
    public ShardDataSourceManager getShardDataSourceManager() {
        return shardDataSourceManager;
    }

    public void setShardDataSourceManager(ShardDataSourceManager shardDataSourceManager) {
        this.shardDataSourceManager = shardDataSourceManager;
    }


    @Override
    public Object invoke(final MethodInvocation method) throws Throwable {
        if (method.getMethod().isAnnotationPresent(ShardTransactional.class)) {
            try {
                ShardTransactional annotation = method.getMethod().getAnnotation(ShardTransactional.class);
                User user = getParam(method, User.class);
                if (user == null) {
                    throw new IllegalStateException("All transactional methods must have user argument");
                }
                TransactionTemplate transactionTemplate = new TransactionTemplate();
                boolean readOnly = annotation.readOnly();
                transactionTemplate.setReadOnly(readOnly);
                ShardInfo shardInfo =  getShardInfo(user);
                transactionTemplate.setName("ShardTransaction");
                transactionTemplate.setTransactionManager(shardDataSourceManager.getTransactionManagerByHostId(shardInfo.getHostId(), readOnly));
                cacheDataSourceInThreadLocal(shardInfo.getHostId(),readOnly);
                return transactionTemplate.execute(new TransactionCallback() {
                    @Override
                    public Object doInTransaction(TransactionStatus transactionStatus) {
                        try {
                            return method.proceed();
                        }catch (Throwable t) {
                            transactionStatus.setRollbackOnly();
                            logger.error("Rolling back transaction due to" ,t);
                            throw new RuntimeException(t);                       
                        }
                    }
                });
            } finally {
                dataSourceThreadLocal.set(null);
            }
        } else {
            return method.proceed();
        }
    }

    private ShardInfo getShardInfo(User user) {
        ...code to lookup shard by user   
        return shardInfo;
    }

    public static DataSource getDataSource() {
        return dataSourceThreadLocal.get();
    }
   
    private DataSource cacheDataSourceInThreadLocal(int hostId, boolean readOnly) {
        DataSource datasource = shardDataSourceManager.getDataSourceByHostId(hostId, readOnly);
        dataSourceThreadLocal.set(datasource);
        return datasource;
    }

    private T getParam(MethodInvocation method, Class clazz) {
        Method reflectMethod = method.getMethod();
        Class[] parameterTypes = reflectMethod.getParameterTypes();
        if (parameterTypes != null) {
            int i=0;
            boolean found = false;
            for (Class parameterType : parameterTypes) {
                if(clazz.isAssignableFrom(parameterType)) {
                    found = true;
                    break;
                }
                i++;
            }
            if (found) {
                T param = (T) method.getArguments()[i];
                return param;
            }
        }
        return null;
    }
}


public class ShardDataSourceManager {
   
    private static final AppLogger logger = AppLogger.getLogger(ShardDataSourceManager.class);
    private static boolean autoCommit = false;
   
    private Map dataSourceMap = new HashMap();

    private Map transactionManagerMap = new HashMap();

    private ShardManager shardManager;

    private String driverClassName = "org.gjt.mm.mysql.Driver";

    private int maxActive = 20;

    private int maxIdle = 5;

    private int maxWait = 180000;
   
    private int minEvictableIdleTimeMillis = 300000;
   
    private boolean testWhileIdle = true;

    private String validationQuery = "select 1 from dual";
   
    private String userName;

    private String userPassword;

    public String getDriverClassName() {
        return driverClassName;
    }

    public void setDriverClassName(String driverClassName) {
        this.driverClassName = driverClassName;
    }

    public int getMaxActive() {
        return maxActive;
    }

    public void setMaxActive(int maxActive) {
        this.maxActive = maxActive;
    }

    public int getMaxIdle() {
        return maxIdle;
    }

    public void setMaxIdle(int maxIdle) {
        this.maxIdle = maxIdle;
    }

    public int getMaxWait() {
        return maxWait;
    }

    public void setMaxWait(int maxWait) {
        this.maxWait = maxWait;
    }

    public int getMinEvictableIdleTimeMillis() {
        return minEvictableIdleTimeMillis;
    }

    public void setMinEvictableIdleTimeMillis(int minEvictableIdleTimeMillis) {
        this.minEvictableIdleTimeMillis = minEvictableIdleTimeMillis;
    }

    public boolean isTestWhileIdle() {
        return testWhileIdle;
    }

    public void setTestWhileIdle(boolean testWhileIdle) {
        this.testWhileIdle = testWhileIdle;
    }

    public String getValidationQuery() {
        return validationQuery;
    }

    public void setValidationQuery(String validationQuery) {
        this.validationQuery = validationQuery;
    }

    public String getUserPassword() {
        return userPassword;
    }

    public void setUserPassword(String userPassword) {
        this.userPassword = userPassword;
    }

    public String getUserName() {
        return userName;
    }

    public void setUserName(String userName) {
        this.userName = userName;
    }

    public void init() throws Exception {
        for (DbHost shardInfo : shardManager.getDbHosts()) {
            String url = "jdbc:mysql://" + shardInfo.getMasterHost();
            BasicDataSource dataSource = createDataSource(url, username);
            dataSourceMap.put(shardInfo.getHostId(), dataSource);
            DataSourceTransactionManager masterTransactionManager = new DataSourceTransactionManager(dataSource);
            transactionManagerMap.put(shardInfo.getHostId(), masterTransactionManager);
            logger.info("DataSource Created for hostid= {}, url= {}", shardInfo.getHostId(), dataSource.getUrl());
        }
    }

    private BasicDataSource createDataSource(String url, String username) {
        logger.info("Initing {} ", url);
        logger.info("Creating Datasource {}", url);
        BasicDataSource dataSource = new BasicDataSource();
        dataSource.setUrl(url);
        dataSource.setUsername(username);
        dataSource.setPassword(userPassword);
        dataSource.setValidationQuery(validationQuery);
        dataSource.setTestWhileIdle(true);
        dataSource.setConnectionProperties("useUnicode=true;characterEncoding=utf8");
        dataSource.setDefaultAutoCommit(autoCommit);
        dataSource.setMaxIdle(maxIdle);
        dataSource.setMaxWait(maxWait);
        dataSource.setMaxActive(maxActive);
        return dataSource;
    }

    private DataSource getDataSourceByHostId(int hostId) {
        DataSource dataSource = dataSourceMap.get(hostId);
        if (dataSource == null) {
            logger.warn("Could not find a data source for: {}", hostId);
            throw new IllegalArgumentException("Invalid dbname, no such pool configured: " + hostId);
        }
        return dataSource;
    }

    public DataSource getDataSourceByHostId(int hostId, boolean readOnly) {
        DataSource dataSource = null;
        if (dataSource == null) {
            logger.debug("Using Master datasource for hostid={}", hostId);
            dataSource = dataSourceMap.get(hostId);
        }
        if (dataSource == null) {
            String msg = "Could not find a data source for hostId=" + hostId;
            throw new IllegalArgumentException(msg);
        }
        return dataSource;
    }

    public DataSourceTransactionManager getTransactionManagerByHostId(int hostId, boolean readOnly) {
        DataSourceTransactionManager transactionManager = null;
        if (transactionManager == null) {
            logger.debug("Using Master transactionmanager for hostid={}", hostId);
            transactionManager = transactionManagerMap.get(hostId);
        }
        if (transactionManager == null) {
            String msg = "Could not find a data source for hostId=" + hostId;
            throw new IllegalArgumentException(msg);
        }
        return transactionManager;
    }

    public void destroy() throws Exception {
        logger.info("destroying pools");
        destroyPool(dataSourceMap);
        transactionManagerMap.clear();
    }

    private void destroyPool(Map dsMap) throws SQLException {
        if (dsMap != null) {
            for (BasicDataSource dataSource : dsMap.values()) {
                logger.info("Discarding pools: {}", dataSource);
                dataSource.close();
            }
        }
    }
}

Comments

  1. Seems like a neat solution. However as I observed sharding eventaully becomes much more than just inserts in a "shard-aware" connection pool. Cross-shard queries, transaction consistency and administration of the entire array - are crucial to have a a good sharding solution. You can have a look at ScaleBase (disclaimer: I work there), http://www.scalebase.com, to see how a this can be your 1-stop-shop for all of your sharding needs, totally transparent (standard conn pool... :) ).

    ReplyDelete
  2. Can I get the source code for this to play with?

    ReplyDelete
  3. except the imports the code pasted above is the real source code we have live in production serving 1B+ rows from 20 mysql servers. I havent got a chance to put it on github yet.

    ReplyDelete
  4. Any github project ? looks nice, i'm doing similar stuff and i'd like to fork and contribute if possible

    ReplyDelete

Post a Comment

Popular posts from this blog

Haproxy and tomcat JSESSIONID

One of the biggest problems I have been trying to solve at our startup is to put our tomcat nodes in HA mode. Right now if a customer comes, he lands on to a node and remains there forever. This has two major issues: 1) We have to overprovision each node with ability to handle worse case capacity. 2) If two or three high profile customers lands on to same node then we need to move them manually. 3) We need to cut over new nodes and we already have over 100+ nodes.  Its a pain managing these nodes and I waste lot of my time in chasing node specific issues. I loath when I know I have to chase this env issue. I really hate human intervention as if it were up to me I would just automate thing and just enjoy the fruits of automation and spend quality time on major issues rather than mundane task,call me lazy but thats a good quality. So Finally now I am at a stage where I can put nodes behing HAProxy in QA env. today we were testing the HA config and first problem I immediat...

Spring 3.2 quartz 2.1 Jobs added with no trigger must be durable.

I am trying to enable HA on nodes and in that process I found that in a two test node setup a job that has a frequency of 10 sec was running into deadlock. So I tried upgrading from Quartz 1.8 to 2.1 by following the migration guide but I ran into an exception that says "Jobs added with no trigger must be durable.". After looking into spring and Quartz code I figured out that now Quartz is more strict and earlier the scheduler.addJob had a replace parameter which if passed to true would skip the durable check, in latest quartz this is fixed but spring hasnt caught up to this. So what do you do, well I jsut inherited the factory and set durability to true and use that public class DurableJobDetailFactoryBean extends JobDetailFactoryBean {     public DurableJobDetailFactoryBean() {         setDurability(true);     } } and used this instead of JobDetailFactoryBean in the spring bean definition     <bean i...

Adding Jitter to cache layer

Thundering herd is an issue common to webapp that rely on heavy caching where if lots of items expire at the same time due to a server restart or temporal event, then suddenly lots of calls will go to database at same time. This can even bring down the database in extreme cases. I wont go into much detail but the app need to do two things solve this issue. 1) Add consistent hashing to cache layer : This way when a memcache server is added/removed from the pool, entire cache is not invalidated.  We use memcahe from both python and Java layer and I still have to find a consistent caching solution that is portable across both languages. hash_ring and spymemcached both use different points for server so need to read/test more. 2) Add a jitter to cache or randomise the expiry time: We expire long term cache  records every 8 hours after that key was added and short term cache expiry is 2 hours. As our customers usually comes to work in morning and access the cloud file server it ...