n, m = map(int, input().split())

w = [input() for i in range(n)]

def rotate(w):
    global n, m
    neww = [[w[j][m - 1 - i] for j in range(n)] for i in range(m)]
    n, m = m, n
    return neww

def count(w):
    result = 0
    dyn = [[0] * m for i in range(n)]
    for i in range(n):
        for j in range(m):
            if i == 0:
                dyn[i][j] = 1
            else:
                tl = None
                if j > 0:
                    tl = w[i - 1][j - 1]
                if tl == w[i - 1][j] and tl == w[i][j]:
                    dyn[i][j] = min(dyn[i - 1][j], dyn[i - 1][j - 1]) + 1
                else:
                    dyn[i][j] = 1
            result += dyn[i][j] - 1
    return result

result = 0
for i in range(4):
    result += count(w)
    w = rotate(w)

print(result)
