#include "bits/stdc++.h" using namespace std; using ll = long long; using Vi = vector; using Pii = pair; #define pb push_back #define mp make_pair #define x first #define y second #define rep(i,b,e) for (int i=(b); i < (e); i++) #define each(a,x) for(auto& a : (x)) #define all(x) (x).begin(), (x).end() #define sz(x) int((x).size()) #define prev prev_ Vi prev, cur; int n, m; vector tbl; ll ans; void solve() { fill(all(cur), 1); rep(i, 1, sz(tbl)) { swap(cur, prev); cur[0] = 1; int d = 1; rep(j, 1, sz(tbl[i])) { if (tbl[i][j-1] == tbl[i][j]) d++; else d = 1; cur[j] = min(d, prev[j]+1); if (tbl[i][j] != tbl[i-1][j]) cur[j] = 1; if (cur[j] > 1) { ans += cur[j]-1; } } } } int main() { cin.sync_with_stdio(0); cin.tie(0); cout << fixed << setprecision(18); cin >> n >> m; tbl.resize(n); each(r, tbl) cin >> r; prev.resize(max(n, m)); cur.resize(max(n, m)); solve(); reverse(all(tbl)); solve(); each(s, tbl) reverse(all(s)); solve(); reverse(all(tbl)); solve(); cout << ans << endl; return 0; }