public class SqlTest {
StringBuffer flag = new StringBuffer();
String sqlSelect1 = "SElect * from table_name where If = where_data";
String sqlSelect2 = "select *,t1.id From table_name t1 join table_name t2 on t1.where = t1.where_data" +
"where t1.id = 1";
String sqlSelect3 = "select * from table_name where where = where_data";
String sqlSelect4 = "select * from table_name t1 join (select * from table_name2 where id = 1) t2 on t1.id = t2.id";
String SqlSelect5 = "select `t`.`date` AS `date`,`t`.`税收` AS `税收`,`t`.`accumulative_total` AS `accumulative_total`,`t`.`tb` AS `tb`,`t1`.`环比` AS `环比` " +
"from (((select `a1`.`date` AS `date`,`a1`.`税收` AS `税收`,`a1`.`accumulative_total` AS `accumulative_total`,`a2`.`tb` AS `tb` " +
"from (`zhouyu`.`b_question13_2` `a1` left join (select `b2`.`date` AS `date_odl`,(case when (isnull(`b1`.`税收`) or ((`b1`.`税收` = 0) and (`b1`.`date` = `b2`.`date`))) " +
"then '无穷大' else concat(format((((ifnull(`b2`.`税收`,0) - ifnull(`b1`.`税收`,0)) * 100) / ifnull(`b1`.`税收`,0)),2),'%') end) AS `tb` from (`zhouyu`.`b_question13_1` `b1` " +
"left join `zhouyu`.`b_question13_2` `b2` on((`b1`.`date` = `b2`.`odl`)))) `a2` on((`a1`.`date` = `a2`.`date_odl`))))) `t` left join (select `b2`.`date` AS `date`," +
"(select concat(round((((`b2`.`税收` - `b1`.`税收`) / `b1`.`税收`) * 100),2),'%') from `zhouyu`.`b_question13_2` `b1` where (((`b2`.`y` = `b1`.`y`) and (`b2`.`m` = (`b1`.`m` + 1))) or " +
"((`b2`.`y` = (`b1`.`y` + 1)) and (`b2`.`m` = (`b1`.`m` - 3))))) AS `环比` from `zhouyu`.`b_question13_2` `b2`) `t1` on((`t`.`date` = `t1`.`date`)))";
String sqlDelete1 = "DELETE FROM table_name WHERE some_column = some_value";
int tableName;
Map<String, StringBuffer> SQLMap = new HashMap<>();
@Test
public void run() {
// SqlParse(sqlSelect1);
// SqlParse(sqlSelect2);
// SqlParse(sqlSelect3);
// SqlParse(sqlSelect4);
SqlParse(SqlSelect5);
}
@Test
public void SqlParse(String sqlSelect) {
List<StringBuffer> sqlOldList = new ArrayList<>();
StringBuffer sqlCaChe = new StringBuffer();
Map<Object, StringBuffer> SQLMap = new HashMap<>();
String sql;
//模板清洗
sql = sqlSelect.toLowerCase().replace("(", " ( ").replace(")", " ) ");//全小写和增加空格,方便数据拆分
for (String s : sql.split(" ")) {
if (!s.toString().equals(" ") && !s.toString().equals("")) {
sqlOldList.add(new StringBuffer(s));
}
}
this.getSQL(sqlOldList, sqlCaChe);
System.out.println(sqlCaChe);
}
List<String> Lexclude = Arrays.asList("select", "from", "join");
List<String> Rexclude = Arrays.asList("full", "union");
/**
* 递归分析sql语法分析,并创建 SQLMap ,将sql进行分解,用别名作为key ,子sql 为值
* "select * from table_name t1 join (select * from table_name2 where id = 1) t2 on t1.id = t2.id";
* @param sqlSelect sql列表
* @param sqlCaChe sql缓存
* @return 表别名
*/
public String getSQL(List<StringBuffer> sqlSelect, StringBuffer sqlCaChe) {
String tableAs = "";
String tableName = "";
List<String> stack = new ArrayList<>();
int stackTop = 0;
boolean isVerify = false;
for (int i = 0; sqlSelect.size() > i; i++) {
sqlCaChe.append(sqlSelect.get(i));
sqlCaChe.append(" ");
if (sqlSelect.get(i).equals("select") ) { // 确认是否为 sql
isVerify = true;
stack.add("select"); //放入栈
if (sqlSelect.get(i - 1).toString().equals("(") && isVerify) { // 左括号匹配, 放入栈, 栈顶+ 1
stack.add("("); //放入栈
stackTop++;
}
}else if (sqlSelect.get(i).toString().equals("from") && sqlSelect.get(i + 1).toString().equals("(")) { // 减小规模 ,遇到 from 且( 递归
stack.add("from("); //放入栈
StringBuffer twoSqlCache = sqlCaChe; // select ~ from 的
sqlCaChe = new StringBuffer(); // from 之后的
tableAs = this.getSQL(sqlSelect.subList(i + 1, sqlSelect.size()), sqlCaChe); // 递归, 获取 from 之后的 sql
if (isVerify) {
twoSqlCache.append(sqlCaChe); // 合并 sql
twoSqlCache.append(tableAs); // 起别名
this.SQLMap.put(tableAs, twoSqlCache); // 将 子sql 用别名作为key,存到 map中
tableName = tableAs;
}
sqlCaChe = new StringBuffer(); // 重置缓存
} else if (sqlSelect.get(i).toString().equals("from")){ // 普通sql select * from * ...
stack.add("from"); //放入栈
StringBuffer sqlEndSelect = this.rangeList(sqlSelect.subList(i,sqlSelect.size()),"select");
if(sqlEndSelect.toString().contains("join")){
tableName = this.getTableName(sqlEndSelect); // 获取表名
}
}else if (sqlSelect.get(i).toString().equals("join") && sqlSelect.get(i+1).toString().equals("(")) { // // 减小规模 ,遇到 join ( 递归
stack.add("join("); //放入栈
StringBuffer twoSqlCache = sqlCaChe; // select ~ from 的
sqlCaChe = new StringBuffer(); // from 之后的
tableAs = this.getSQL(sqlSelect.subList(i + 1, sqlSelect.size()), sqlCaChe); // 递归, 获取 from 之后的 sql
if (isVerify){
twoSqlCache.append(sqlCaChe); // 合并 sql
twoSqlCache.append(tableAs); // 起别名
this.SQLMap.put(tableAs,twoSqlCache); // 将 子sql 用别名作为key,存到 map中
tableName = tableAs;
}
sqlCaChe = new StringBuffer(); // 重置缓存
}else if(this.Rexclude.indexOf(sqlSelect.get(i).toString()) != -1){ // 如果包含 Rexclude 中的关键字
stack.add(sqlSelect.get(i).toString()); //放入栈
StringBuffer twoSqlCache = sqlCaChe; // select ~ from 的
sqlCaChe = new StringBuffer(); // from 之后的
tableAs = this.getSQL(sqlSelect.subList(i + 1, sqlSelect.size()), sqlCaChe); // 递归, 获取 from 之后的 sql
if (isVerify) {
twoSqlCache.append(sqlCaChe); // 合并 sql
twoSqlCache.append(tableAs); // 起别名
this.SQLMap.put(tableAs, twoSqlCache); // 将 子sql 用别名作为key,存到 map中
tableName = tableAs;
}
sqlCaChe = new StringBuffer(); // 重置缓存
}
if (sqlSelect.get(i).toString().equals(")") && !stack.isEmpty()) { // 右括号匹配, 栈 - 1
isVerify = false;
System.out.println(stack.remove(stack.size() - 1));
if(stack.size()==0 && (stack.indexOf("from(") != -1 || stack.indexOf("from") != -1) && (stack.indexOf(sqlSelect.get(i)) != -1 ) ) {
System.out.println("结束-------");
break; // 终结条件 栈为空
}
}
}
if (sqlCaChe.toString().contains("where")) {
sqlCaChe.insert(sqlCaChe.indexOf("where") + 5 ," " + tableName + " flag = 0 " + " and ");
} else if (!sqlCaChe.toString().contains("where")) {
sqlCaChe.append(" where " + tableName + " flag = 0");
}
return "table_" + this.tableName++;
}
/**
* 截取到指定string的列表,并合并为 stringbuffer
* @param arrayList
* @param end
* @return
*/
public StringBuffer rangeList(List<StringBuffer> arrayList, String end) {
StringBuffer string = new StringBuffer();
for (StringBuffer str : arrayList) {
if (str.toString().equals("(") || str.toString().equals(")"))
continue;
if (!str.toString().equals(end)) {
string.append(str);
string.append(" ");
}
}
return string;
}
/**
* 根据一段sql获取表名 select * from table_name as t where t.id = 1
* select * from table_name as t join table_name2 as t2 on t.id = t2.id where t.id = 1
* @param sqlSelect
* @return
*/
public String getTableName(StringBuffer sqlSelect) {
String tableName = null;
List<String> sqlList = Arrays.asList(sqlSelect.toString().split(" "));
StringBuffer sqlsub = new StringBuffer();
if (sqlSelect.toString().contains("join")) {
if (sqlsub.toString().equals("left") || sqlsub.toString().equals("right")) { //获取 表名
tableName = sqlList.get(sqlList.indexOf("join") - 2) ; //获取 join 左边第二个真实表名
} else {
tableName = sqlList.get(sqlList.indexOf("join") - 1) ; //获取 join 左边第一个真实表名
}
}
return tableName;
}