#include #include #include #include using namespace std; typedef long long llint; typedef pair pii; int n, m; vector> flip1(const vector> &a) { vector> b(n, vector(m)); for (int i = 0; i < n; ++i) { for (int j = 0; j < m; ++j) { b[i][m - j - 1] = a[i][j]; } } return b; } vector> flip2(const vector> &a) { vector> b(n, vector(m)); for (int i = 0; i < n; ++i) { for (int j = 0; j < m; ++j) { b[n - i - 1][j] = a[i][j]; } } return b; } llint solve(const vector> &a) { vector> most_at(n, vector(m)); for (int i = n - 1; i >= 0; --i) { std::vector at(m); at[m - 1] = 1; for (int j = m - 2; j >= 0; --j) { if (a[i][j] == a[i][j + 1]) { at[j] = at[j + 1] + 1; } else { at[j] = 1; } } for (int j = 0; j < m; ++j) { if (i == n - 1 || a[i][j] != a[i + 1][j]) { most_at[i][j] = 1; } else { if (at[j] > most_at[i + 1][j]) { most_at[i][j] = most_at[i + 1][j] + 1; } else { most_at[i][j] = at[j]; } } } } llint ans = 0; for (int i = 0; i < n; ++i) for (int j = 0; j < m; ++j) ans += most_at[i][j] - 1; return ans; } int main(void) { scanf("%d%d", &n, &m); vector> a; for (int i = 0; i < n; ++i) { char s[1024]; scanf("%s", s); a.push_back(std::vector()); for (int j = 0; j < m; ++j) { a[i].push_back(s[j]); } } printf("%lld\n", solve(a) + solve(flip1(a)) + solve(flip2(a)) + solve(flip1(flip2(a)))); return 0; }