目录
一,基类模板
template<typename T>
class SearchBase //填数字的搜索算法基类模板
{
public:
SearchBase(int num, int low, int high) {
this->num = num, this->low = low, this->high = high;
grid.resize(num);
type.resize(num);
for (int i = 0; i < num; i++) {
type[i] = 0;//默认格子类型
searchOrder.push_back(i);//默认搜索顺序
}
}
vector<int> getAns()
{
trys(0);
return grid;
}
protected:
bool trys(int id)
{
if (id >= searchOrder.size())return true;
int sid = searchOrder[id];
if (type[sid] == -1)return trys(id + 1);
for (int k = low; k <= high; k++) {
if (ok(sid, k)) {
if (type[sid] == 0)set(sid, k);
if (trys(id + 1))return true;
if (type[sid] == 0)reback(sid, k);
}
}
return false;
}
inline bool ok(int id, int k)
{
//根据grid和自定义统计量判断
return ((T*)this)->ok(id, k);
}
inline void set(int id, int k)
{
//刷新统计量
return ((T*)this)->set(id, k);
}
inline void reback(int id, int k)
{
//回溯统计量
return ((T*)this)->reback(id, k);
}
int num;//格子数量
int low, high;//每个格子填入数字的范围是[low,high], 限制一:low>0
vector<int>type;//标识所有格子类型,0是需要填数字,1是已知数字,-1是无效格子
vector<int>searchOrder;//设定搜索顺序
vector<int>grid;//所有格子中的数字
// 自定义一些统计量
};
二,分组定和搜索算法
class SearchWithGroupSum :public SearchBase<SearchWithGroupSum> //分组定和搜索算法
{
friend class SearchBase<SearchWithGroupSum>;
public:
SearchWithGroupSum(vector<vector<int>>& gridGroup, vector<int>& groupSum, int num, int low, int high) : SearchBase(num, low, high)
{
this->gridGroup = gridGroup;
this->groupSum = groupSum;
anti.resize(num);
grid.resize(num);
for (int g = 0; g < gridGroup.size(); g++) {
for (int i = 0; i < gridGroup[g].size(); i++)anti[gridGroup[g][i]].push_back(g);
groupCollect.push_back(0);
}
setLastId();
}
void setGrid(vector<int>& grid)//除了已知数字之外都填0,有type=1的格子时需要调用本接口
{
for (int i = 0; i < grid.size(); i++)if (grid[i])set(i, grid[i]);
}
void setType(vector<int>& type)//有type=1或-1的格子时需要调用本接口
{
this->type = type;
}
void setSearchOrder(vector<int>& searchOrder)//完全自定义搜索顺序,非必须调用
{
this->searchOrder = searchOrder;
setLastId();
}
void setSearchOrder()//根据分组自动设定搜索顺序,非必须调用
{
searchOrder.clear();
vector<int>numInGroup(gridGroup.size());
for (int i = 0; i < gridGroup.size(); i++) {
for (auto id : gridGroup[i])if (type[id] == 0)numInGroup[i]++;
}
int nums = 0;
for (int i = 0; i < num; i++)if (type[i] == 0)nums++;
map<int, int>visit;
while (nums--) {
searchOrder.push_back(bestGridId(visit,numInGroup));
}
setLastId();
}
int bestGridId(map<int, int>&visit,vector<int>&numInGroup)
{
int minId = min_element(numInGroup.begin(), numInGroup.end()) - numInGroup.begin();
vector<int>v1, v2;
for (auto id : gridGroup[minId]) {
if (type[id] || visit[id])continue;
v1.push_back(id);
int s = 0;
for (auto k : anti[id]) {
s += gridGroup[k].size() - numInGroup[k];
}
v2.push_back(s);
}
int ans = max_element(v2.begin(), v2.end()) - v2.begin();
visit[v1[ans]] = 1;
for (auto gid : anti[v1[ans]]) {
numInGroup[gid]--;
if (numInGroup[gid] == 0)numInGroup[gid] = INT_MAX;
}
return v1[ans];
}
void setDifFlagFalse() //组内数字可重复的情况下,调用本接口
{
difFlag = false;
}
private:
inline void setLastId()
{
groupLastId.clear();
for (auto group : gridGroup)groupLastId.push_back(calLastId(group));
}
inline int calLastId(vector<int>group)
{
map<int, int>m;
for (auto g : group)m[g] = 1;
for (int i = searchOrder.size() - 1; i >= 0; i--)
if (m[searchOrder[i]])
return searchOrder[i];
}
inline bool ok(int id, int k)
{
if (!checkSum(id, k))return false;
if (difFlag && !checkDif(id, k))return false;
return true;
}
inline void set(int id, int k)
{
grid[id] = k;
for (auto p : anti[id]) {
groupCollect[p] ^= (1 << k - low);
groupSum[p] -= k;
}
}
inline void reback(int id, int k)
{
for (auto p : anti[id]) {
groupCollect[p] ^= (1 << k - low);
groupSum[p] += k;
}
}
inline bool checkSum(int id, int k)
{
if (type[id]) {
for (auto p : anti[id]) {
if (id == groupLastId[p] && groupSum[p] != 0)return false;
}
}
else {
for (auto p : anti[id]) {
if (id == groupLastId[p] && k != groupSum[p]) return false;
if (id != groupLastId[p] && k > groupSum[p]) return false;
}
}
return true;
}
inline bool checkDif(int id, int k)
{
if (type[id])return k == grid[id];
for (auto p : anti[id])if (groupCollect[p] & (1 << k - low))return false;
return true;
}
private:
const int M = 31; // 限制二:需要保证high-low<M
vector<vector<int>>gridGroup;//每一组有哪些格子,最多M个
vector<vector<int>>anti;//每个格子属于哪些组
vector<int>groupLastId;//每组最后一个搜索的id
vector<int>groupSum;//每一组的数字和统计量
bool difFlag = true;//是否要求每组内所有数字各不相同
vector<int>groupCollect;//每一组的数字互斥状态压缩统计
};
三,应用
1,数和
参见数和
int num(char a, char b)
{
if (a == 'X')return -1;
if (a == '.')return 0;
return (a - '0') * 10 + b - '0';
}
int main()
{
ios::sync_with_stdio(false);
int n, m;
map<int, int>rid, cid;
cin >> n >> m;
string s;
vector<vector<int>>gridGroup;
vector<int>groupSum;
vector<int> type(n * m);
for (int i = 0; i < n; i++)for (int j = 0; j < m; j++)
{
cin >> s;
if (num(s[0], s[1]) > 0) {
cid[i * m + j] = groupSum.size();
groupSum.push_back(num(s[0], s[1]));
gridGroup.push_back(vector<int>{});
}
if (num(s[3], s[4]) > 0) {
rid[i * m + j] = groupSum.size();
groupSum.push_back(num(s[3], s[4]));
gridGroup.push_back(vector<int>{});
}
if (num(s[0], s[1]) == 0) {
gridGroup[cid[i * m + j] = cid[i * m + j - m]].push_back(i * m + j);
gridGroup[rid[i * m + j] = rid[i * m + j - 1]].push_back(i * m + j);
type[i * m + j] = 0;
}
else type[i * m + j] = -1;
}
auto s1 = clock();
auto opt = SearchWithGroupSum(gridGroup, groupSum, n * m, 1, 9);
opt.setType(type);
auto ans = opt.getAns();
auto e1 = clock();
cout << endl << e1 - s1;
for (int i = 0; i < n; i++)for (int j = 0; j < m; j++)
{
if (ans[i * m + j])cout << ans[i * m + j];
else cout << '_';
if (j < m - 1)cout << ' ';
else cout << endl;
}
return 0;
}
输入:
6 6
XXXXX XXXXX 28\XX 17\XX 28\XX XXXXX
XXXXX 22\22 ..... ..... ..... 10\XX
XX\34 ..... ..... ..... ..... .....
XX\14 ..... ..... 16\13 ..... .....
XX\22 ..... ..... ..... ..... XXXXX
XXXXX XX\16 ..... ..... XXXXX XXXXX
输出:
_ _ _ _ _ _
_ _ 5 8 9 _
_ 7 6 9 8 4
_ 6 8 _ 7 6
_ 9 2 7 4 _
_ _ 7 9 _ _
耗时 0 ms
2,标准数独
string Sudoku(string s, char cEmpty = '.')
{
vector<vector<int>>gridGroup;
vector<int>v;
for (int i = 0; i < 9; i++) {
v.clear();
for (int j = 0; j < 9; j++)v.push_back(i * 9 + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < 9; j++)v.push_back(j * 9 + i);
gridGroup.push_back(v);
}
for (int i = 0; i < 3; i++)for (int j = 0; j < 3; j++) {
v.clear();
for (int r = i * 3; r < i * 3 + 3; r++)for (int c = j * 3; c < j * 3 + 3; c++)v.push_back(r * 9 + c);
gridGroup.push_back(v);
}
vector<int> groupSum(27, 45);
SearchWithGroupSum opt(gridGroup, groupSum, 81, 1, 9);
vector<int>grid(81);
vector<int>type(81);
for (int i = 0; i < 81; i++)if (s[i] != cEmpty)grid[i] = s[i] - '0', type[i] = 1;
opt.setGrid(grid);
opt.setType(type);
v = opt.getAns();
string ans(81, '0');
for (int i = 0; i < 81; i++)ans[i] = v[i] + '0';
return ans;
}
int main()
{
ios::sync_with_stdio(false);
string s;
while (cin >> s)
{
auto s1 = clock();
if (s == "end")return 0;
cout << Sudoku(s) << endl;
auto e1 = clock();
cout << endl << e1 - s1;
}
return 0;
}
输入:
.2738..1..1...6735.......293.5692.8...........6.1745.364.......9518...7..8..6534.
......52..8.4......3...9...5.1...6..2..7........3.....6...1..........7.4.......3.
输出:
527389416819426735436751829375692184194538267268174593643217958951843672782965341
416837529982465371735129468571298643293746185864351297647913852359682714128574936
分别耗时 0ms、28ms
3,不规则数独
和标准数独的代码差异非常小
string Sudoku(string s, vector<int>&par, char cEmpty = '.', int parEmpty = -1)
{
vector<vector<int>>gridGroup;
vector<int>v;
map<int, vector<int>>m;
for (int i = 0; i < 9; i++) {
v.clear();
for (int j = 0; j < 9; j++)v.push_back(i * 9 + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < 9; j++) {
v.push_back(j * 9 + i);
if (par[i * 9 + j] != parEmpty)m[par[i * 9 + j]].push_back(i * 9 + j);
}
gridGroup.push_back(v);
}
for(auto mi:m)gridGroup.push_back(mi.second);
vector<int> groupSum(27, 45);
SearchWithGroupSum opt(gridGroup, groupSum, 81, 1, 9);
vector<int>grid(81);
vector<int>type(81);
for (int i = 0; i < 81; i++)if (s[i] != cEmpty)grid[i] = s[i] - '0', type[i] = 1;
opt.setGrid(grid);
opt.setType(type);
v = opt.getAns();
string ans(81, '0');
for (int i = 0; i < 81; i++)ans[i] = v[i] + '0';
return ans;
}
int main()
{
ios::sync_with_stdio(false);
string s;
vector<int>par(81);
while (cin >> s)
{
for (int i = 0; i < 81; i++)cin >> par[i];
auto s1 = clock();
if (s == "end")return 0;
cout << Sudoku(s,par) << endl;
auto e1 = clock();
cout << endl << e1 - s1;
}
return 0;
}
输入:
.3.159.8.2.9...6.3..78.34..9...4...57.6...1.83...9...6..29.75..5.1...8.2.7.516.2.
1 2 2 3 3 3 4 4 4
1 2 2 3 3 3 4 4 4
1 2 2 2 3 3 4 4 4
1 2 5 2 5 3 6 6 6
1 1 5 5 5 5 5 6 6
1 1 1 8 5 9 5 9 6
7 7 7 8 8 9 9 9 6
7 7 7 8 8 8 9 9 6
7 7 7 8 8 8 9 9 6
输出:
634159287259478613127863459918642375746325198385291746462987531591734862873516924
耗时1ms
4,满覆盖杀手数独
和标准数独、不规则数独的代码差异都非常小
string Sudoku(string s, vector<int>&par, vector<int>&groupSum,char cEmpty = '.', int parEmpty = -1)
{
vector<vector<int>>gridGroup;
vector<int>v;
map<int, vector<int>>m;
for (int i = 0; i < 9; i++) {
v.clear();
for (int j = 0; j < 9; j++)v.push_back(i * 9 + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < 9; j++) {
v.push_back(j * 9 + i);
if (par[i * 9 + j] != parEmpty)m[par[i * 9 + j]].push_back(i * 9 + j);
}
gridGroup.push_back(v);
groupSum.insert(groupSum.begin(), 45);
groupSum.insert(groupSum.begin(), 45);
}
for(auto mi:m)gridGroup.push_back(mi.second);
SearchWithGroupSum opt(gridGroup, groupSum, 81, 1, 9);
vector<int>grid(81);
vector<int>type(81);
for (int i = 0; i < 81; i++)if (s[i] != cEmpty)grid[i] = s[i] - '0', type[i] = 1;
opt.setGrid(grid);
opt.setType(type);
v = opt.getAns();
string ans(81, '0');
for (int i = 0; i < 81; i++)ans[i] = v[i] + '0';
return ans;
}
int main()
{
ios::sync_with_stdio(false);
string s;
vector<int>groupSum;
vector<int>par(81);
while (cin >> s)
{
int x;
while (cin >> x) {
if (x == 0)break;
groupSum.push_back(x);
}
for (int i = 0; i < 81; i++)cin >> par[i];
auto s1 = clock();
if (s == "end")return 0;
cout << Sudoku(s, par, groupSum) << endl;
auto e1 = clock();
cout << endl << e1 - s1;
}
return 0;
}
输入:
.................................................................................
15 24 11 17 5 14 10 16 10 15 10 9 10 21 7 20 15 15 12 17 8 16 18 10 13 12 13 32 10 0
1 2 3 3 5 5 8 8 8
1 2 2 4 6 7 7 9 9
1 2 2 4 6 6 14 15 9
10 10 12 13 13 14 14 15 16
11 11 12 19 18 17 17 16 16
20 19 19 19 18 22 17 16 23
20 20 21 21 22 22 22 23 23
24 25 25 25 25 28 28 28 28
24 26 26 27 27 27 29 29 28
输出:
946532781183974652527816943691287534378495216452163897765348129814629375239751468
耗时8秒左右,与预期严重不符。
很快我发现,原来是漏了规则。
正确代码:
string Sudoku(string s, vector<int>&par, vector<int>&groupSum,char cEmpty = '.', int parEmpty = -1)
{
vector<vector<int>>gridGroup;
vector<int>v;
map<int, vector<int>>m;
for (int i = 0; i < 9; i++) {
v.clear();
for (int j = 0; j < 9; j++)v.push_back(i * 9 + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < 9; j++) {
v.push_back(j * 9 + i);
if (par[i * 9 + j] != parEmpty)m[par[i * 9 + j]].push_back(i * 9 + j);
}
gridGroup.push_back(v);
groupSum.insert(groupSum.begin(), 45);
groupSum.insert(groupSum.begin(), 45);
groupSum.insert(groupSum.begin(), 45);
}
for (int i = 0; i < 3; i++)for (int j = 0; j < 3; j++) {
v.clear();
for (int r = i * 3; r < i * 3 + 3; r++)for (int c = j * 3; c < j * 3 + 3; c++)v.push_back(r * 9 + c);
gridGroup.insert(gridGroup.begin(),v);
}
for(auto mi:m)gridGroup.push_back(mi.second);
SearchWithGroupSum opt(gridGroup, groupSum, 81, 1, 9);
vector<int>grid(81);
vector<int>type(81);
for (int i = 0; i < 81; i++)if (s[i] != cEmpty)grid[i] = s[i] - '0', type[i] = 1;
opt.setGrid(grid);
opt.setType(type);
v = opt.getAns();
string ans(81, '0');
for (int i = 0; i < 81; i++)ans[i] = v[i] + '0';
return ans;
}
int main()
{
ios::sync_with_stdio(false);
string s;
vector<int>groupSum;
vector<int>par(81);
while (cin >> s)
{
int x;
while (cin >> x) {
if (x == 0)break;
groupSum.push_back(x);
}
for (int i = 0; i < 81; i++)cin >> par[i];
auto s1 = clock();
if (s == "end")return 0;
cout << Sudoku(s, par, groupSum) << endl;
auto e1 = clock();
cout << endl << e1 - s1;
}
return 0;
}
耗时300毫秒。
或许即使少了条件也有唯一答案,但搜索效率一定会降低。
5,非满覆盖杀手数独
代码和满覆盖杀手数独完全相同。
输入:
.................................................................................
26 14 16 12 16 6 14 14 12 8 20 12 26 22 6 6 38 16 14 6 14 0
-1 1 2 2 3 3 -1 4 -1
1 1 1 5 5 3 -1 4 4
6 1 -1 5 5 8 8 -1 7
6 9 9 -1 10 10 11 12 7
13 13 13 14 -1 10 11 12 12
15 13 13 14 14 -1 11 11 21
15 -1 16 16 17 17 -1 -1 21
18 -1 -1 17 17 17 -1 19 19
18 18 -1 17 17 20 20 19 -1
输出:
956834217742169583183725946539241768867953124214678395421587639395416872678392451
耗时39秒
如果改成自定义搜索顺序:
string Sudoku(string s, vector<int>& par, vector<int>& groupSum, vector<int> &searchOrder,
char cEmpty = '.', int parEmpty = -1)
{
vector<vector<int>>gridGroup;
vector<int>v;
map<int, vector<int>>m;
for (int i = 0; i < 9; i++) {
v.clear();
for (int j = 0; j < 9; j++)v.push_back(i * 9 + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < 9; j++) {
v.push_back(j * 9 + i);
if (par[i * 9 + j] != parEmpty)m[par[i * 9 + j]].push_back(i * 9 + j);
}
gridGroup.push_back(v);
groupSum.insert(groupSum.begin(), 45);
groupSum.insert(groupSum.begin(), 45);
groupSum.insert(groupSum.begin(), 45);
}
for (int i = 0; i < 3; i++)for (int j = 0; j < 3; j++) {
v.clear();
for (int r = i * 3; r < i * 3 + 3; r++)for (int c = j * 3; c < j * 3 + 3; c++)v.push_back(r * 9 + c);
gridGroup.insert(gridGroup.begin(), v);
}
for (auto mi : m)gridGroup.push_back(mi.second);
SearchWithGroupSum opt(gridGroup, groupSum, 81, 1, 9);
vector<int>grid(81);
vector<int>type(81);
for (int i = 0; i < 81; i++)if (s[i] != cEmpty)grid[i] = s[i] - '0', type[i] = 1;
opt.setGrid(grid);
opt.setType(type);
opt.setSearchOrder(searchOrder);
v = opt.getAns();
string ans(81, '0');
for (int i = 0; i < 81; i++)ans[i] = v[i] + '0';
return ans;
}
int main()
{
ios::sync_with_stdio(false);
string s;
vector<int>groupSum;
vector<int>par(81);
while (cin >> s)
{
int x;
while (cin >> x) {
if (x == 0)break;
groupSum.push_back(x);
}
vector<int> searchOrder(81);
for (int i = 0; i < 81; i++)cin >> par[i];
for (int i = 0; i < 81; i++) {
cin >> x;
searchOrder[x-1]=i;
}
auto s1 = clock();
if (s == "end")return 0;
cout << Sudoku(s, par, groupSum, searchOrder) << endl;
auto e1 = clock();
cout << endl << e1 - s1;
}
return 0;
}
输入:
.................................................................................
26 14 16 12 16 6 14 14 12 8 20 12 26 22 6 6 38 16 14 6 14 0
-1 1 2 2 3 3 -1 4 -1
1 1 1 5 5 3 -1 4 4
6 1 -1 5 5 8 8 -1 7
6 9 9 -1 10 10 11 12 7
13 13 13 14 -1 10 11 12 12
15 13 13 14 14 -1 11 11 21
15 -1 16 16 17 17 -1 -1 21
18 -1 -1 17 17 17 -1 19 19
18 18 -1 17 17 20 20 19 -1
18 19 25 26 30 31 47 42 46
17 20 21 35 36 32 45 43 44
1 22 27 37 38 33 34 39 40
2 7 8 51 48 49 52 53 41
9 10 11 54 57 50 61 59 60
3 12 13 55 56 58 62 63 64
4 23 5 6 75 76 73 70 65
14 24 28 78 79 77 74 68 66
15 16 29 80 81 71 72 69 67
耗时177毫秒
6,幻方
幻方要用根据分组计算搜索顺序,否则性能差距很大。
vector<int> magicSquare(int n)
{
vector<vector<int>> gridGroup;
vector<int>v;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++)v.push_back(i * n + j);
gridGroup.push_back(v);
v.clear();
for (int j = 0; j < n; j++)v.push_back(j * n + i);
gridGroup.push_back(v);
v.clear();
}
for (int i = 0; i < n; i++)v.push_back(i*n + i);
gridGroup.push_back(v);
v.clear();
for (int i = 0; i < n; i++)v.push_back(i*n + n - 1 - i);
gridGroup.push_back(v);
v.clear();
for (int i = 0; i < n*n; i++)v.push_back(i);
gridGroup.push_back(v);
vector<int> groupSum(n * 2 + 2, (n*n + 1)*n / 2);
groupSum.push_back((n*n + 1)*n*n / 2);
SearchWithGroupSum opt(gridGroup, groupSum, n*n, 1, n*n);
opt.setSearchOrder();
return opt.getAns();
}
int main()
{
ios::sync_with_stdio(false);
auto s1 = clock();
auto v = magicSquare(5);
for (int i = 0; i < v.size(); i++) {
cout << v[i] << " ";
if ((i + 1) % 5 == 0)cout << endl;
}
auto e1 = clock();
cout << endl << e1 - s1;
return 0;
}
输出:
1 2 13 24 25
16 21 18 6 4
22 8 12 14 9
23 19 5 11 7
3 15 17 10 20
5阶幻方耗时33毫秒
6阶幻方搜了很久都没搜出来。