# BERT bfloat16 inference problems for parameter b = 1 incuding all the
# relevant post-ops and data types propagation
# 2D problems have M = b * 384
# 4D problems have batch = b x 16
#              
# In total, there are 24 identical fragments in the topology:
#         ____|____ -----.
#        /    |    \     :
#      MM_1  MM_2  MM_3  :
#       |     |   /      :
#       |    MM_4 -------:
#        \   /           :
#         MM_5           :
#           |            :
#         MM_6 ----------`
#           |
#     Layer_norm ---.
#           |       :
#         MM_7      :
#           |       :
#         MM_8 -----`
#           |
#     Layer_norm

--reset
--skip-impl=ref
--cfg=bf16bf16bf16 --bia_dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
# MM_2, MM_3 and MM_6 are the same, but MM_6 with binary post-ops by default
384x1024:1024x1024n"BERT:MM_1*96"

--reset
--skip-impl=ref
--cfg=bf16bf16bf16 --stag=abcd --wtag=abdc --dtag=abcd
#--attr-post-ops=add:bf16:per_dim_023
1x16x384x64:1x16x64x384n"BERT:MM_4*24"

--reset
--skip-impl=ref
--cfg=bf16bf16bf16 --stag=abcd --wtag=abcd --dtag=abcd  
1x16x384x384:1x16x384x64n"BERT:MM_5*24"

#--reset
#--skip-impl=ref
#--cfg=bf16bf16bf16 --bia_dt=bf16 --bia_mask=2
#--attr-post-ops=add:bf16:per_tensor
#384x1024:1024x1024n"BERT:MM_6"

--reset
--skip-impl=ref
--cfg=bf16bf16bf16 --bia_dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
384x1024:1024x4096n"BERT:MM_7*24"

--reset
--skip-impl=ref
--cfg=bf16bf16bf16 --bia_dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
#--attr-post-ops=add:bf16:per_tensor
384x4096:4096x1024n"BERT:MM_8*24"
